返回介绍

数学基础

统计学习

深度学习

工具

Scala

八、t-SNE

发布于 2023-07-17 23:38:26 字数 15363 浏览 0 评论 0 收藏 0

  1. t-SNE:t-distributed stochastic neighbor embedding 是一种非线性降维算法,它是由SNE 发展而来。

    tsne

8.1 SNE

  1. SNE 的基本思想:如果两个样本在高维相似,则它们在低维也相似。

  2. SNE 主要包含两步:

    • 构建样本在高维的概率分布。
    • 在低维空间里重构这些样本的概率分布,使得这两个概率分布之间尽可能相似。
  3. 在数据集 $ MathJax-Element-778 $ 中,给定一个样本 $ MathJax-Element-917 $ ,然后计算 $ MathJax-Element-780 $ 是 $ MathJax-Element-917 $ 的邻居的概率。

    SNE 假设:如果 $ MathJax-Element-856 $ 与 $ MathJax-Element-917 $ 越相似,则 $ MathJax-Element-856 $ 是 $ MathJax-Element-917 $ 的邻居的概率越大。

    • 相似度通常采用欧几里得距离来衡量,两个样本距离越近则它们越相似。

    • 概率 $ MathJax-Element-786 $ 通常采用指数的形式:

      $ p( \mathbf{\vec x}_j\mid \mathbf{\vec x}_i) \propto \exp\left(-||\mathbf{\vec x}_j-\mathbf{\vec x}_i||^2/(2\sigma_i^2)\right) $

      对 $ MathJax-Element-787 $ 进行归一化有:

      $ p( \mathbf{\vec x}_j\mid \mathbf{\vec x}_i) = \frac{\exp\left(-||\mathbf{\vec x}_j-\mathbf{\vec x}_i||^2/(2\sigma_i^2)\right)}{\sum_{k\ne i}\exp\left(-||\mathbf{\vec x}_k-\mathbf{\vec x}_i||^2/(2\sigma_i^2)\right)} $

      其中 $ MathJax-Element-840 $ 是与 $ MathJax-Element-917 $ 相关的、待求得参数,它用于对距离进行归一化。

    • 定义 $ MathJax-Element-790 $ 。由于挑选 $ MathJax-Element-856 $ 时排除了 $ MathJax-Element-917 $ ,因此有 $ MathJax-Element-793 $ 。

    • 定义概率分布 $ MathJax-Element-794 $ ,它刻画了所有其它样本是 $ MathJax-Element-917 $ 的邻居的概率分布。

  4. 假设经过降维,样本 $ MathJax-Element-796 $ 在低维空间的表示为 $ MathJax-Element-797 $ ,其中 $ MathJax-Element-798 $ 。

    • 定义:

      $ q_{j\mid i} = q( \mathbf{\vec z}_j\mid \mathbf{\vec z}_i) = \frac{\exp\left(-||\mathbf{\vec z}_j-\mathbf{\vec z}_i||^2 \right)}{\sum_{k\ne i}\exp\left(-||\mathbf{\vec z}_k-\mathbf{\vec z}_i||^2\right)} $

      其中 $ MathJax-Element-869 $ 表示给定一个样本 $ MathJax-Element-937 $ ,然后计算 $ MathJax-Element-801 $ 是 $ MathJax-Element-932 $ 的邻居的概率。

      • 这里选择 $ MathJax-Element-803 $ 为固定值。
      • 同样地,有 $ MathJax-Element-804 $ 。
    • 定义概率分布 $ MathJax-Element-805 $ ,它刻画了所有其它样本是 $ MathJax-Element-937 $ 的邻居的概率分布。

  5. 对于样本 $ MathJax-Element-917 $ ,如果降维的效果比较好,则有 $ MathJax-Element-808 $ 。即:降维前后不改变 $ MathJax-Element-917 $ 周围的样本分布。

    对于 $ MathJax-Element-917 $ ,定义其损失函数为分布 $ MathJax-Element-836 $ 和 $ MathJax-Element-812 $ 的距离,通过 KL 散度来度量。

    对于全体数据集 $ MathJax-Element-841 $ ,整体损失函数为:

    $ \mathcal L = \sum_{i=1}^N KL(P_i||Q_i) =\sum_{i=1}^N\sum_{j=1}^N p_{j\mid i} \log \frac{p_{j\mid i}}{q_{j\mid i}} $
  6. KL 散度具有不对称性,因此不同的错误对应的代价是不同的。给定样本 $ MathJax-Element-917 $ :

    • 对于高维距离较远的点 $ MathJax-Element-856 $ ,假设 $ MathJax-Element-816 $ , 如果在低维映射成距离比较近的点,假设 $ MathJax-Element-817 $ 。则该点的代价为 $ MathJax-Element-818 $ 。
    • 对于高维距离较近的点 $ MathJax-Element-856 $ ,假设 $ MathJax-Element-820 $ , 如果在低维映射成距离比较远的点,假设 $ MathJax-Element-821 $ 。则该点的代价为 $ MathJax-Element-822 $ 。

    因此SNE 倾向于将高维空间较远的点映射成低位空间中距离较近的点。这意味着SNE 倾向于保留高维数据中的局部特征(因为远处的特征会被扭曲)。因此SNE 更关注局部结构而忽视了全局结构。

  7. 从 $ MathJax-Element-823 $ 可以看到: $ MathJax-Element-840 $ 是与 $ MathJax-Element-917 $ 相关的、用于对距离进行归一化的参数。

    • 若 $ MathJax-Element-840 $ 较大,则概率 $ MathJax-Element-831 $ 的样本 $ MathJax-Element-856 $ 更多,它们覆盖的范围更广,概率分布 $ MathJax-Element-836 $ 的熵越大。
    • 若 $ MathJax-Element-840 $ 较小,则概率 $ MathJax-Element-831 $ 的样本 $ MathJax-Element-856 $ 更少,它们覆盖的范围更窄,概率分布 $ MathJax-Element-836 $ 的熵越小。

    定义困惑度为: $ MathJax-Element-834 $ , 其中 $ MathJax-Element-835 $ 表示概率分布 $ MathJax-Element-836 $ 的熵。

    • 困惑度刻画了 $ MathJax-Element-917 $ 附近的有效近邻点个数。
    • 通常选择困惑度为 5~50 之间。它表示:对于给定的 $ MathJax-Element-917 $ ,只考虑它周围最近的 5~50 个样本的分布。
    • 给定困惑度之后,可以用二分搜索来寻找合适的 $ MathJax-Element-840 $ 。
  8. 当 $ MathJax-Element-840 $ 已经求得,可以根据数据集 $ MathJax-Element-841 $ 以及公式 $ MathJax-Element-842 $ 来求出 $ MathJax-Element-843 $ 。

    剔除 $ MathJax-Element-844 $ 中的已知量( $ MathJax-Element-845 $ ),则有: $ MathJax-Element-846 $ 。

    可以通过梯度下降法求解损失函数的极小值。

    • 记 $ MathJax-Element-847 $ ,则有 $ MathJax-Element-848 $ 。

      考虑到softmax 交叉熵的损失函数 $ MathJax-Element-849 $ 的梯度为 $ MathJax-Element-850 $ 。令分布 $ MathJax-Element-851 $ 为样本的真实标记 $ MathJax-Element-852 $ ,则有:

      $ \nabla_{y_{i,j}}\left(\sum_{j=1}^N p_{j\mid i} \log q_{j\mid i}\right) = p_{j\mid i} - q_{j\mid i}\\ \nabla_{\mathbf{\vec z}_i}\left(\sum_{j=1}^N p_{j\mid i} \log q_{j\mid i}\right) = \nabla_{y_{i,j}}\left(\sum_{j=1}^N - p_{j\mid i} \log q_{j\mid i}\right)\times \nabla_{\mathbf{\vec z}_i} y_{i,j} \\ = -2( p_{j\mid i} - q_{j\mid i})\times(\mathbf{\vec z}_i-\mathbf{\vec z}_j)\\ \nabla_{\mathbf{\vec z}_j}\left(\sum_{i=1}^N p_{j\mid i} \log q_{j\mid i}\right) = \nabla_{y_{i,j}}\left(\sum_{i=1}^N - p_{j\mid i} \log q_{j\mid i}\right)\times \nabla_{\mathbf{\vec z}_j} y_{i,j} \\ = -2( p_{j\mid i} - q_{j\mid i})\times(\mathbf{\vec z}_j-\mathbf{\vec z}_i) $

      考虑梯度 $ MathJax-Element-920 $ ,有两部分对它产生贡献:

      • 给定 $ MathJax-Element-917 $ 时,梯度的贡献为 : $ MathJax-Element-855 $ 。
      • 给定 $ MathJax-Element-856 $ 时,梯度的贡献为: $ MathJax-Element-857 $ 。

      因此有:

      $ \nabla_{\mathbf{\vec z}_i} \mathcal L=-\sum_j( -2( p_{j\mid i} - q_{j\mid i})\times(\mathbf{\vec z}_i-\mathbf{\vec z}_j)) + ( -2( p_{i\mid j} - q_{i\mid j})\times(\mathbf{\vec z}_i-\mathbf{\vec z}_j))\\ =\sum_j 2( p_{j\mid i} - q_{j\mid i}+ p_{i\mid j} - q_{i\mid j})(\mathbf{\vec z}_i-\mathbf{\vec z}_j) $

      该梯度可以用分子之间的引力和斥力进行解释:低维空间中的点 $ MathJax-Element-937 $ 的位置是由其它所有点对其作用力的合力决定。

      • 某个点 $ MathJax-Element-864 $ 对 $ MathJax-Element-937 $ 的作用力的方向:沿着 $ MathJax-Element-861 $ 的方向。
      • 某个点 $ MathJax-Element-864 $ 对 $ MathJax-Element-937 $ 的作用力的大小:取决于 $ MathJax-Element-864 $ 和 $ MathJax-Element-937 $ 之间的距离。
    • 为了避免陷入局部最优解,可以采用如下的方法:

      • 采用基于动量的随机梯度下降法:

        $ \mathbf{\vec v}\leftarrow \alpha\mathbf{\vec v}-\epsilon\nabla_{\mathbf{\vec z}_i} \mathcal L\\ \mathbf{\vec z}_i\leftarrow \mathbf{\vec z}_i+\mathbf{\vec v} $

        其中 $ MathJax-Element-866 $ 为学习率, $ MathJax-Element-867 $ 为权重衰减系数。

      • 每次迭代过程中引入一些高斯噪声,然后逐渐减小该噪声。

