返回介绍

数学基础

统计学习

深度学习

工具

Scala

三、最大熵的学习

发布于 2023-07-17 23:38:25 字数 9870 浏览 0 评论 0 收藏 0

  1. 最大熵模型的学习就是在给定训练数据集 $ MathJax-Element-1005 $ 时,对模型进行极大似然估计或者正则化的极大似然估计。

  2. 最大熵模型与 logistic 回归模型有类似的形式,它们又称为对数线性模型。

    • 它们的目标函数具有很好的性质:光滑的凸函数。因此有多种最优化方法可用,且保证能得到全局最优解。
    • 最常用的方法有:改进的迭代尺度法、梯度下降法、牛顿法、拟牛顿法。

3.1 改进的迭代尺度法

  1. 改进的迭代尺度法Improved Iterative Scaling:IIS是一种最大熵模型学习的最优化算法。

  2. 已知最大熵模型为:

    $ P_\mathbf{\vec w}(y\mid \vec{\mathbf x})=\frac{1}{Z_\mathbf{\vec w}(\vec{\mathbf x})} \exp\left(\sum_{i=1}^{n}w_i f_i(\vec{\mathbf x},y)\right) $

    其中

    $ Z_\mathbf{\vec w}(\vec{\mathbf x})=\sum_y \exp\left(\sum_{i=1}^{n}w_i f_i(\vec{\mathbf x},y)\right) $

    对数似然函数为:

    $ L(\mathbf{\vec w})=\log \prod_{\vec{\mathbf x},y}P_\mathbf{\vec w}(y\mid \vec{\mathbf x})^{\tilde P(\vec{\mathbf x},y)}=\sum_{\vec{\mathbf x},y}[\tilde P(\vec{\mathbf x},y) \log P_\mathbf{\vec w}(y\mid \vec{\mathbf x})]\\ =\sum_{\vec{\mathbf x},y}\left(\tilde P(\vec{\mathbf x},y)\sum_{i=1}^{n}w_i f_i(\vec{\mathbf x},y)\right)-\sum_{\vec{\mathbf x}}\left(\tilde P(\vec{\mathbf x})\log Z_\mathbf{\vec w}(\vec{\mathbf x})\right) $

    最大熵模型的目标是:通过极大化似然函数学习模型参数,求出使得对数似然函数最大的参数 $ MathJax-Element-177 $ 。

  3. IIS 原理:假设最大熵模型当前的参数向量是 $ MathJax-Element-178 $ ,希望找到一个新的参数向量 $ MathJax-Element-179 $ ,使得模型的对数似然函数值增大。

    • 若能找到这样的新参数向量,则更新 $ MathJax-Element-180 $ 。
    • 重复这一过程,直到找到对数似然函数的最大值。
  4. 对于给定的经验分布 $ MathJax-Element-181 $ ,模型参数从 $ MathJax-Element-182 $ 到 $ MathJax-Element-183 $ 之间,对数似然函数的改变量为:

    $ L(\mathbf{\vec w}+ \vec\delta)-L(\mathbf{\vec w})=\sum_{\vec{\mathbf x},y}\left(\tilde P(\vec{\mathbf x},y)\sum_{i=1}^{n}\delta_i f_i(\vec{\mathbf x},y)\right)-\sum_{\vec{\mathbf x}}\left(\tilde P(\vec{\mathbf x})\log \frac{Z_{\mathbf{\vec w}+\vec\delta}(\vec{\mathbf x})}{Z_\mathbf{\vec w}(\vec{\mathbf x})}\right) $
    • 利用不等式:当 $ MathJax-Element-2607 $ 时 $ MathJax-Element-2604 $ 有:
    $ L(\mathbf{\vec w}+\vec\delta)-L(\mathbf{\vec w}) \ge \sum_{\vec{\mathbf x},y}\left(\tilde P(\vec{\mathbf x},y)\sum_{i=1}^{n}\delta_i f_i(\vec{\mathbf x},y)\right)+\sum_{\vec{\mathbf x}}\left[\tilde P(\vec{\mathbf x})\left(1-\frac{Z_{\mathbf{\vec w}+\vec\delta}(\vec{\mathbf x})}{Z_\mathbf{\vec w}(\vec{\mathbf x})}\right)\right]\\ =\sum_{\vec{\mathbf x},y}\left(\tilde P(\vec{\mathbf x},y)\sum_{i=1}^{n}\delta_i f_i(\vec{\mathbf x},y)\right)+\sum_{\vec{\mathbf x}}\tilde P(\vec{\mathbf x})-\sum_{\vec{\mathbf x}}\left(\tilde P(\vec{\mathbf x}) \frac{Z_{\mathbf{\vec w}+\vec\delta}(\vec{\mathbf x})}{Z_\mathbf{\vec w}(\vec{\mathbf x})} \right) $
    • 考虑到 $ MathJax-Element-185 $ ,以及:

      $ \frac{Z_{\mathbf{\vec w}+\vec\delta}(\vec{\mathbf x})}{Z_\mathbf{\vec w}(\vec{\mathbf x})} =\frac{\sum_y \exp\left(\sum_{i=1}^{n}(w_i+\delta_i) f_i(\vec{\mathbf x},y)\right)}{Z_\mathbf{\vec w}(\vec{\mathbf x})}\\ =\frac {1}{Z_\mathbf{\vec w}(\vec{\mathbf x})}\sum_y \left[\exp\left(\sum_{i=1}^{n}w_i f_i(\vec{\mathbf x},y)\right)\cdot \exp\left(\sum_{i=1}^{n} \delta_i f_i(\vec{\mathbf x},y)\right)\right]\\ =\sum_y \left[\frac {1}{Z_\mathbf{\vec w}(\vec{\mathbf x})} \cdot \exp\left(\sum_{i=1}^{n}w_i f_i(\vec{\mathbf x},y)\right)\cdot \exp\left(\sum_{i=1}^{n} \delta_i f_i(\vec{\mathbf x},y)\right)\right] $

      根据 $ MathJax-Element-186 $ 有:

      $ \frac{Z_{\mathbf{\vec w}+\vec\delta}(\vec{\mathbf x})}{Z_\mathbf{\vec w}(\vec{\mathbf x})} =\sum_y \left[P_\mathbf{\vec w}(y\mid \vec{\mathbf x})\cdot \exp\left(\sum_{i=1}^{n} \delta_i f_i(\vec{\mathbf x},y)\right)\right] $

      则有:

      $ L(\mathbf{\vec w}+\vec\delta)-L(\mathbf{\vec w}) \ge \sum_{\vec{\mathbf x},y}\left(\tilde P(\vec{\mathbf x},y)\sum_{i=1}^{n}\delta_i f_i(\vec{\mathbf x},y)\right)+1\\ -\sum_\vec{\mathbf x} \left[\tilde P(\vec{\mathbf x}) \sum_y\left(P_\mathbf{\vec w}(y\mid \vec{\mathbf x})\exp\sum_{i=1}^{n}\delta_if_i(\vec{\mathbf x},y)\right)\right] $
    • 则 $ MathJax-Element-187 $ 。

  5. 如果能找到合适的 $ MathJax-Element-201 $ 使得 $ MathJax-Element-196 $ 提高,则对数似然函数也会提高。但是 $ MathJax-Element-201 $ 是个向量,不容易同时优化。

    • 一个解决方案是:每次只优化一个变量 $ MathJax-Element-209 $ 。
    • 为达到这个目的,引入一个变量 $ MathJax-Element-210 $ 。
  6. $ MathJax-Element-196 $ 改写为:

    $ A(\vec\delta\mid\mathbf{\vec w})=\sum_{\vec{\mathbf x},y}\left(\tilde P(\vec{\mathbf x},y)\sum_{i=1}^{n}\delta_i f_i(\vec{\mathbf x},y)\right)+1\\ -\sum_\vec{\mathbf x} \left[\tilde P(\vec{\mathbf x}) \sum_y \left(P_\mathbf{\vec w}(y\mid \vec{\mathbf x})\exp \left(f^{o}(\vec{\mathbf x},y)\sum_{i=1}^{n}\frac{\delta_if_i(\vec{\mathbf x},y)}{f^{o}(\vec{\mathbf x},y)}\right)\right)\right] $
    • 利用指数函数的凸性,根据

      $ \frac{f_i(\vec{\mathbf x},y)}{f^{o}(\vec{\mathbf x},y)} \ge 0,\quad \sum_{i=1}^{n}\frac{f_i(\vec{\mathbf x},y)}{f^{o}(\vec{\mathbf x},y)}=1 $

      以及Jensen 不等式有:

      $ \exp\left(f^{o}(\vec{\mathbf x},y)\sum_{i=1}^{n}\frac{\delta_if_i(\vec{\mathbf x},y)}{f^{o}(\vec{\mathbf x},y)}\right) \le \sum_{i=1}^{n}\left(\frac{f_i(\vec{\mathbf x},y)}{f^{o}(\vec{\mathbf x},y)}\exp(\delta_i f^{o}(\vec{\mathbf x},y))\right) $

      于是:

      $ A(\vec\delta\mid\mathbf{\vec w}) \ge \sum_{\vec{\mathbf x},y}\left(\tilde P(\vec{\mathbf x},y)\sum_{i=1}^{n}\delta_i f_i(\vec{\mathbf x},y)\right)+1\\ -\sum_\vec{\mathbf x} \left[\tilde P(\vec{\mathbf x}) \sum_y \left(P_\mathbf{\vec w}(y\mid \vec{\mathbf x})\sum_{i=1}^{n}\left(\frac{f_i(\vec{\mathbf x},y)}{f^{o}(\vec{\mathbf x},y)}\exp(\delta_i f^{o}(\vec{\mathbf x},y))\right)\right)\right] $
    • $ B(\vec\delta\mid\mathbf{\vec w})=\sum_{\vec{\mathbf x},y}\left(\tilde P(\vec{\mathbf x},y)\sum_{i=1}^{n}\delta_i f_i(\vec{\mathbf x},y)\right)+1\\ -\sum_\vec{\mathbf x} \left[\tilde P(\vec{\mathbf x}) \sum_y \left(P_\mathbf{\vec w}(y\mid \vec{\mathbf x})\sum_{i=1}^{n}\left(\frac{f_i(\vec{\mathbf x},y)}{f^{o}(\vec{\mathbf x},y)}\exp(\delta_i f^{o}(\vec{\mathbf x},y))\right)\right)\right] $

      则: $ MathJax-Element-2642 $ 。这里 $ MathJax-Element-198 $ 是对数似然函数改变量的一个新的(相对不那么紧)的下界。

  7. 求 $ MathJax-Element-198 $ 对 $ MathJax-Element-199 $ 的偏导数:

    $ \frac{\partial B(\vec\delta\mid\mathbf{\vec w})}{\partial \delta_i}=\sum_{\vec{\mathbf x},y}[\tilde P(\vec{\mathbf x},y)f_i(\vec{\mathbf x},y)]-\sum_{\vec{\mathbf x}}\left(\tilde P(\vec{\mathbf x})\sum_y[P_{\mathbf{\vec w}}(y\mid \vec{\mathbf x})f_i(\vec{\mathbf x},y)\exp(\delta_if^{o}(\vec{\mathbf x},y))]\right) =0 $

    令偏导数为 0 即可得到 $ MathJax-Element-209 $ :

    $ \sum_{\vec{\mathbf x}}\left(\tilde P(\vec{\mathbf x})\sum_y[P_{\mathbf{\vec w}}(y\mid \vec{\mathbf x})f_i(\vec{\mathbf x},y)\exp(\delta_if^{o}(\vec{\mathbf x},y))]\right)=\mathbb E_{\tilde P}[f_i] $

    最终根据 $ MathJax-Element-2667 $ 可以得到 $ MathJax-Element-201 $ 。

  8. IIS 算法:

    • 输入:

      • 特征函数 $ MathJax-Element-215 $
      • 经验分布 $ MathJax-Element-216 $
      • 模型 $ MathJax-Element-204 $
    • 输出:

      • 最优参数 $ MathJax-Element-205 $
      • 最优模型 $ MathJax-Element-206 $
    • 算法步骤:

      • 初始化:取 $ MathJax-Element-207 $ 。

      • 迭代,迭代停止条件为:所有 $ MathJax-Element-212 $ 均收敛。迭代步骤为:

        • 求解 $ MathJax-Element-2712 $ ,求解方法为:对每一个 $ MathJax-Element-208 $ :

          • 求解 $ MathJax-Element-209 $ 。其中 $ MathJax-Element-209 $ 是方程: $ MathJax-Element-2686 $ 的解,其中: $ MathJax-Element-210 $ 。
          • 更新 $ MathJax-Element-211 $ 。
        • 判定迭代停止条件。若不满足停止条件,则继续迭代。

