返回介绍

VAE(1) - 从 KL 说起

发布于 2025-02-25 23:04:58 字数 7791 浏览 0 评论 0 收藏 0

前面我们介绍了 GAN - Generative Adversarial Network,这个网络组是站在对抗博弈的角度去展现生成模型和判别模型各自的威力的,下面我们来看看这种生成模型和判别模型组合的另一个套路 - Variational autoencoder,简称 VAE。

突然想起来,他也叫 VAE,我觉得他还是有点音乐才华的。不过我们今天不去讨论他。

Variational autoencoder 的概念相对复杂一些,它涉及到一些比较复杂的公式推导。在开始正式的推导之前,我们先来看看一个基础概念 - KL divergence,翻译过来叫做 KL 散度。

什么是 KL 散度

无论从概率论的角度,还是从信息论的角度,我们都可以很好地给出 KL 散度测量的意义。这里不是基础的概念介绍,所以有关 KL 的概念就不介绍了。在 Variational Inference 中,我们希望能够找到一个相对简单好算的概率分布 q,使它尽可能地近似我们待分析的后验概率 p(z|x),其中 z 是隐变量,x 是显变量。在这里我们的“loss 函数”就是 KL 散度,他可以很好地测量两个概率分布之间的距离。如果两个分布越接近,那么 KL 散度越小,如果越远,KL 散度就会越大。

KL 散度的公式为:

KL(p||q)=\sum{p(x)log\frac{p(x)}{q(x)}},这个是离散概率分布的公式,

KL(p||q)=\int{p(x)log{\frac{p(x)}{q(x)}}dx},这个是连续概率分布的公式。

关于其他 KL 散度的性质,这里就不赘述了。

KL 散度的实战 - 1 维高斯分布

我们先来一个相对简单的例子。假设我们有两个随机变量 x1,x2,各自服从一个高斯分布N_1(\mu_1,\sigma_1^2),N_2(\mu_2,\sigma_2^2),那么这两个分布的 KL 散度该怎么计算呢?

我们知道

N(\mu,\sigma)=\frac{1}{\sqrt{2\pi\sigma^2}}e^{\frac{(x-\mu)^2}{2\sigma^2}}

那么 KL(p1,p2) 就等于

\int{p_1(x)log\frac{p_1(x)}{p_2(x)}}dx
=\int{p_1(x)(log{p_1(x)}}dx-log{p_2(x)})dx=\int{p_1(x)}*(log{\frac{1}{\sqrt{2\pi\sigma_1^2}}e^{\frac{(x-\mu_1)^2}{2\sigma_1^2}}}-log{\frac{1}{\sqrt{2\pi\sigma_2^2}}e^{\frac{(x-\mu_2)^2}{2\sigma_2^2}}})dx
=\int{p_1(x)
*(-\frac{1}{2}log2\pi-log{\sigma_1}-\frac{(x-\mu_1)^2}{2\sigma_1^2}}+
\frac{1}{2}log{2\pi}+log{\sigma_2}+\frac{(x-\mu_2)^2}{2\sigma_2^2})dx
=\int{p_1(x)(log\frac{\sigma_2}{\sigma_1}+[\frac{(x-\mu_2)^2}{2\sigma_2^2}-\frac{(x-\mu_1)^2}{2\sigma_1^2}])dx}
=\int(log\frac{\sigma_2}{\sigma_1})p_1(x)dx+\int(\frac{(x-\mu_2)^2}{2\sigma_2^2})p_1(x)dx-\int(\frac{(x-\mu_1)^2}{2\sigma_1^2})p_1(x)dx
=log\frac{\sigma_2}{\sigma_1}+\frac{1}{2\sigma_2^2}\int((x-\mu_2)^2)p_1(x)dx-\frac{1}{2\sigma_1^2}\int((x-\mu_1)^2)p_1(x)dx

(更新)到这里停一下,有童鞋问这里右边最后一项的化简,这时候积分符号里面的东西是不看着很熟悉?没错,就是我们常见的方差嘛,于是括号内外一约分,就得到了最终的结果 - \frac{1}{2}