8.2 对称 SNE

  1. SNE 中使用的是条件概率分布 $ MathJax-Element-868 $ 和 $ MathJax-Element-869 $ ,它们分别表示在高维和低维下,给定第 $ MathJax-Element-874 $ 个样本的情况下,第 $ MathJax-Element-875 $ 个样本的分布。

    而对称 SNE 中使用联合概率分布 $ MathJax-Element-903 $ 和 $ MathJax-Element-902 $ ,它们分别表示在高维和低维下,第 $ MathJax-Element-874 $ 个样本和第 $ MathJax-Element-875 $ 个样本的联合分布。其中:

    $ p_{i,j} = \frac{\exp\left(-||\mathbf{\vec x}_i-\mathbf{\vec x}_j||^2/(2\sigma^2)\right)}{\sum_{k}\sum_{l,k\ne l}\exp\left(-||\mathbf{\vec x}_k-\mathbf{\vec x}_l||^2/(2\sigma^2)\right)}\\ q_{i,j} = \frac{\exp\left(-||\mathbf{\vec z}_i-\mathbf{\vec z}_j||^2 \right)}{\sum_{k}\sum_{l,k\ne l}\exp\left(-||\mathbf{\vec z}_k-\mathbf{\vec z}_l||^2\right)}\\ p_{i,i} = 0,\quad q_{i,i}=0 $

    根据定义可知 $ MathJax-Element-903 $ 和 $ MathJax-Element-902 $ 都满足对称性: $ MathJax-Element-878 $ 、 $ MathJax-Element-879 $ 。

  2. 上述定义的 $ MathJax-Element-903 $ 存在异常值问题。

    • 当 $ MathJax-Element-917 $ 是异常值时,对所有的 $ MathJax-Element-882 $ , 有 $ MathJax-Element-883 $ 都很大。这使得 $ MathJax-Element-884 $ 都几乎为 0 。

      这就使得 $ MathJax-Element-917 $ 的代价: $ MathJax-Element-886 $ 。即:无论 $ MathJax-Element-937 $ 周围的点的分布如何,它们对于代价函数的影响忽略不计。

      而在原始 SNE 中,可以保证 $ MathJax-Element-888 $ , $ MathJax-Element-937 $ 周围的点的分布会影响代价函数。

    • 为解决异常值问题,定义: $ MathJax-Element-890 $ 。这就使得 $ MathJax-Element-891 $ ,从而使得 $ MathJax-Element-937 $ 周围的点的分布对代价函数有一定的贡献。

      注意:这里并没有调整 $ MathJax-Element-902 $ 的定义。

  3. 对称SNE 的目标函数为: $ MathJax-Element-894 $ 。

    根据前面的推导有: $ MathJax-Element-895 $ 。其中: $ MathJax-Element-896 $ 。

  4. 实际上对称SNE 的效果只是略微优于原始SNE 的效果。

