返回介绍

数学基础

统计学习

深度学习

工具

Scala

二、反向传播

发布于 2023-07-17 23:38:25 字数 9742 浏览 0 评论 0 收藏 0

2.1 前向传播

  1. 考虑计算单个标量 $ MathJax-Element-123 $ 的计算图:

    • 假设有 $ MathJax-Element-81 $ 个输入节点: $ MathJax-Element-82 $ 。它们对应的是模型的参数和输入。
    • 假设 $ MathJax-Element-83 $ 为中间节点。
    • 假设 $ MathJax-Element-123 $ 为输出节点,它对应的是模型的代价函数。
    • 对于每个非输入节点 $ MathJax-Element-124 $ ,定义其双亲节点的集合为 $ MathJax-Element-86 $ 。
    • 假设每个非输入节点 $ MathJax-Element-124 $ ,操作 $ MathJax-Element-88 $ 与其关联,并且通过对该函数求值得到: $ MathJax-Element-89 $ 。

    通过仔细排序(有向无环图的拓扑排序算法),使得可以依次计算 $ MathJax-Element-95 $ 。

  2. 前向传播算法:

    • 输入:

      • 计算图 $ MathJax-Element-332 $
      • 初始化向量 $ MathJax-Element-121 $
    • 输出: $ MathJax-Element-123 $ 的值

    • 算法步骤:

      • 初始化输入节点: $ MathJax-Element-94 $ 。

      • 根据计算图,从前到后计算 $ MathJax-Element-95 $ 。对于 $ MathJax-Element-96 $ 计算过程为:

        • 计算 $ MathJax-Element-118 $ 的双亲节点集合 $ MathJax-Element-98 $ 。
        • 计算 $ MathJax-Element-118 $ : $ MathJax-Element-100 $ 。
      • 输出 $ MathJax-Element-123 $ 。

2.2 反向传播

  1. 计算 $ MathJax-Element-428 $ 时需要构造另一张计算图 $ MathJax-Element-127 $ : 它的节点与 $ MathJax-Element-332 $ 中完全相同,但是计算顺序完全相反。

    计算图 $ MathJax-Element-127 $ 如下图所示:

  2. 对于图中的任意一非输出节点 $ MathJax-Element-118 $ (非 $ MathJax-Element-123 $ ),根据链式法则:

    $ \frac{\partial u_n}{\partial u_j}=\sum_{(\partial u_i,\partial u_j) \in \mathcal B}\frac{\partial u_n}{\partial u_i}\frac{\partial u_i}{\partial u_j} $

    其中 $ MathJax-Element-108 $ 表示图 $ MathJax-Element-127 $ 中的边 $ MathJax-Element-112 $ 。

    • 若图 $ MathJax-Element-127 $ 中存在边 $ MathJax-Element-112 $ ,则在图 $ MathJax-Element-332 $ 中存在边 $ MathJax-Element-114 $ ,则 $ MathJax-Element-124 $ 为 $ MathJax-Element-118 $ 的子节点。
    • 设图 $ MathJax-Element-332 $ 中 $ MathJax-Element-118 $ 的子节点的集合为 $ MathJax-Element-119 $ ,则上式改写作:
    $ \frac{\partial u_n}{\partial u_j}=\sum_{u_i \in \mathbb C_j}\frac{\partial u_n}{\partial u_i}\frac{\partial u_i}{\partial u_j} $
  3. 反向传播算法:

    • 输入:

      • 计算图 $ MathJax-Element-332 $
      • 初始化参数向量 $ MathJax-Element-121 $
    • 输出: $ MathJax-Element-428 $

    • 算法步骤:

      • 运行计算 $ MathJax-Element-123 $ 的前向算法,获取每个节点的值。

      • 给出一个 grad_table表,它存储的是已经计算出来的偏导数。

        $ MathJax-Element-124 $ 对应的表项存储的是偏导数 $ MathJax-Element-128 $ 。

      • 初始化 $ MathJax-Element-126 $ 。

      • 沿着计算图 $ MathJax-Element-127 $ 计算偏导数。遍历 $ MathJax-Element-217 $ 从 $ MathJax-Element-445 $ 到 $ MathJax-Element-447 $ :

        • 计算 $ MathJax-Element-434 $ 。其中: $ MathJax-Element-128 $ 是已经存储的 $ MathJax-Element-129 $ , $ MathJax-Element-333 $ 为实时计算的。

          图 $ MathJax-Element-332 $ 中的边 $ MathJax-Element-132 $ 定义了一个操作,而该操作的偏导只依赖于这两个变量,因此可以实时求解 $ MathJax-Element-333 $ 。

        • 存储 $ MathJax-Element-133 $ 。

      • 返回 $ MathJax-Element-453 $ 。

  1. 反向传播算法计算所有的偏导数,计算量与 $ MathJax-Element-332 $ 中的边的数量成正比。

    其中每条边的计算包括计算偏导数,以及执行一次向量点积。

  2. 上述反向传播算法为了减少公共子表达式的计算量 ,并没有考虑存储的开销。这避免了重复子表达式的指数级的增长。

    • 某些算法可以通过对计算图进行简化从而避免更多的子表达式。
    • 有些算法会重新计算这些子表达式而不是存储它们,从而节省内存。

