返回介绍

数学基础

统计学习

深度学习

工具

Scala

五、ResNet

发布于 2023-07-17 23:38:25 字数 12803 浏览 0 评论 0 收藏 0

  1. ResNet 提出了一种残差学习框架来解决网络退化问题,从而训练更深的网络。这种框架可以结合已有的各种网络结构,充分发挥二者的优势。

  2. ResNet以三种方式挑战了传统的神经网络架构:

    • ResNet 通过引入跳跃连接来绕过残差层,这允许数据直接流向任何后续层。

      这与传统的、顺序的pipeline 形成鲜明对比:传统的架构中,网络依次处理低级feature 到高级feature

    • ResNet 的层数非常深,高达1202层。而ALexNet 这样的架构,网络层数要小两个量级。

    • 通过实验发现,训练好的 ResNet 中去掉单个层并不会影响其预测性能。而训练好的AlexNet 等网络中,移除层会导致预测性能损失。

  3. ImageNet分类数据集中,拥有152层的残差网络,以3.75% top-5 的错误率获得了ILSVRC 2015 分类比赛的冠军。

  4. 很多证据表明:残差学习是通用的,不仅可以应用于视觉问题,也可应用于非视觉问题。

5.1 网络退化问题

  1. 学习更深的网络的一个障碍是梯度消失/爆炸,该问题可以通过Batch Normalization 在很大程度上解决。

  2. ResNet 论文作者发现:随着网络的深度的增加,准确率达到饱和之后迅速下降,而这种下降不是由过拟合引起的。这称作网络退化问题。

    如果更深的网络训练误差更大,则说明是由于优化算法引起的:越深的网络,求解优化问题越难。如下所示:更深的网络导致更高的训练误差和测试误差。

  3. 理论上讲,较深的模型不应该比和它对应的、较浅的模型更差。因为较深的模型是较浅的模型的超空间。较深的模型可以这样得到:先构建较浅的模型,然后添加很多恒等映射的网络层。

    实际上我们的较深的模型后面添加的不是恒等映射,而是一些非线性层。因此,退化问题表明:通过多个非线性层来近似横等映射可能是困难的。

  4. 解决网络退化问题的方案:学习残差。

5.2 残差块

  1. 假设需要学习的是映射 $ MathJax-Element-88 $ ,残差块使用堆叠的非线性层拟合残差: $ MathJax-Element-89 $ 。

    其中:

    • $ MathJax-Element-698 $ 和 $ MathJax-Element-91 $ 是块的输入和输出向量。

    • $ MathJax-Element-105 $ 是要学习的残差映射。因为 $ MathJax-Element-93 $ ,因此称 $ MathJax-Element-116 $ 为残差。

    • + :通过快捷连接逐个元素相加来执行。快捷连接 指的是那些跳过一层或者更多层的连接。

      • 快捷连接简单的执行恒等映射,并将其输出添加到堆叠层的输出。
      • 快捷连接既不增加额外的参数,也不增加计算复杂度。
    • 相加之后通过非线性激活函数,这可以视作对整个残差块添加非线性,即 $ MathJax-Element-95 $ 。

  2. 前面给出的残差块隐含了一个假设: $ MathJax-Element-105 $ 和 $ MathJax-Element-698 $ 的维度相等。如果它们的维度不等,则需要在快捷连接中对 $ MathJax-Element-698 $ 执行线性投影来匹配维度: $ MathJax-Element-99 $ 。

    事实上当它们维度相等时,也可以执行线性变换。但是实践表明:使用恒等映射足以解决退化问题,而使用线性投影会增加参数和计算复杂度。因此 $ MathJax-Element-100 $ 仅在匹配维度时使用。

  3. 残差函数 $ MathJax-Element-116 $ 的形式是可变的。

    • 层数可变:论文中的实验包含有两层堆叠、三层堆叠,实际任务中也可以包含更多层的堆叠。

      如果 $ MathJax-Element-116 $ 只有一层,则残差块退化线性层: $ MathJax-Element-103 $ 。此时对网络并没有什么提升。

    • 连接形式可变:不仅可用于全连接层,可也用于卷积层。此时 $ MathJax-Element-116 $ 代表多个卷积层的堆叠,而最终的逐元素加法+ 在两个feature map 上逐通道进行。

      此时 x 也是一个feature map,而不再是一个向量。

  4. 残差学习成功的原因:学习残差 $ MathJax-Element-105 $ 比学习原始映射 $ MathJax-Element-106 $ 要更容易。

    • 当原始映射 $ MathJax-Element-118 $ 就是一个恒等映射时, $ MathJax-Element-116 $ 就是一个零映射。此时求解器只需要简单的将堆叠的非线性连接的权重推向零即可。

      实际任务中原始映射 $ MathJax-Element-118 $ 可能不是一个恒等映射:

      • 如果 $ MathJax-Element-118 $ 更偏向于恒等映射(而不是更偏向于非恒等映射),则 $ MathJax-Element-116 $ 就是关于恒等映射的抖动,会更容易学习。
      • 如果原始映射 $ MathJax-Element-118 $ 更偏向于零映射,那么学习 $ MathJax-Element-118 $ 本身要更容易。但是在实际应用中,零映射非常少见,因为它会导致输出全为0。
    • 如果原始映射 $ MathJax-Element-118 $ 是一个非恒等映射,则可以考虑对残差模块使用缩放因子。如Inception-Resnet 中:在残差模块与快捷连接叠加之前,对残差进行缩放。

      注意:ResNet 作者在随后的论文中指出:不应该对恒等映射进行缩放。因此Inception-Resnet对残差模块进行缩放。

    • 可以通过观察残差 $ MathJax-Element-116 $ 的输出来判断:如果 $ MathJax-Element-116 $ 的输出均为0附近的、较小的数,则说明原始映射 $ MathJax-Element-118 $ 更偏向于恒等映射;否则,说明原始映射 $ MathJax-Element-118 $ 更偏向于非横等映射。