8.3 拥挤问题

  1. 拥挤问题Crowding Problem:指的是SNE 的可视化效果中,不同类别的簇挤在一起,无法区分开来。

    拥挤问题本质上是由于高维空间距离分布和低维空间距离分布的差异造成的。

  2. 考虑 $ MathJax-Element-897 $ 维空间中一个以原点为中心、半径为1 的超球体。在球体内部随机选取一个点,则该点距离原点的距离为 $ MathJax-Element-898 $ 的概率密度分布为:

    $ p(r) = \lim_{\Delta r\rightarrow 0 }\frac{(r+\Delta r)^n-r^n}{\Delta r} = nr^{n-1} $

    sne_crowding

    累计概率分布为: $ MathJax-Element-899 $ 。

    sne_crowding2

    可以看到:随着空间维度的增长,采样点在原点附近的概率越低、在球体表面附近的概率越大。

    如果直接将这种距离分布关系保留到低维,则就会出现拥挤问题。

8.4 t-SNE

  1. t-SNE 通过采用不同的分布来解决拥挤问题:

    • 在高维空间下使用高斯分布将距离转换为概率分布。
    • 在低维空间下使用 t 分布将距离转换为概率分布。这也是t-SNE 的名字的由来。
  2. t-SNE 使用自由度为1t 分布。此时有: $ MathJax-Element-900 $ 。

    则梯度为:

    $ \nabla_{\mathbf{\vec z}_i} \mathcal L= \sum_j 4(p_{i,j}-q_{i,j})(\mathbf{\vec z}_i-\mathbf{\vec z}_j)(1+||\mathbf{\vec z}_i-\mathbf{\vec z}_j||^2)^{-1} $

    也可以选择自由度超过 1t 分布。自由度越高,越接近高斯分布。

  3. t 分布相对于高斯分布更加偏重长尾。可以看到:

    • 对于高维空间相似度较大的点(如下图中的 q1 ),相比较于高斯分布, t 分布在低维空间中的距离要更近一点。
    • 对于高维空间相似度较小的点(如下图中的 q2 ),相比较于高斯分布,t 分布在低维空间中的距离要更远一点。

    即:同一个簇内的点(距离较近)聚合的更紧密,不同簇之间的点(距离较远)更加疏远。

    t_dist

  4. 优化过程中的技巧:

    • 提前压缩early compression:开始初始化的时候,各个点要离得近一点。这样小的距离,方便各个聚类中心的移动。可以通过引入L2正则项(距离的平方和)来实现。
    • 提前放大early exaggeration:在开始优化阶段, $ MathJax-Element-903 $ 乘以一个大于 1 的数进行扩大,来避免 $ MathJax-Element-902 $ 太小导致优化太慢的问题。比如前 50 轮迭代, $ MathJax-Element-903 $ 放大四倍。
  5. t-SNE 的主要缺点:

    • t-SNE 主要用于可视化,很难用于降维。有两个原因:

      • t-SNE 没有显式的预测部分,所以它无法对测试样本进行直接降维。

        一个解决方案是:构建一个回归模型来建立高维到低维的映射关系,然后通过该模型来对测试样本预测其低维坐标。

      • t-SNE 通常用于2维或者3维的可视化。如果数据集相互独立的特征数量如果较大,则映射到 2~3 维之后信息损失严重。

    • t-SNE 中的距离、概率本身没有意义,它们主要用于描述样本之间的概率分布。

    • t-SNE 代价函数是非凸的,可能得到局部最优解。

    • t-SNE 计算开销较大,训练速度慢。其计算复杂度为 $ MathJax-Element-904 $ 。经过优化之后可以达到 $ MathJax-Element-905 $ 。

