返回介绍

数学基础

统计学习

深度学习

工具

Scala

三、算法实现

发布于 2023-07-17 23:38:25 字数 9143 浏览 0 评论 0 收藏 0

3.1 符号-数值 / 符号-符号方法

  1. 代数表达式和计算图都是对符号symbol进行操作,这些基于代数的表达式或者基于图的表达式称作符号表达式。

    当训练神经网络时,必须给这些符号赋值。如:对于符号 $ MathJax-Element-344 $ 赋予一个实际的数值,如 $ MathJax-Element-207 $ 。

  2. 符号到数值的方法:给定计算图,以及图的一组输入的数值,然后返回在这些输入值处的梯度。

    这种方法用于TorchCaffe之类的库中。

  3. 符号到符号的方法:给定计算图,算法会添加额外的一些节点到计算图中,这些额外的节点提供了所需的导数的符号描述。

    这种方法用于TheanoTensorFlow之类的库中。

    下图左侧为 $ MathJax-Element-208 $ 的计算图,右侧添加了若干节点从而给出了计算 $ MathJax-Element-209 $ 的计算图。

  4. 符号到符号的方法的优点:导数可以使用与原始表达式相同的编程语言来描述。

    导数只是另外一张计算图,因此可以再次运行反向传播算法对导数再进行求导,从而获取更高阶的导数。

  5. 推荐使用符号到符号的方法来求导数。一旦构造出了添加导数后的计算图,那么随后如果给出了输入的数值,可以对图中的任意子节点求值。

    目前通用的计算图求解引擎的做法是:任意时刻,一旦一个节点的父节点都求值完毕,那么该节点将能够立即求值。

  6. 事实上符号到数值的方法与符号到符号的方法执行了同样的计算过程,区别在于:

    • 符号到数值的方法并没有暴露内部的计算过程。
    • 符号到符号的方法将各种求导运算暴露出来,添加到计算图中成为了节点。

3.2 算法框架

  1. 假设计算图 $ MathJax-Element-332 $ 中的每个节点对应一个变量。这里将变量描述为一个张量 $ MathJax-Element-259 $ ,它可以具有任意维度并且可能是标量、向量、矩阵。

    根据前面介绍的张量的链式法则, $ MathJax-Element-212 $ ,则张量的链式法则为:

    $ \nabla_{\mathbf X} z=\sum_{j}(\nabla_{\mathbf X} y_j)\frac{\partial z}{\partial y_j} $

    其中 $ MathJax-Element-217 $ 为张量 $ MathJax-Element-216 $ 展平为一维向量后的索引, $ MathJax-Element-215 $ 为张量 $ MathJax-Element-216 $ 展平为一维向量之后的第 $ MathJax-Element-217 $ 个元素。

3.2.1 三个子过程

  1. $ MathJax-Element-218 $ :返回用于计算 $ MathJax-Element-259 $ 的操作operation 。它就是tensorflow 中的Operation 对象。

    该函数通常返回一个操作对象 $ MathJax-Element-220 $ :

    • 该对象有个 f方法,该方法给出了父节点到 $ MathJax-Element-259 $ 的函数: $ MathJax-Element-958 $ 。

      其中 $ MathJax-Element-222 $ 为 $ MathJax-Element-259 $ 的父节点集合: $ MathJax-Element-242 $

    • 该操作对象有个bprop方法。给定 $ MathJax-Element-259 $ 的某个子节点 $ MathJax-Element-260 $ ,该方法用于已知 $ MathJax-Element-260 $ 的梯度 $ MathJax-Element-228 $ ,求解 $ MathJax-Element-260 $ 对于 $ MathJax-Element-259 $ 的梯度的贡献: $ MathJax-Element-963 $ 。

      如果考虑 $ MathJax-Element-259 $ 的所有子节点集合 $ MathJax-Element-237 $ ,则 它们的梯度贡献之和就是总的梯度:

      $ \nabla_{\mathbf V}z=\sum_{\mathbf C\in \mathbb C_{\mathbf V}^{(\mathcal G)}} (\nabla_{\mathbf V} \mathbf C )\mathbf G_C $
  2. $ MathJax-Element-233 $ :返回图 $ MathJax-Element-332 $ 中节点 $ MathJax-Element-259 $ 的子节点列表,也就是节点 $ MathJax-Element-259 $ 的子节点集合: $ MathJax-Element-237 $ 。

  3. $ MathJax-Element-238 $ :返回图 $ MathJax-Element-332 $ 中节点 $ MathJax-Element-259 $ 的父节点列表,也就是 $ MathJax-Element-259 $ 的父节点集合: $ MathJax-Element-242 $ 。

  4. op.bprop方法总是假定其输入节点各不相同。

    如果定义了一个乘法操作,而且每条输入节点都是x,则op.bprop方法也会认为它们是不同的:

    op.bprop会认为其输入分别为yz,然后求出表达式之后再代入y=x,z=x

  5. 大多数反向传播算法的实现都提供了operation对象以及它的bprop方法。

    如果希望添加自己的反向传播过程,则只需要派生出op.bprop方法即可。