5.3 ResNet 分析

  1. Veit et al. 认为ResNet 工作较好的原因是:一个ResNet 网络可以看做是一组较浅的网络的集成模型。

    但是ResNet 的作者认为这个解释是不正确的。因为集成模型要求每个子模型是独立训练的,而这组较浅的网络是共同训练的。

  2. 论文《Residual Networks Bahave Like Ensemble of Relatively Shallow Networks》ResNet 进行了深入的分析。

    • 通过分解视图表明:ResNet 可以被视作许多路径的集合。

    • 通过研究ResNet 的梯度流表明:网络训练期间只有短路径才会产生梯度流,深的路径不是必须的。

    • 通过破坏性实验,表明:

      • 即使这些路径是共同训练的,它们也不是相互依赖的。
      • 这些路径的行为类似集成模型,其预测准确率平滑地与有效路径的数量有关。

5.3.1 分解视图

  1. 考虑从输出 $ MathJax-Element-119 $ 到 $ MathJax-Element-120 $ 的三个ResNet 块构建的网络。根据:

    $ \mathbf {\vec y}_3=\mathbf {\vec y}_2+f_3(\mathbf {\vec y}_2) =[\mathbf {\vec y}_1+f_2(\mathbf {\vec y}_1)]+f_3(\mathbf {\vec y}_1+f_2(\mathbf {\vec y}_1))\\ =[\mathbf {\vec y}_0+f_1(\mathbf {\vec y}_0)+f_2(\mathbf {\vec y}_0+f_1(\mathbf {\vec y}_0))]+f_3(\mathbf {\vec y}_0+f_1(\mathbf {\vec y}_0)+f_2(\mathbf {\vec y}_0+f_1(\mathbf {\vec y}_0))) $

    下图中:左图为原始形式,右图为分解视图。分解视图中展示了数据从输入到输出的多条路径。

    对于严格顺序的网络(如VGG ),这些网络中的输入总是在单个路径中从第一层直接流到最后一层。如下图所示。

  2. 分解视图中, 每条路径可以通过二进制编码向量 $ MathJax-Element-121 $ 来索引:如果流过残差块 $ MathJax-Element-128 $ ,则 $ MathJax-Element-123 $ ;如果跳过残差块 $ MathJax-Element-128 $ ,则 $ MathJax-Element-125 $ 。

    因此ResNet 从输入到输出具有 $ MathJax-Element-140 $ 条路径,第 $ MathJax-Element-619 $ 个残差块 $ MathJax-Element-128 $ 的输入汇聚了之前的 $ MathJax-Element-129 $ 个残差块的 $ MathJax-Element-130 $ 条路径。

  3. 普通的前馈神经网络也可以在单个神经元(而不是网络层)这一粒度上运用分解视图,这也可以将网络分解为不同路径的集合。

    它与ResNet 分解的区别是:

    • 普通前馈神经网络的神经元分解视图中,所有路径都具有相同的长度。
    • ResNet 网络的残差块分解视图中,所有路径具有不同的路径长度。

5.3.2 路径长度分析

  1. ResNet 中,从输入到输出存在许多条不同长度的路径。这些路径长度的分布服从二项分布。对于 $ MathJax-Element-566 $ 层深的ResNet,大多数路径的深度为 $ MathJax-Element-132 $ 。

    下图为一个 54 个块的ResNet 网络的路径长度的分布 ,其中95% 的路径只包含 19~35个块。