好,继续。
=log\frac{\sigma_2}{\sigma_1}+\frac{1}{2\sigma_2^2}\int((x-\mu_2)^2)p_1(x)dx-\frac{1}{2}
=log\frac{\sigma_2}{\sigma_1}+\frac{1}{2\sigma_2^2}\int((x-\mu_1+\mu_1-\mu_2)^2)p_1(x)dx-\frac{1}{2}
=log\frac{\sigma_2}{\sigma_1}+\frac{1}{2\sigma_2^2}[\int{(x-\mu_1)^2}p_1(x)dx+\int{(\mu_1-\mu_2)^2}p_1(x)dx+2\int{(x-\mu_1)(\mu_1-\mu_2)]}p_1(x)dx-\frac{1}{2}
=log\frac{\sigma_2}{\sigma_1}+\frac{1}{2\sigma_2^2}[\int{(x-\mu_1)^2}p_1(x)dx+(\mu_1-\mu_2)^2]-\frac{1}{2}

=log\frac{\sigma_2}{\sigma_1}+\frac{\sigma_1^2+(\mu_1-\mu_2)^2}{2\sigma_2^2}-\frac{1}{2}

说实话一直以来我不是很喜欢写这种大段推导公式的文章,一来原创性比较差(都是前人推过的,我就是大自然的搬运工),二来其中的逻辑性太强,容易让人看蒙。不过最终的结论还是得出来了,我们假设 N2 是一个正态分布,也就是说\mu_2=0,\sigma_2^2=1那么 N1 长成什么样子能够让 KL 散度尽可能地小呢?

也就是说KL(\mu_1,\sigma_1)=-log\sigma_1+\frac{\sigma_1^2+\mu_1^2}{2}-\frac{1}{2}

我们用“肉眼”看一下就能猜测到当\mu_1=0,\sigma_1=1时,KL 散度最小。从公式中可以看出,如果\mu_1偏离了 0,那么 KL 散度一定会变大。而方差的变化则有些不同:

\sigma_1大于 1 时,\frac{1}{2}\sigma_1^2将越变越大,而-log\sigma_1越变越小;

\sigma_1小于 1 时,\frac{1}{2}\sigma_1^2将越变越小,而-log\sigma_1越变越大;

那么哪边的力量更强大呢?我们可以作图出来:

import numpy as np
import matplotlib.pyplot as plt
x = np.linspace(0.5,2,100)
y = -np.log(x)+x*x/2-0.5
plt.plot(x,y)
plt.show()

从图中可以看出

二次项的威力更大,函数一直保持为非负,这和我们前面提到的关于非负的定义是完全一致的。

好了,看完了这个简单的例子,下面让我们再看一个复杂的例子。

一个更为复杂的例子:多维高斯分布的 KL 散度

上一回我们看过了 1 维高斯分布间的 KL 散度计算,下面我们来看看多维高斯分布的 KL 散度是什么样子?说实话,这一次的公式将在后面介绍 VAE 时发挥很重要的作用!

首先给出多维高斯分布的公式:

p(x_1,x_2,...x_n)=\frac{1}{\sqrt{2\pi*det(\Sigma)}}e^{(-\frac{1}{2}(x-\mu)^T\Sigma^{-1}(x-\mu))}

由于这次是多维变量,里面的大多数计算都变成了向量、矩阵之间的计算。我们常用的是各维间相互独立的分布,因此协方差矩阵实际上是个对角阵。

考虑到篇幅以及实际情况,下面直接给出结果,让我们忽略哪些恶心的推导过程:

KL(p1||p2)=\frac{1}{2}[log \frac{det(\Sigma_2)}{det(\Sigma_1)} - d + tr(\Sigma_2^{-1}\Sigma_1)+(\mu_2-\mu_1)^T \Sigma_2^{-1}(\mu_2-\mu_1)]

其实这一次我们并没有介绍关于 KL 的意义和作用,只是生硬地、莫名其妙地推导一堆公式,不过别着急,下一回,我们展示 VAE 效果的时候,就会让大家看到 KL 散度的作用。

坚持看到这里的童鞋是有福的,来展示一下 VAE 的解码器在 MNIST 数据库上产生的字符生成效果:

从这个效果上来看,它的功能和 GAN 是有点像的,那么让我们来进一步揭开它的庐山真面目吧!

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。
列表为空,暂无数据
    我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
    原文