3.2.2 反向传播过程

  1. build_grad 过程采用符号-符号方法 ,用于求解单个结点 $ MathJax-Element-259 $ 的梯度 $ MathJax-Element-289 $ 。

  2. build_grad 在求解过程中会用到裁剪的计算图 $ MathJax-Element-968 $ , $ MathJax-Element-968 $ 会剔除所有与 $ MathJax-Element-973 $ 梯度无关的节点,保留与 $ MathJax-Element-973 $ 梯度有关的节点。

  3. build_grad 过程:

    • 输入:

      • 待求梯度的节点 $ MathJax-Element-259 $
      • 计算图 $ MathJax-Element-332 $
      • 被裁减的计算图 $ MathJax-Element-279 $
      • 梯度表 $ MathJax-Element-290 $
    • 输出: $ MathJax-Element-289 $

    • 算法步骤:

      • 如果 $ MathJax-Element-259 $ 已经就在 $ MathJax-Element-290 $ 中,则直接返回 $ MathJax-Element-250 $ 。

        这样可以节省大量的重复计算

      • 初始化 $ MathJax-Element-986 $ 。

      • 在图 $ MathJax-Element-279 $ 中, 迭代遍历 $ MathJax-Element-259 $ 的子节点的集合 $ MathJax-Element-254 $ : $ MathJax-Element-255 $ :

        • 获取计算 $ MathJax-Element-260 $ 的操作: $ MathJax-Element-257 $

        • 获取该子节点的梯度,这是通过递归来实现的: $ MathJax-Element-980 $ 。

          因为子节点更靠近输出端,因此子节点 $ MathJax-Element-260 $ 的梯度一定是先于 $ MathJax-Element-259 $ 的梯度被计算。

        • 计算子节点 $ MathJax-Element-260 $ 对于 $ MathJax-Element-289 $ 的贡献: $ MathJax-Element-995 $ 。

        • 累加子节点 $ MathJax-Element-260 $ 对于 $ MathJax-Element-289 $ 的贡献: $ MathJax-Element-999 $ 。

      • 存储梯度来更新梯度表: $ MathJax-Element-264 $ 。

      • 在 $ MathJax-Element-332 $ 中插入节点 $ MathJax-Element-316 $ 来更新计算图 $ MathJax-Element-332 $ 。插入过程不仅增加了节点 $ MathJax-Element-316 $ ,还增加了 $ MathJax-Element-316 $ 的父节点到 $ MathJax-Element-316 $ 的边。

      • 返回 $ MathJax-Element-316 $ 。

  4. 反向传播过程:

    • 输入:

      • 计算图 $ MathJax-Element-332 $
      • 目标变量 $ MathJax-Element-284 $
      • 待计算梯度的变量的集合 $ MathJax-Element-282 $
    • 输出: $ MathJax-Element-276 $

    • 算法步骤:

      • 裁剪 $ MathJax-Element-332 $ 为 $ MathJax-Element-279 $ ,使得 $ MathJax-Element-279 $ 仅包含 $ MathJax-Element-284 $ 的祖先之中,那些同时也是 $ MathJax-Element-282 $ 的后代的节点。

        因为这里只关心 $ MathJax-Element-282 $ 中节点的梯度

      • 初始化 $ MathJax-Element-290 $ ,它是一个表,各表项存储的是 $ MathJax-Element-284 $ 对于对应节点的偏导数。

      • 初始化 $ MathJax-Element-285 $ (因为 $ MathJax-Element-286 $ )。

      • 迭代:对每个 $ MathJax-Element-287 $ ,执行 $ MathJax-Element-288 $ 。

      • 返回 $ MathJax-Element-290 $ 。

