返回介绍

数学基础

统计学习

深度学习

工具

Scala

二、近似推断

发布于 2023-07-17 23:38:25 字数 10068 浏览 0 评论 0 收藏 0

  1. 精确推断方法通常需要很大的计算开销,因此在现实应用中近似推断方法更为常用。

    近似推断方法可以分作两类:

    • 采样sampling。通过使用随机化方法完成近似。
    • 使用确定性近似完成近似推断,典型代表为变分推断variantional inference

2.1 MCMC 采样

  1. MCMC 采样是一种常见的采样方法,可以用于概率图模型的近似推断。其原理部分参考数学基础部分的蒙特卡洛方法与 MCMC 采样

2.2 变分推断

  1. 变分推断通过使用已知简单分布来逼近需要推断的复杂分布,并通过限制近似分布的类型,从而得到一种局部最优、但具有确定解的近似后验分布。

  2. 给定多维随机变量 $ MathJax-Element-57 $ ,其中每个分量都依赖于随机变量 $ MathJax-Element-120 $ 。假定 $ MathJax-Element-79 $ 是观测变量, $ MathJax-Element-120 $ 是隐含变量。

    推断任务是:由观察到的随机变量 $ MathJax-Element-79 $ 来估计隐变量 $ MathJax-Element-120 $ 和分布参数变量 $ MathJax-Element-71 $ , 即求解 $ MathJax-Element-69 $ 和 $ MathJax-Element-71 $ 。

    $ MathJax-Element-71 $ 的估计可以使用EM算法:(设数据集 $ MathJax-Element-163 $ )

    • E步:根据 $ MathJax-Element-72 $ 时刻的参数 $ MathJax-Element-186 $ ,计算 $ MathJax-Element-74 $ 函数:

      $ Q(\Theta;\Theta^{}) =\sum_{i=1}^N\sum_{\mathbf {\vec z}} \ln p(\mathbf {\vec x}=\mathbf {\vec x}_i,\mathbf {\vec z};\Theta) p(\mathbf {\vec z}\mid \mathbf {\vec x}=\mathbf {\vec x}_i;\Theta^{}) $
    • M 步:基于 E 步的结果进行最大化寻优: $ MathJax-Element-731 $ 。

  3. 根据 EM 算法的原理知道, $ MathJax-Element-210 $ 是隐变量 $ MathJax-Element-120 $ 的一个近似后验分布。

    事实上我们可以人工构造一个概率分布 $ MathJax-Element-566 $ 来近似后验分布 $ MathJax-Element-102 $ ,其中 $ MathJax-Element-572 $ 为参数。

    如: $ MathJax-Element-808 $ ,其中 $ MathJax-Element-814 $ 为参数, $ MathJax-Element-815 $ 表示正态分布。

    • 这样构造的 $ MathJax-Element-566 $ 与 $ MathJax-Element-210 $ 的作用相同,它们都是对 $ MathJax-Element-102 $ 的一个近似。

      但是选择构造 $ MathJax-Element-566 $ 的优势是:可以选择一些性质较好的分布。

    • 根据后验概率的定义,对于每个 $ MathJax-Element-831 $ 都需要构造对应的 $ MathJax-Element-840 $ 。

  4. 根据 $ MathJax-Element-570 $ ,两边同时取对数有:

    $ \log p(\mathbf{\vec x}) = \log \frac{p(\mathbf{\vec x},\mathbf{\vec z})}{q(\mathbf{\vec z};\lambda)} - \log \frac{ p(\mathbf{\vec z}\mid \mathbf{\vec x})}{q(\mathbf{\vec z};\lambda)} $

    同时对两边对分布 $ MathJax-Element-566 $ 求期望,由于 $ MathJax-Element-248 $ 与 $ MathJax-Element-416 $ 无关,因此有:

    $ \log p(\mathbf{\vec x}) = \mathbb E_q\left[ \log \frac{p(\mathbf{\vec x},\mathbf{\vec z})}{q(\mathbf{\vec z};\lambda)}\right] - \mathbb E_q\left[ \log \frac{ p(\mathbf{\vec z}\mid \mathbf{\vec x})}{q(\mathbf{\vec z};\lambda)}\right]\\ = \mathbb E_q\left[ \log \frac{p(\mathbf{\vec x},\mathbf{\vec z})}{q(\mathbf{\vec z};\lambda)}\right]+KL(q(\mathbf{\vec z};\lambda)||p(\mathbf{\vec z}\mid \mathbf{\vec x})) $

    其中 $ MathJax-Element-551 $ 为KL 散度(Kullback-Leibler divergence),其定义为:

    $ KL(p || q)=\int_{-\infty}^{+\infty} p(x)\log\frac {p(x)}{q(x)}dx $

    我们的目标是使得 $ MathJax-Element-596 $ 尽可能靠近 $ MathJax-Element-621 $ ,即: $ MathJax-Element-634 $ 。

    考虑到 $ MathJax-Element-248 $ 与 $ MathJax-Element-416 $ 无关,因此上述目标等价于:

    $ \max_\lambda \mathbb E_q\left[ \log \frac{p(\mathbf{\vec x},\mathbf{\vec z})}{q(\mathbf{\vec z};\lambda)}\right] $

    称 $ MathJax-Element-675 $ 为 ELBO:Evidence Lower Bound

    $ MathJax-Element-679 $ 是观测变量的概率,一般被称作 evidence 。因为 $ MathJax-Element-686 $ ,所以有:

    $ MathJax-Element-693 $ 。因此它被称作 Evidence Lower Bound

  5. 考虑 ELBO

    $ ELBO = \mathbb E_q [\log p(\mathbf{\vec x},\mathbf{\vec z}) - \log q(\mathbf{\vec z};\lambda)]=\mathbb E_q \log p(\mathbf{\vec x},\mathbf{\vec z}) - H(q) $
    • 第一项称作能量函数。为了使得ELBO 最大,则它倾向于在 $ MathJax-Element-865 $ 较大的地方 $ MathJax-Element-596 $ 也较大。
    • 第二项为 $ MathJax-Element-596 $ 分布的熵。为了使得ELBO 最大,则它倾向于 $ MathJax-Element-596 $ 为均匀分布。
  6. 假设 $ MathJax-Element-120 $ 可以拆解为一系列相互独立的子变量 $ MathJax-Element-877 $ ,则有: $ MathJax-Element-979 $ 。这被称作平均场mean field approximation

    此时 ELBO 为:

    $ ELBO =\int_{\mathbf {\vec z}_1}\int_{\mathbf {\vec z}_2}\cdots\int_{\mathbf {\vec z}_K}\prod_{k=1}^{K}q_k(\mathbf {\vec z}_k;\lambda_k) \log p(\mathbf {\vec x},\mathbf {\vec z}_1,\mathbf {\vec z}_2,\cdots,\mathbf {\vec z}_K) d \mathbf {\vec z}_1 d \mathbf {\vec z}_2 \cdots d \mathbf {\vec z}_K\\ -\int_{\mathbf {\vec z}_1}\int_{\mathbf {\vec z}_2}\cdots\int_{\mathbf {\vec z}_K}\prod_{k=1}^{K}q_k(\mathbf {\vec z}_k;\lambda_k) \sum_{k=1}^{K}\log q_k(\mathbf {\vec z}_k;\lambda_k) d \mathbf {\vec z}_1 d \mathbf {\vec z}_2 \cdots d \mathbf {\vec z}_K $

    定义 $ MathJax-Element-1032 $ ,它就是 $ MathJax-Element-120 $ 中去掉 $ MathJax-Element-877 $ 的剩余部分。定义 $ MathJax-Element-1037 $ 为 $ MathJax-Element-572 $ 中去掉 $ MathJax-Element-1040 $ 的剩余部分。

    • 考虑第一项:

      考虑到括号内的内容为:

      $ \int \prod_{k=1,k\ne j}^{K}q_k(\mathbf {\vec z}_k;\lambda_k)\log p(\mathbf {\vec x},\mathbf {\vec z}_1,\mathbf {\vec z}_2,\cdots,\mathbf {\vec z}_K) d \mathbf {\vec z}_1 d \mathbf {\vec z}_2 \cdots d\mathbf {\vec z}_{j-1} d \mathbf {\vec z}_{j+1} \cdots d \mathbf {\vec z}_K \\ =\int q(\bar{\mathbf{\vec z}}_j;\bar\lambda_j) \log p(\mathbf{\vec x},\mathbf{\vec z}) d \bar{\mathbf{\vec z}}_j = \mathbb E_{ q(\bar{\mathbf{\vec z}}_j;\bar\lambda_j)} [\log p(\mathbf{\vec x},\mathbf{\vec z}) ] $

      因此第一项为: $ MathJax-Element-1221 $ 。

    • 考虑第二项:

      $ \int_{\mathbf {\vec z}_1}\int_{\mathbf {\vec z}_2}\cdots\int_{\mathbf {\vec z}_K}\prod_{k=1}^{K}q_k(\mathbf {\vec z}_k;\lambda_k) \sum_{k=1}^{K}\log q_k(\mathbf {\vec z}_k;\lambda_k) d \mathbf {\vec z}_1 d \mathbf {\vec z}_2 \cdots d \mathbf {\vec z}_K \\ =\int_{\mathbf {\vec z}_1}q_1(\mathbf {\vec z}_1;\lambda_1)\int_{\mathbf {\vec z}_2}q_2(\mathbf {\vec z}_2;\lambda_2)\cdots\int_{\mathbf {\vec z}_K}q_K(\mathbf {\vec z}_K;\lambda_K)(\log q_1(\mathbf {\vec z}_1;\lambda_1)+\log q_2(\mathbf {\vec z}_2;\lambda_2)+\cdots\\ +\log q_K(\mathbf {\vec z}_K;\lambda_K)) d \mathbf {\vec z}_1 d \mathbf {\vec z}_2 \cdots d \mathbf {\vec z}_K\\ =\int_{\mathbf {\vec z}_1}q_1(\mathbf {\vec z}_1;\lambda_1)\int_{\mathbf {\vec z}_2}q_2(\mathbf {\vec z}_2;\lambda_2)\cdots\int_{\mathbf {\vec z}_K}q_K(\mathbf {\vec z}_K;\lambda_K) \log q_1(\mathbf {\vec z}_1;\lambda_1)d \mathbf {\vec z}_1 d \mathbf {\vec z}_2 \cdots d \mathbf {\vec z}_K\\ + \int_{\mathbf {\vec z}_1}q_1(\mathbf {\vec z}_1;\lambda_1)\int_{\mathbf {\vec z}_2}q_2(\mathbf {\vec z}_2;\lambda_2)\cdots\int_{\mathbf {\vec z}_K}q_K(\mathbf {\vec z}_K;\lambda_K) \log q_2(\mathbf {\vec z}_2;\lambda_2)d \mathbf {\vec z}_1 d \mathbf {\vec z}_2 \cdots d \mathbf {\vec z}_K\\ +\cdots+ \int_{\mathbf {\vec z}_1}q_1(\mathbf {\vec z}_1;\lambda_1)\int_{\mathbf {\vec z}_2}q_2(\mathbf {\vec z}_2;\lambda_2)\cdots\int_{\mathbf {\vec z}_K}q_K(\mathbf {\vec z}_K;\lambda_K)\log q_K(\mathbf {\vec z}_K;\lambda_K) d \mathbf {\vec z}_1 d \mathbf {\vec z}_2 \cdots d \mathbf {\vec z}_K $

      由于 $ MathJax-Element-1261 $ 构成了一个分布函数,因此 :

      $ \int \prod_{k=1,k\ne j}^{K} q_k(\mathbf {\vec z}_k;\lambda_k)d \mathbf {\vec z}_1 d \mathbf {\vec z}_2 \cdots d\mathbf {\vec z}_{j-1} d \mathbf {\vec z}_{j+1} \cdots d \mathbf {\vec z}_K=1 $

      则有:

      $ \int_{\mathbf {\vec z}_1}\int_{\mathbf {\vec z}_2}\cdots\int_{\mathbf {\vec z}_K}\prod_{k=1}^{K}q_k(\mathbf {\vec z}_k;\lambda_k) \sum_{k=1}^{K}\log q_k(\mathbf {\vec z}_k;\lambda_k) d \mathbf {\vec z}_1 d \mathbf {\vec z}_2 \cdots d \mathbf {\vec z}_K \\ =\int_{\mathbf {\vec z}_1}q_1(\mathbf {\vec z}_1;\lambda_1)\log q_1(\mathbf {\vec z}_1;\lambda_1)d \mathbf {\vec z}_1+ \int_{\mathbf {\vec z}_2}q_2(\mathbf {\vec z}_2;\lambda_2)\log q_2(\mathbf {\vec z}_2;\lambda_2)d \mathbf {\vec z}_2+\cdots\\ +\int_{\mathbf {\vec z}_K}q_K(\mathbf {\vec z}_K;\lambda_K)\log q_K(\mathbf {\vec z}_K;\lambda_K)d \mathbf {\vec z}_K \\=\sum_{k=1}^K\int_{\mathbf {\vec z}_k}q_k(\mathbf {\vec z}_k;\lambda_k)\log q_k(\mathbf {\vec z}_k;\lambda_k)d \mathbf {\vec z}_k $

    即:

    $ ELBO = \int_{\mathbf {\vec z}_j}q_j(\mathbf {\vec z}_j;\lambda_j) \mathbb E_{ q(\bar{\mathbf{\vec z}}_j;\bar\lambda_j)} [\log p(\mathbf{\vec x},\mathbf{\vec z}) ] d \mathbf {\vec z}_j - \sum_{k=1}^K\int_{\mathbf {\vec z}_k}q_k(\mathbf {\vec z}_k;\lambda_k)\log q_k(\mathbf {\vec z}_k;\lambda_k)d \mathbf {\vec z}_k $
  7. 定义一个概率分布 $ MathJax-Element-1398 $ ,其中 $ MathJax-Element-1400 $ 是与 $ MathJax-Element-1402 $ 有关、与 $ MathJax-Element-1404 $ 无关的常数项。

    则有:

    $ ELBO = \int_{\mathbf {\vec z}_j}q_j(\mathbf {\vec z}_j;\lambda_j) [\log C + \log q_j^{*}(\mathbf{\vec z}_j,\lambda_j) ]d \mathbf {\vec z}_j - \sum_{k=1}^K\int_{\mathbf {\vec z}_k}q_k(\mathbf {\vec z}_k;\lambda_k)\log q_k(\mathbf {\vec z}_k;\lambda_k)d \mathbf {\vec z}_k\\ = \log C + \int_{\mathbf {\vec z}_j}q_j(\mathbf {\vec z}_j;\lambda_j) \log q_j^{*}(\mathbf{\vec z}_j,\lambda_j) - \int_{\mathbf {\vec z}_j}q_j(\mathbf {\vec z}_j;\lambda_j)\log q_j(\mathbf {\vec z}_j;\lambda_j)d \mathbf {\vec z}_j\\ -\sum_{k=1,k\ne j}^K\int_{\mathbf {\vec z}_k}q_k(\mathbf {\vec z}_k;\lambda_k)\log q_k(\mathbf {\vec z}_k;\lambda_k)d \mathbf {\vec z}_k $

    其中 $ MathJax-Element-1727 $ ,因此有:

    $ ELBO = \log C - KL(q_j(\mathbf {\vec z}_j;\lambda_j)||q_j^{*}(\mathbf {\vec z}_j;\lambda_j)) + H(q(\bar{\mathbf{\vec z}}_j,\bar \lambda_j)) $

    为求解 $ MathJax-Element-1575 $ ,则可以看到当 $ MathJax-Element-1583 $ 时, $ MathJax-Element-1584 $ 取最大值。 因此得到 $ MathJax-Element-1587 $ 的更新规则:

    $ \begin{aligned} &q_1(\mathbf {\vec z}_1;\lambda_1)=q_1^* (\mathbf {\vec z}_1;\lambda_1)\\ &q_2(\mathbf {\vec z}_2;\lambda_2)=q_2^* (\mathbf {\vec z}_2;\lambda_2)\\ &q_3(\mathbf {\vec z}_3;\lambda_3)=q_3^* (\mathbf {\vec z}_3;\lambda_3)\\ &... \end{aligned} $

    根据 $ MathJax-Element-1371 $ 可知:在对 $ MathJax-Element-1350 $ 进行更新时,融合了 $ MathJax-Element-116 $ 之外的其他 $ MathJax-Element-1746 $ 的信息。

  8. 在实际应用变分法时,最重要的是考虑如何对隐变量 $ MathJax-Element-120 $ 进行拆解,以及假设各种变量子集服从何种分布。

    如果隐变量 $ MathJax-Element-120 $ 的拆解或者变量子集的分布假设不当,则会导致变分法效率低、效果差。

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。
列表为空,暂无数据
    我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
    原文