3.2 拟牛顿法

  1. 若对数似然函数 $ MathJax-Element-2747 $ 最大,则 $ MathJax-Element-213 $ 最小。

    令 $ MathJax-Element-214 $ ,则最优化目标修改为:

    $ \min_{\mathbf{\vec w} \in \mathbb R^{n}}F(\mathbf{\vec w})= \min_{\mathbf{\vec w} \in \mathbb R^{n}} \sum_{\vec{\mathbf x}}\left(\tilde P(\vec{\mathbf x})\log\sum_y \exp\left(\sum_{i=1}^{n}w_i f_i(\vec{\mathbf x},y)\right)\right) -\sum_{\vec{\mathbf x},y}\left(\tilde P(\vec{\mathbf x},y)\sum_{i=1}^{n}w_i f_i(\vec{\mathbf x},y)\right) $

    计算梯度:

    $ \vec g(\mathbf{\vec w})=\left(\frac{\partial F(\mathbf{\vec w})}{\partial w_1},\frac{\partial F(\mathbf{\vec w})}{\partial w_2},\cdots,\frac{\partial F(\mathbf{\vec w})}{\partial w_n}\right)^{T},\\ \frac{\partial F(\mathbf{\vec w})}{\partial w_i}=\sum_{\vec{\mathbf x}}[\tilde P(\vec{\mathbf x})P_{\mathbf{\vec w}}(y\mid \vec{\mathbf x})f_i(\vec{\mathbf x},y)]- \mathbb E_{\tilde P}[f_i],\quad i=1,2,\cdots,n $
  2. 最大熵模型学习的 BFGS算法:

    • 输入:

      • 特征函数 $ MathJax-Element-215 $
      • 经验分布 $ MathJax-Element-216 $
      • 目标函数 $ MathJax-Element-217 $
      • 梯度 $ MathJax-Element-218 $
      • 精度要求 $ MathJax-Element-219 $
    • 输出:

      • 最优参数值 $ MathJax-Element-220 $
      • 最优模型 $ MathJax-Element-221 $
    • 算法步骤:

      • 选定初始点 $ MathJax-Element-222 $ ,取 $ MathJax-Element-223 $ 为正定对阵矩阵,迭代计数器 $ MathJax-Element-224 $ 。

      • 计算 $ MathJax-Element-225 $ :

        • 若 $ MathJax-Element-226 $ ,停止计算,得到 $ MathJax-Element-227 $

        • 若 $ MathJax-Element-228 $ :

          • 由 $ MathJax-Element-229 $ 求得 $ MathJax-Element-230 $

          • 一维搜索:求出 $ MathJax-Element-231 $ : $ MathJax-Element-2789 $

          • 置 $ MathJax-Element-233 $

          • 计算 $ MathJax-Element-234 $ 。 若 $ MathJax-Element-235 $ ,停止计算,得到 $ MathJax-Element-236 $ 。

          • 否则计算 $ MathJax-Element-237 $ :

            $ \mathbf B_{k+1}=\mathbf B_k+\frac{\mathbf{\vec y}_k \mathbf{\vec y}_k^{T}}{\mathbf{\vec y}_k^{T} \vec\delta_k}-\frac{\mathbf B_k \vec\delta_k \vec\delta_k^{T}\mathbf B_k}{\vec\delta_k^{T}\mathbf B_k \vec\delta_k} $

            其中: $ MathJax-Element-238 $ 。

          • 置 $ MathJax-Element-239 $ ,继续迭代。

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。
列表为空,暂无数据
    我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
    原文