2.3 反向传播示例

  1. 对于 $ MathJax-Element-136 $ ,将公式拆分成 $ MathJax-Element-146 $ 和 $ MathJax-Element-138 $ ,则有:

    $ \frac{\partial q}{\partial x}=1,\quad \frac{\partial q}{\partial y}=1,\quad \frac{\partial f}{\partial q}=z,\quad \frac{\partial f}{\partial z}=q $

    根据链式法则,有 $ MathJax-Element-144 $ 。

    假设 $ MathJax-Element-140 $ ,则计算图如下。其中:绿色为前向传播的值,红色为反向传播的结果。

    • 前向传播,计算从输入到输出(绿色);反向传播,计算从尾部开始到输入(红色)。

    • 在整个计算图中,每个单元的操作类型,以及输入是已知的。通过这两个条件可以计算出两个结果:

      • 这个单元的输出值。
      • 这个单元的输出值关于输入值的局部梯度比如 $ MathJax-Element-141 $ 和 $ MathJax-Element-142 $ 。

      每个单元计算这两个结果是独立完成的,它不需要计算图中其他单元的任何细节。

      但是在反向传播过程中,单元将获取整个网络的最终输出值(这里是 $ MathJax-Element-143 $ )在单元的输出值上的梯度,即回传的梯度。

      链式法则指出:单元应该将回传的梯度乘以它对其输入的局部梯度,从而得到整个网络的输出对于该单元每个输入值的梯度。如: $ MathJax-Element-144 $ 。

  2. 在多数情况下,反向传播中的梯度可以被直观地解释。如:加法单元、乘法单元、最大值单元。

    假设: $ MathJax-Element-145 $ ,前向传播的计算从输入到输出(绿色),反向传播的计算从尾部开始到输入(红色)。

    • 加法单元 $ MathJax-Element-146 $ ,则 $ MathJax-Element-147 $ 。如果 $ MathJax-Element-153 $ ,则有:

      $ \frac{\partial f}{\partial x}=m,\quad \frac{\partial f}{\partial y}=m $

      这表明:加法单元将回传的梯度相等的分发给它的输入。

    • 乘法单元 $ MathJax-Element-149 $ ,则 $ MathJax-Element-150 $ 。如果 $ MathJax-Element-153 $ ,则有:

      $ \frac{\partial f}{\partial x}=my,\quad \frac{\partial f}{\partial y}=mx $

      这表明:乘法单元交换了输入数据,然后乘以回传的梯度作为每个输入的梯度。

    • 取最大值单元 $ MathJax-Element-152 $ ,则:

      $ \frac{\partial q}{\partial x}=\begin{cases}1,&x>=y\\0,&x=x\\0,&y如果 $ MathJax-Element-153 $ ,则有:

      $ \frac{\partial f}{\partial x}=\begin{cases}m,&x>=y\\0,&x=x\\0,&y这表明:取最大值单元将回传的梯度分发给最大的输入。

  3. 通常如果函数 $ MathJax-Element-154 $ 的表达式非常复杂,则当对 $ MathJax-Element-155 $ 进行微分运算,运算结束后会得到一个巨大而复杂的表达式。

    • 实际上并不需要一个明确的函数来计算梯度,只需要如何使用反向传播算法计算梯度即可。
    • 可以把复杂的表达式拆解成很多个简单的表达式(这些表达式的局部梯度是简单的、已知的),然后利用链式法则来求取梯度。
    • 在计算反向传播时,前向传播过程中得到的一些中间变量非常有用。实际操作中,最好对这些中间变量缓存。