8.5 t-SNE 改进

  1. 2014Mattern 在论文Accelerating t-SNE using Tree-Based Algorithms 中对t-SNE 进行了改进,主要包括两部分:

    • 使用kNN 图来表示高维空间中点的相似度。
    • 优化了梯度的求解过程。

8.5.1 kNN 图的相似度表示

  1. 注意到 $ MathJax-Element-906 $ 的表达式:

    $ p( \mathbf{\vec x}_j\mid \mathbf{\vec x}_i) = \frac{\exp\left(-||\mathbf{\vec x}_j-\mathbf{\vec x}_i||^2/(2\sigma_i^2)\right)}{\sum_{k\ne i}\exp\left(-||\mathbf{\vec x}_k-\mathbf{\vec x}_i||^2/(2\sigma_i^2)\right)} $
    • 每个数据点 $ MathJax-Element-917 $ 都需要计算 $ MathJax-Element-908 $ ,这一项需要计算所有其他样本点到 $ MathJax-Element-917 $ 的距离。当数据集较大时,这一项的计算量非常庞大。

    • 事实上,如果两个点相距较远,则它们互为邻居的概率非常小,因为 $ MathJax-Element-910 $ 。

      因此在构建高维空间的点的相似度关系时,只需要考虑 $ MathJax-Element-917 $ 最近的若干个邻居点即可。

  2. 考虑与点 $ MathJax-Element-917 $ 最近的 $ MathJax-Element-929 $ 个点,其中 $ MathJax-Element-914 $ 为点 $ MathJax-Element-917 $ 的周围点的概率分布的困惑度。记这些邻居结点的集合为 $ MathJax-Element-916 $ ,则有:

    $ p_{j\mid i}=p( \mathbf{\vec x}_j\mid \mathbf{\vec x}_i) = \frac{\exp\left(-||\mathbf{\vec x}_j-\mathbf{\vec x}_i||^2/(2\sigma_i^2)\right)}{\sum_{k \in \mathbb N_i}\exp\left(-||\mathbf{\vec x}_k-\mathbf{\vec x}_i||^2/(2\sigma_i^2)\right)},\; j\in \mathbb N_i\\ p_{j\mid i} = 0 ,\;j\ne \mathbb N_i $

    这种方法会大大降低计算量。但是需要首先构建高维空间的kNN 图,从而快速的获取 $ MathJax-Element-917 $ 最近的 $ MathJax-Element-929 $ 个点。

  3. Maaten 使用VP树:vantage-point tree 来构建kNN 图,可以在 $ MathJax-Element-919 $ 的计算复杂度内得到一个精确的 kNN 图。