5.3.3 路径梯度分析

  1. ResNet 中,路径的梯度幅度随着它在反向传播中经过的残差块的数量呈指数减小。因此,训练期间大多数梯度来源于更短的路径。

  2. 对于一个包含 54 个残差块的ResNet 网络:

    • 下图表示:单条长度为 $ MathJax-Element-415 $ 的路径在反向传播到 input 处的梯度的幅度的均值,它刻画了长度为 $ MathJax-Element-415 $ 的单条路径的对于更新的影响。

      因为长度为 $ MathJax-Element-415 $ 的路径有多条,因此取其平均。

    • 下图表示:长度为 $ MathJax-Element-415 $ 的所有路径在反向传播到 input 处的梯度的幅度的和。它刻画了长度为 $ MathJax-Element-415 $ 的所有路径对于更新的影响。

      它不仅取决于长度为 $ MathJax-Element-415 $ 的单条路径的对于更新的影响,还取决于长度为 $ MathJax-Element-415 $ 的单条路径的数量。

  3. 有效路径:反向传播到 input 处的梯度幅度相对较大的路径。

    ResNet 中有效路径相对较浅,而且有效路径数量占比较少。在一个54 个块的ResNet 网络中:

    • 几乎所有的梯度更新都来自于长度为 5~17 的路径。
    • 长度为 5~17 的路径占网络所有路径的 0.45% 。
  4. 论文从头开始重新训练ResNet,同时在训练期间只保留有效路径,确保不使用长路径。实验结果表明:相比于完整模型的 6.10% 的错误率,这里实现了 5.96% 的错误率。二者没有明显的统计学上的差异,这表明确实只需要有效路径。

    因此,ResNet 不是让梯度流流通整个网络深度来解决梯度消失问题,而是引入能够在非常深的网络中传输梯度的短路径来避免梯度消失问题。

  5. ResNet 原理类似,随机深度网络起作用有两个原因:

    • 训练期间,网络看到的路径分布会发生变化,主要是变得更短。
    • 训练期间,每个mini-batch 选择不同的短路径的子集,这会鼓励各路径独立地产生良好的结果。

5.3.4 路径破坏性分析

  1. ResNet 网络训练完成之后,如果随机丢弃单个残差块,则测试误差基本不变。因为移除一个残差块时,ResNet 中路径的数量从 $ MathJax-Element-140 $ 减少到 $ MathJax-Element-141 $ ,留下了一半的路径。

    VGG 网络训练完成之后,如果随机丢弃单个块,则测试误差急剧上升,预测结果就跟随机猜测差不多。因为移除一个块时,VGG 中唯一可行的路径被破坏。

  2. 删除ResNet 残差块通常会删除长路径。

    当删除了 $ MathJax-Element-415 $ 个残差块时,长度为 $ MathJax-Element-684 $ 的路径的剩余比例由下式给定: $ MathJax-Element-144 $ 。

    下图中:

    • 删除10个残差模块,一部分有效路径(路径长度为5~17)仍然被保留,模型测试性能会部分下降。
    • 删除20个残差模块,绝大部分有效路径(路径长度为5~17)被删除,模型测试性能会大幅度下降。

  3. ResNet 网络中,路径的集合表现出一种类似集成模型的效果。一个关键证据是:它们的整体表现平稳地取决于路径的数量。随着网络删除越来越多的残差块,网络路径的数量降低,测试误差平滑地增加(而不是突变)。

  4. 如果在测试时重新排序网络的残差块,这意味着交换了低层映射和高层映射。采用Kendall Tau rank 来衡量网络结构被破坏的程度,结果表明:随着 Kendall Tau rank 的增加,预测错误率也在增加。

5.4 网络性能

  1. plain 网络:一些简单网络结构的叠加,如下图所示。图中给出了四种plain 网络,它们的区别主要是网络深度不同。其中,输入图片尺寸 224x224 。

    ResNet 简单的在plain 网络上添加快捷连接来实现。

    FLOPsfloating point operations 的缩写,意思是浮点运算量,用于衡量算法/模型的复杂度。

    FLOPSfloating point per second的缩写,意思是每秒浮点运算次数,用于衡量计算速度。

  2. 相对于输入的feature map,残差块的输出feature map 尺寸可能会发生变化:

    • 输出 feature map 的通道数增加,此时需要扩充快捷连接的输出feature map 。否则快捷连接的输出 feature map 无法和残差块的feature map 累加。

      有两种扩充方式:

      • 直接通过 0 来填充需要扩充的维度,在图中以实线标识。
      • 通过1x1 卷积来扩充维度,在图中以虚线标识。
    • 输出 feature map 的尺寸减半。此时需要对快捷连接执行步长为 2 的池化/卷积:如果快捷连接已经采用 1x1 卷积,则该卷积步长为2 ;否则采用步长为 2 的最大池化 。

  3. 计算复杂度:

    VGG-1934层 plain 网络Resnet-34
    计算复杂度(FLOPs)19.6 billion3.5 billion3.6 billion
  4. 模型预测能力:在ImageNet 验证集上执行10-crop 测试的结果。

    • A 类模型:快捷连接中,所有需要扩充的维度的填充 0 。
    • B 类模型:快捷连接中,所有需要扩充的维度通过1x1 卷积来扩充。
    • C 类模型:所有快捷连接都通过1x1 卷积来执行线性变换。

    可以看到C 优于BB 优于A。但是 C 引入更多的参数,相对于这种微弱的提升,性价比较低。所以后续的ResNet 均采用 B 类模型。

    模型top-1 误差率top-5 误差率
    VGG-1628.07%9.33%
    GoogleNet-9.15%
    PReLU-net24.27%7.38%
    plain-3428.54%10.02%
    ResNet-34 A25.03%7.76%
    ResNet-34 B24.52%7.46%
    ResNet-34 C24.19%7.40%
    ResNet-5022.85%6.71%
    ResNet-10121.75%6.05%
    ResNet-15221.43%5.71%

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。
列表为空,暂无数据
    我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
    原文