2.4 深度前馈神经网络应用

  1. 给定一个样本,其定义代价函数为 $ MathJax-Element-492 $ ,其中 $ MathJax-Element-503 $ 为神经网络的预测值。

    考虑到正则化项,定义损失函数为: $ MathJax-Element-510 $ 。其中 $ MathJax-Element-188 $ 为正则化项,而 $ MathJax-Element-160 $ 包含了所有的参数(包括每一层的权重 $ MathJax-Element-312 $ 和每一层的偏置 $ MathJax-Element-162 $ )。

    这里给出的是单个样本的损失函数,而不是训练集的损失函数。

  2. 计算 $ MathJax-Element-503 $ 的计算图为:

  3. 前向传播用于计算深度前馈神经网络的损失函数。算法为:

    • 输入:

      • 网络层数 $ MathJax-Element-335 $

      • 每一层的权重矩阵 $ MathJax-Element-180 $

      • 每一层的偏置向量 $ MathJax-Element-181 $

      • 每一层的激活函数 $ MathJax-Element-182 $

        也可以对所有的层使用同一个激活函数

      • 输入 $ MathJax-Element-183 $ 和对应的标记 $ MathJax-Element-518 $ 。

      • 隐层到输出的函数 $ MathJax-Element-539 $ 。

    • 输出:损失函数 $ MathJax-Element-170 $

    • 算法步骤:

      • 初始化 $ MathJax-Element-171 $

      • 迭代: $ MathJax-Element-172 $ ,计算:

        • $ MathJax-Element-204 $
        • $ MathJax-Element-174 $
      • 计算 $ MathJax-Element-537 $ , $ MathJax-Element-547 $ 。

  4. 反向传播用于计算深度前馈网络的损失函数对于参数的梯度。

    梯度下降算法需要更新模型参数,因此只关注损失函数对于模型参数的梯度,不关心损失函数对于输入的梯度 $ MathJax-Element-178 $ 。

    • 根据链式法则有: $ MathJax-Element-196 $ 。

      考虑到 $ MathJax-Element-728 $ ,因此雅可比矩阵 $ MathJax-Element-200 $ 为对角矩阵,对角线元素 $ MathJax-Element-803 $ 。 $ MathJax-Element-805 $ 表示 $ MathJax-Element-774 $ 的第 $ MathJax-Element-199 $ 个元素。

      因此 $ MathJax-Element-742 $ ,其中 $ MathJax-Element-195 $ 表示Hadamard积。

    • 因为 $ MathJax-Element-204 $ ,因此:

      $ \nabla_{\mathbf{\vec b}_k}J=\nabla_{\mathbf{\vec a}_k}J,\quad \nabla_{\mathbf W_k}J=(\nabla_{\mathbf{\vec a}_k}J) \mathbf{\vec h}^{T}_{k-1} $

      上式仅仅考虑从 $ MathJax-Element-830 $ 传播到 $ MathJax-Element-852 $ 中的梯度。 考虑到损失函数中的正则化项 $ MathJax-Element-188 $ 包含了权重和偏置,因此需要增加正则化项的梯度。则有:

      $ \nabla_{\mathbf{\vec b}_k}J=\nabla_{\mathbf{\vec a}_k}J+\lambda \nabla_{\mathbf{\vec b}_k}\Omega(\vec\theta)\\ \nabla_{\mathbf W_k}J=(\nabla_{\mathbf{\vec a}_k}J) \mathbf{\vec h}^{T}_{k-1}+\lambda \nabla_{\mathbf W_k}\Omega(\vec\theta) $
    • 因为 $ MathJax-Element-204 $ ,因此: $ MathJax-Element-205 $ 。

  5. 反向传播算法:

    • 输入:

      • 网络层数 $ MathJax-Element-335 $
      • 每一层的权重矩阵 $ MathJax-Element-180 $
      • 每一层的偏置向量 $ MathJax-Element-181 $
      • 每一层的激活函数 $ MathJax-Element-182 $
      • 输入 $ MathJax-Element-183 $ 和对应的标记 $ MathJax-Element-561 $
    • 输出:梯度 $ MathJax-Element-185 $

    • 算法步骤:

      • 通过前向传播计算损失函数 $ MathJax-Element-186 $ 以及网络的输出 $ MathJax-Element-566 $ 。

      • 计算输出层的导数: $ MathJax-Element-655 $ 。

        这里等式成立的原因是:正则化项 $ MathJax-Element-188 $ 与模型输出 $ MathJax-Element-566 $ 无关。

      • 计算最后一层隐单元的梯度: $ MathJax-Element-689 $ 。

      • 迭代: $ MathJax-Element-190 $ ,迭代步骤如下:

        每一轮迭代开始之前,维持不变式: $ MathJax-Element-191 $ 。

        • 计算 $ MathJax-Element-698 $ : $ MathJax-Element-871 $ 。

        • 令: $ MathJax-Element-194 $ 。

        • 计算对权重和偏置的偏导数:

          $ \nabla_{\mathbf{\vec b}_k}J=\mathbf{\vec g}+\lambda \nabla_{\mathbf{\vec b}_k}\Omega(\vec\theta)\\ \nabla_{\mathbf W_k}J=\mathbf{\vec g} \mathbf{\vec h}^{T}_{k-1}+\lambda \nabla_{\mathbf W_k}\Omega(\vec\theta) $
        • 计算 $ MathJax-Element-897 $ : $ MathJax-Element-905 $ 。

        • 令: $ MathJax-Element-910 $ 。

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。
列表为空,暂无数据
    我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
    原文