3.3.3 算法复杂度

  1. 算法复杂度分析过程中,我们假设每个操作的执行都有大概相同的时间开销。

    实际上每个操作可能包含多个算术运算,如:将矩阵乘法视为单个操作的话,就包含了很多乘法和加法。因此每个操作的运行时间实际上相差非常大。

  2. 在具有 $ MathJax-Element-341 $ 个节点的计算图中计算梯度,不会执行超过 $ MathJax-Element-296 $ 的操作,也不会执行超过 $ MathJax-Element-296 $ 个存储。

    因为最坏的情况下前向传播将遍历执行图中的全部 $ MathJax-Element-341 $ 个节点,每两个节点之间定义了一个梯度。

  3. 大多数神经网络的代价函数的计算图是链式结构,因此不会执行超过 $ MathJax-Element-297 $ 的操作。

    从 $ MathJax-Element-296 $ 降低到 $ MathJax-Element-297 $ 是因为:并不是所有的两个节点之间都有数据通路。

  4. 如果直接用梯度计算公式来求解则会产生大量的重复子表达式,导致指数级的运行时间。

    反向传播过程是一种表填充算法,利用存储中间结果(存储子节点的梯度) 来对表进行填充。计算图中的每个节点对应了表中的一个位置,该位置存储的就是该节点的梯度。

    通过顺序填充这些表的条目,反向传播算法避免了重复计算公共表达式。这种策略也称作动态规划。

3.4、应用

  1. 考虑只有单个隐层的最简单的深度前馈网络,使用小批量(minibatch)随机梯度下降法训练模型。反向传播算法用于计算单个minibatch上的代价函数的梯度。

  2. 取训练集上的一组minibatch实例,记做输入矩阵 $ MathJax-Element-1064 $ ,矩阵的每一行就是一个实例,其中 $ MathJax-Element-330 $ 为样本数量, $ MathJax-Element-341 $ 为特征数量。同时给出标记 $ MathJax-Element-1089 $ ,它是每个样本的真实标记。

    设激活函数为 ReLU 激活函数,设模型不包含偏置。设输入层到隐层的权重矩阵为 $ MathJax-Element-320 $ ,则隐层的输出为: $ MathJax-Element-305 $ 。设隐层到输出层的权重矩阵为 $ MathJax-Element-314 $ ,则分类的非归一化对数概率为 $ MathJax-Element-319 $ 。

    假设程序包含了cross_entropy操作,用于计算未归一化对数概率分布定义的交叉熵 ,该交叉熵作为代价函数 $ MathJax-Element-308 $ 。引入正则化项,总的代价函数为: $ MathJax-Element-1114 $ 。

    交叉熵为 $ MathJax-Element-1107 $ 。最小化交叉熵就是最大化似然估计

    其计算图如下所示:

  3. 目标是通过小批量随机梯度下降法求解代价函数的最小值,因此需要计算 $ MathJax-Element-309 $ 。

    从图中看出有两种不同的路径从 $ MathJax-Element-310 $ 回退到 $ MathJax-Element-326 $ :

    • 一条路径是通过正则化项。

      这条路径对于梯度的贡献相对简单,它对于 $ MathJax-Element-312 $ 的梯度贡献为 $ MathJax-Element-313 $ 。

    • 一条路径是通过交叉熵。

      • 对于 $ MathJax-Element-314 $ ,这条分支其梯度的贡献为 $ MathJax-Element-315 $ ,其中 $ MathJax-Element-316 $ 为 $ MathJax-Element-317 $ ,将 $ MathJax-Element-318 $ 替换为 $ MathJax-Element-319 $

      • 对于 $ MathJax-Element-320 $ ,这条分支对于梯度的贡献计算为:

        • 首先计算 $ MathJax-Element-321 $ 。
        • 然后根据relu操作的反向传播规则:根据 $ MathJax-Element-322 $ 中小于零的部分,对应地将 $ MathJax-Element-323 $ 对应位置清零,记清零后的结果为 $ MathJax-Element-324 $ 。
        • 分支的梯度贡献为: $ MathJax-Element-325 $ 。
  4. 该算法的计算成本主要来源于矩阵乘法:

    • 前向传播阶段(为了计算对各节点求值):乘-加运算的数量为 $ MathJax-Element-327 $ ,其中 $ MathJax-Element-328 $ 为权重的数量。
    • 在反向传播阶段:具有相同的计算成本。
  5. 算法的主要存储成本是:需要存储隐层非线性函数的输入。因此存储成本是 $ MathJax-Element-329 $ ,其中 $ MathJax-Element-330 $ 为 minibatch中样例的数量, $ MathJax-Element-331 $ 是隐单元的数量。

  6. 这里描述的反向传播算法要比现实中实际使用的实现更简单。

    • 这里定义的operation限制为返回单个张量的函数,大多数软件实现支持返回多个张量的operation

    • 这里未指定如何控制反向传播的内存消耗。反向传播经常涉及将许多张量加在一起。

      • 朴素算法将分别计算这些张量,然后第二步中将所有张量求和,内存需求过高。
      • 可以通过维持一个buffer,并且在计算时将每个值加到buffer中来避免该瓶颈。
    • 反向传播的具体实现还需要处理各种数据类型,如32位浮点数、整数等。

    • 一些operation具有未定义的梯度,需要跟踪这些情况并向用户报告。

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。
列表为空,暂无数据
    我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
    原文