8.5.2 梯度求解优化

  1. 对 $ MathJax-Element-920 $ 进行变换。定义 $ MathJax-Element-921 $ 。则根据:

    $ q_{i,j}=\frac{(1+||\mathbf{\vec z}_i-\mathbf{\vec z}_j||^2)^{-1}}{\sum_{k}\sum_{l,l\ne k}(1+||\mathbf{\vec z}_k-\mathbf{\vec z}_l||^2)^{-1}} $

    有: $ MathJax-Element-922 $ 。

    则有:

    $ \nabla_{\mathbf{\vec z}_i} \mathcal L= \sum_j 4(p_{i,j}-q_{i,j})(\mathbf{\vec z}_i-\mathbf{\vec z}_j)(1+||\mathbf{\vec z}_i-\mathbf{\vec z}_j||^2)^{-1}\\ =\sum_j 4(p_{i,j}-q_{i,j})(\mathbf{\vec z}_i-\mathbf{\vec z}_j)q_{i,j}Z\\ =4\left[\sum_{j}p_{i,j}q_{i,j}Z (\mathbf{\vec z}_i-\mathbf{\vec z}_j)-\sum_{j}q_{i,j}^2Z(\mathbf{\vec z}_i-\mathbf{\vec z}_j)\right] $

    定义引力为: $ MathJax-Element-923 $ ,斥力为: $ MathJax-Element-924 $ 。则有:

    $ MathJax-Element-925 $ 。

  2. 引力部分的计算比较简单。

    考虑到 $ MathJax-Element-926 $ ,则有:

    $ F_{attr} = \sum_{j}p_{i,j}q_{i,j}Z (\mathbf{\vec z}_i-\mathbf{\vec z}_j)=\sum_jp_{i,j}\frac{\mathbf{\vec z}_i-\mathbf{\vec z}_j}{1+||\mathbf{\vec z}_i-\mathbf{\vec z}_j||^2} $

    根据 $ MathJax-Element-927 $ ,则只可以忽略较远的结点。 仅考虑与点 $ MathJax-Element-937 $ 最近的 $ MathJax-Element-929 $ 个点,则引力部分的计算复杂度为 $ MathJax-Element-930 $ 。

  3. 斥力部分的计算比较复杂,但是仍然有办法进行简化。

    • 考虑下图中的三个点,其中 $ MathJax-Element-931 $ 。此时认为点 $ MathJax-Element-932 $ 和 $ MathJax-Element-933 $ 对点 $ MathJax-Element-937 $ 的斥力是近似相等的。

    tsne2

    • 事实上这种情况在低维空间中很常见,甚至某片区域中每个点对 $ MathJax-Element-937 $ 的斥力都可以用同一个值来近似,如下图所示。

      假设区域 $ MathJax-Element-945 $ 中 4 个点对 $ MathJax-Element-937 $ 产生的斥力都是近似相等的,则可以计算这 4 个点的中心(虚拟的点)产生的斥力 $ MathJax-Element-938 $ ,则区域 $ MathJax-Element-945 $ 产生的总的斥力为 $ MathJax-Element-940 $ 。

    tsne3

    • Matten 使用四叉树来完成区域搜索任务,并用该区域中心点产生的斥力作为整个区域的斥力代表值。

      并非所有区域都满足该近似条件,这里使用Barnes-Hut 算法搜索并验证符合近似条件的点-区域 组合 。

    • 事实上可以进一步优化,近似区域到区域之间的斥力。

      • 如下所示为区域 $ MathJax-Element-945 $ 和区域 $ MathJax-Element-946 $ 中,任意两个结点之间的斥力都可以用 $ MathJax-Element-944 $ 来近似。其中 $ MathJax-Element-944 $ 代表区域 $ MathJax-Element-945 $ 的中心(虚拟的点)和区域 $ MathJax-Element-946 $ 的中心(虚拟的点)产生的斥力。
      • 同样也需要判断两个区域之间的斥力是否满足近似条件。这里采用了Dual-tree 算法搜索并验证符合近似条件的区域-区域 组合 。

      tsne4

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。
列表为空,暂无数据
    我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
    原文