数学基础
- 线性代数
- 概率论与随机过程
- 数值计算
- 蒙特卡洛方法与 MCMC 采样
- 机器学习方法概论
统计学习
深度学习
- 深度学习简介
- 深度前馈网络
- 反向传播算法
- 正则化
- 深度学习中的最优化问题
- 卷积神经网络
- CNN:图像分类
- 循环神经网络 RNN
- Transformer
- 一、Transformer [2017]
- 二、Universal Transformer [2018]
- 三、Transformer-XL [2019]
- 四、GPT1 [2018]
- 五、GPT2 [2019]
- 六、GPT3 [2020]
- 七、OPT [2022]
- 八、BERT [2018]
- 九、XLNet [2019]
- 十、RoBERTa [2019]
- 十一、ERNIE 1.0 [2019]
- 十二、ERNIE 2.0 [2019]
- 十三、ERNIE 3.0 [2021]
- 十四、ERNIE-Huawei [2019]
- 十五、MT-DNN [2019]
- 十六、BART [2019]
- 十七、mBART [2020]
- 十八、SpanBERT [2019]
- 十九、ALBERT [2019]
- 二十、UniLM [2019]
- 二十一、MASS [2019]
- 二十二、MacBERT [2019]
- 二十三、Fine-Tuning Language Models from Human Preferences [2019]
- 二十四 Learning to summarize from human feedback [2020]
- 二十五、InstructGPT [2022]
- 二十六、T5 [2020]
- 二十七、mT5 [2020]
- 二十八、ExT5 [2021]
- 二十九、Muppet [2021]
- 三十、Self-Attention with Relative Position Representations [2018]
- 三十一、USE [2018]
- 三十二、Sentence-BERT [2019]
- 三十三、SimCSE [2021]
- 三十四、BERT-Flow [2020]
- 三十五、BERT-Whitening [2021]
- 三十六、Comparing the Geometry of BERT, ELMo, and GPT-2 Embeddings [2019]
- 三十七、CERT [2020]
- 三十八、DeCLUTR [2020]
- 三十九、CLEAR [2020]
- 四十、ConSERT [2021]
- 四十一、Sentence-T5 [2021]
- 四十二、ULMFiT [2018]
- 四十三、Scaling Laws for Neural Language Models [2020]
- 四十四、Chinchilla [2022]
- 四十七、GLM-130B [2022]
- 四十八、GPT-NeoX-20B [2022]
- 四十九、Bloom [2022]
- 五十、PaLM [2022] (粗读)
- 五十一、PaLM2 [2023](粗读)
- 五十二、Self-Instruct [2022]
- 句子向量
- 词向量
- 传统CTR 预估模型
- CTR 预估模型
- 一、DSSM [2013]
- 二、FNN [2016]
- 三、PNN [2016]
- 四、DeepCrossing [2016]
- 五、Wide 和 Deep [2016]
- 六、DCN [2017]
- 七、DeepFM [2017]
- 八、NFM [2017]
- 九、AFM [2017]
- 十、xDeepFM [2018]
- 十一、ESMM [2018]
- 十二、DIN [2017]
- 十三、DIEN [2019]
- 十四、DSIN [2019]
- 十五、DICM [2017]
- 十六、DeepMCP [2019]
- 十七、MIMN [2019]
- 十八、DMR [2020]
- 十九、MiNet [2020]
- 二十、DSTN [2019]
- 二十一、BST [2019]
- 二十二、SIM [2020]
- 二十三、ESM2 [2019]
- 二十四、MV-DNN [2015]
- 二十五、CAN [2020]
- 二十六、AutoInt [2018]
- 二十七、Fi-GNN [2019]
- 二十八、FwFM [2018]
- 二十九、FM2 [2021]
- 三十、FiBiNET [2019]
- 三十一、AutoFIS [2020]
- 三十三、AFN [2020]
- 三十四、FGCNN [2019]
- 三十五、AutoCross [2019]
- 三十六、InterHAt [2020]
- 三十七、xDeepInt [2023]
- 三十九、AutoDis [2021]
- 四十、MDE [2020]
- 四十一、NIS [2020]
- 四十二、AutoEmb [2020]
- 四十三、AutoDim [2021]
- 四十四、PEP [2021]
- 四十五、DeepLight [2021]
- 图的表达
- 一、DeepWalk [2014]
- 二、LINE [2015]
- 三、GraRep [2015]
- 四、TADW [2015]
- 五、DNGR [2016]
- 六、Node2Vec [2016]
- 七、WALKLETS [2016]
- 八、SDNE [2016]
- 九、CANE [2017]
- 十、EOE [2017]
- 十一、metapath2vec [2017]
- 十二、GraphGAN [2018]
- 十三、struc2vec [2017]
- 十四、GraphWave [2018]
- 十五、NetMF [2017]
- 十六、NetSMF [2019]
- 十七、PTE [2015]
- 十八、HNE [2015]
- 十九、AANE [2017]
- 二十、LANE [2017]
- 二十一、MVE [2017]
- 二十二、PMNE [2017]
- 二十三、ANRL [2018]
- 二十四、DANE [2018]
- 二十五、HERec [2018]
- 二十六、GATNE [2019]
- 二十七、MNE [2018]
- 二十八、MVN2VEC [2018]
- 二十九、SNE [2018]
- 三十、ProNE [2019]
- Graph Embedding 综述
- 图神经网络
- 一、GNN [2009]
- 二、Spectral Networks 和 Deep Locally Connected Networks [2013]
- 三、Fast Localized Spectral Filtering On Graph [2016]
- 四、GCN [2016]
- 五、神经图指纹 [2015]
- 六、GGS-NN [2016]
- 七、PATCHY-SAN [2016]
- 八、GraphSAGE [2017]
- 九、GAT [2017]
- 十、R-GCN [2017]
- 十一、 AGCN [2018]
- 十二、FastGCN [2018]
- 十三、PinSage [2018]
- 十四、GCMC [2017]
- 十五、JK-Net [2018]
- 十六、PPNP [2018]
- 十七、VRGCN [2017]
- 十八、ClusterGCN [2019]
- 十九、LDS-GNN [2019]
- 二十、DIAL-GNN [2019]
- 二十一、HAN [2019]
- 二十二、HetGNN [2019]
- 二十三、HGT [2020]
- 二十四、GPT-GNN [2020]
- 二十五、Geom-GCN [2020]
- 二十六、Graph Network [2018]
- 二十七、GIN [2019]
- 二十八、MPNN [2017]
- 二十九、UniMP [2020]
- 三十、Correct and Smooth [2020]
- 三十一、LGCN [2018]
- 三十二、DGCNN [2018]
- 三十三、AS-GCN
- 三十四、DGI [2018]
- 三十五、DIFFPOLL [2018]
- 三十六、DCNN [2016]
- 三十七、IN [2016]
- 图神经网络 2
- 图神经网络 3
- 推荐算法(传统方法)
- 一、Tapestry [1992]
- 二、GroupLens [1994]
- 三、ItemBased CF [2001]
- 四、Amazon I-2-I CF [2003]
- 五、Slope One Rating-Based CF [2005]
- 六、Bipartite Network Projection [2007]
- 七、Implicit Feedback CF [2008]
- 八、PMF [2008]
- 九、SVD++ [2008]
- 十、MMMF 扩展 [2008]
- 十一、OCCF [2008]
- 十二、BPR [2009]
- 十三、MF for RS [2009]
- 十四、 Netflix BellKor Solution [2009]
- 推荐算法(神经网络方法 1)
- 一、MIND [2019](用于召回)
- 二、DNN For YouTube [2016]
- 三、Recommending What Video to Watch Next [2019]
- 四、ESAM [2020]
- 五、Facebook Embedding Based Retrieval [2020](用于检索)
- 六、Airbnb Search Ranking [2018]
- 七、MOBIUS [2019](用于召回)
- 八、TDM [2018](用于检索)
- 九、DR [2020](用于检索)
- 十、JTM [2019](用于检索)
- 十一、Pinterest Recommender System [2017]
- 十二、DLRM [2019]
- 十三、Applying Deep Learning To Airbnb Search [2018]
- 十四、Improving Deep Learning For Airbnb Search [2020]
- 十五、HOP-Rec [2018]
- 十六、NCF [2017]
- 十七、NGCF [2019]
- 十八、LightGCN [2020]
- 十九、Sampling-Bias-Corrected Neural Modeling [2019](检索)
- 二十、EGES [2018](Matching 阶段)
- 二十一、SDM [2019](Matching 阶段)
- 二十二、COLD [2020 ] (Pre-Ranking 模型)
- 二十三、ComiRec [2020](https://www.wenjiangs.com/doc/0b4e1736-ac78)
- 二十四、EdgeRec [2020]
- 二十五、DPSR [2020](检索)
- 二十六、PDN [2021](mathcing)
- 二十七、时空周期兴趣学习网络ST-PIL [2021]
- 推荐算法之序列推荐
- 一、FPMC [2010]
- 二、GRU4Rec [2015]
- 三、HRM [2015]
- 四、DREAM [2016]
- 五、Improved GRU4Rec [2016]
- 六、NARM [2017]
- 七、HRNN [2017]
- 八、RRN [2017]
- 九、Caser [2018]
- 十、p-RNN [2016]
- 十一、GRU4Rec Top-k Gains [2018]
- 十二、SASRec [2018]
- 十三、RUM [2018]
- 十四、SHAN [2018]
- 十五、Phased LSTM [2016]
- 十六、Time-LSTM [2017]
- 十七、STAMP [2018]
- 十八、Latent Cross [2018]
- 十九、CSRM [2019]
- 二十、SR-GNN [2019]
- 二十一、GC-SAN [2019]
- 二十二、BERT4Rec [2019]
- 二十三、MCPRN [2019]
- 二十四、RepeatNet [2019]
- 二十五、LINet(2019)
- 二十六、NextItNet [2019]
- 二十七、GCE-GNN [2020]
- 二十八、LESSR [2020]
- 二十九、HyperRec [2020]
- 三十、DHCN [2021]
- 三十一、TiSASRec [2020]
- 推荐算法(综述)
- 多任务学习
- 系统架构
- 实践方法论
- 深度强化学习 1
- 自动代码生成
工具
- CRF
- lightgbm
- xgboost
- scikit-learn
- spark
- numpy
- matplotlib
- pandas
- huggingface_transformer
- 一、Tokenizer
- 二、Datasets
- 三、Model
- 四、Trainer
- 五、Evaluator
- 六、Pipeline
- 七、Accelerate
- 八、Autoclass
- 九、应用
- 十、Gradio
Scala
- 环境搭建
- 基础知识
- 函数
- 类
- 样例类和模式匹配
- 测试和注解
- 集合 collection(一)
- 集合collection(二)
- 集成 Java
- 并发
八、 DenseNet
DenseNet
不是通过更深或者更宽的结构,而是通过特征重用来提升网络的学习能力。ResNet
的思想是:创建从“靠近输入的层” 到 “靠近输出的层” 的直连。而DenseNet
做得更为彻底:将所有层以前馈的形式相连,这种网络因此称作DenseNet
。DenseNet
具有以下的优点:- 缓解梯度消失的问题。因为每层都可以直接从损失函数中获取梯度、从原始输入中获取信息,从而易于训练。
- 密集连接还具有正则化的效应,缓解了小训练集任务的过拟合。
- 鼓励特征重用。网络将不同层学到的
feature map
进行组合。 - 大幅度减少参数数量。因为每层的卷积核尺寸都比较小,输出通道数较少 (由增长率 $ MathJax-Element-415 $ 决定)。
DenseNet
具有比传统卷积网络更少的参数,因为它不需要重新学习多余的feature map
。传统的前馈神经网络可以视作在层与层之间传递
状态
的算法,每一层接收前一层的状态
,然后将新的状态
传递给下一层。这会改变
状态
,但是也传递了需要保留的信息。ResNet
通过恒等映射来直接传递需要保留的信息,因此层之间只需要传递状态的变化
。DenseNet
会将所有层的状态
全部保存到集体知识
中,同时每一层增加很少数量的feture map
到网络的集体知识中
。
DenseNet
的层很窄(即:feature map
的通道数很小),如:每一层的输出只有 12 个通道。
8.1 DenseNet 块
具有 $ MathJax-Element-696 $ 层的传统卷积网络有 $ MathJax-Element-696 $ 个连接,每层仅仅与后继层相连。
具有 $ MathJax-Element-696 $ 个残差块的
ResNet
在每个残差块增加了跨层连接,第 $ MathJax-Element-642 $ 个残差块的输出为: $ MathJax-Element-332 $ 。其中 $ MathJax-Element-333 $ 是第 $ MathJax-Element-642 $ 个残差块的输入特征; $ MathJax-Element-335 $ 为一组与第 $ MathJax-Element-642 $ 个残差块相关的权重(包括偏置项), $ MathJax-Element-689 $ 是残差块中的层的数量; $ MathJax-Element-338 $ 代表残差函数。具有 $ MathJax-Element-696 $ 个层块的
DenseNet
块有 $ MathJax-Element-340 $ 个连接,每层以前馈的方式将该层与它后面的所有层相连。对于第 $ MathJax-Element-642 $ 层:所有先前层的feature map
都作为本层的输入,第 $ MathJax-Element-642 $ 层具有 $ MathJax-Element-642 $ 个输入feature map
;本层输出的feature map
都将作为后面 $ MathJax-Element-344 $ 层的输入。假设
DenseNet
块包含 $ MathJax-Element-696 $ 层,每一层都实现了一个非线性变换 $ MathJax-Element-371 $ ,其中 $ MathJax-Element-642 $ 表示层的索引。假设
$ \mathbf x_{l}=H_l([\mathbf x_0,\mathbf x_1,\cdots,\mathbf x_{l-1}]) $DenseNet
块的输入为 $ MathJax-Element-348 $ ,DenseNet
块的第 $ MathJax-Element-642 $ 层的输出为 $ MathJax-Element-350 $ ,则有:其中 $ MathJax-Element-351 $ 表示 $ MathJax-Element-352 $ 层输出的
feature map
沿着通道方向的拼接。ResNet
块与它不同。在ResNet
中,不同feature map
是通过直接相加来作为块的输出。当
feature map
的尺寸改变时,无法沿着通道方向进行拼接。此时将网络划分为多个DenseNet
块,每块内部的feature map
尺寸相同,块之间的feature map
尺寸不同。
8.1.1 增长率
DenseNet
块中,每层的 $ MathJax-Element-371 $ 输出的feature map
通道数都相同,都是 $ MathJax-Element-415 $ 个。 $ MathJax-Element-415 $ 是一个重要的超参数,称作网络的增长率。第 $ MathJax-Element-642 $ 层的输入
feature map
的通道数为: $ MathJax-Element-357 $ 。其中 $ MathJax-Element-375 $ 为输入层的通道数。DenseNet
不同于现有网络的一个重要地方是:DenseNet
的网络很窄,即输出的feature map
通道数较小,如: $ MathJax-Element-359 $ 。一个很小的增长率就能够获得不错的效果。一种解释是:
DenseNet
块的每层都可以访问块内的所有早前层输出的feature map
,这些feature map
可以视作DenseNet
块的全局状态。每层输出的feature map
都将被添加到块的这个全局状态中,该全局状态可以理解为网络块的“集体知识”,由块内所有层共享。增长率 $ MathJax-Element-415 $ 决定了新增特征占全局状态的比例。因此
feature map
无需逐层复制(因为它是全局共享),这也是DenseNet
与传统网络结构不同的地方。这有助于整个网络的特征重用,并产生更紧凑的模型。
8.1.2 非线性变换 $ MathJax-Element-731 $
- $ MathJax-Element-371 $ 可以是包含了
Batch Normalization(BN)
、ReLU
单元、池化或者卷积等操作的复合函数。 - 论文中 $ MathJax-Element-371 $ 的结构为:先执行
BN
,再执行ReLU
,最后接一个3x3
的卷积,即:BN-ReLU-Conv(3x3)
。
8.1.3 bottleneck
尽管
DenseNet
块中每层只产生 $ MathJax-Element-415 $ 个输出feature map
,但是它具有很多输入。当在 $ MathJax-Element-371 $ 之前采用1x1
卷积实现降维时,可以减小计算量。$ MathJax-Element-371 $ 的输入是由第 $ MathJax-Element-367 $ 层的输出
feature map
组成,其中第0
层的输出feature map
就是整个DensNet
块的输入feature map
。
事实上第 $ MathJax-Element-368 $ 层从DensNet
块的输入feature map
中抽取各种特征。即 $ MathJax-Element-371 $ 包含了DensNet
块的输入feature map
的冗余信息,这可以通过 1x1
卷积降维来去掉这种冗余性。
因此这种1x1
卷积降维对于DenseNet
块极其有效。
如果在 $ MathJax-Element-371 $ 中引入
1x1
卷积降维,则该版本的DenseNet
称作DenseNet-B
。其 $ MathJax-Element-371 $ 结构为:先执行BN
,再执行ReLU
,再接一个1x1
的卷积,再执行BN
,再执行ReLU
,最后接一个3x3
的卷积。即:BN-ReLU-Conv(1x1)-BN-ReLU-Conv(3x3)
。其中
1x1
卷积的输出通道数是个超参数,论文中选取为 $ MathJax-Element-372 $ 。
8.2 过渡层
一个
DenseNet
网络具有多个DenseNet
块,DenseNet
块之间由过渡层连接。DenseNet
块之间的层称为过渡层,其主要作用是连接不同的DenseNet
块。过渡层可以包含卷积或池化操作,从而改变前一个
DenseNet
块的输出feature map
的大小(包括尺寸大小、通道数量)。论文中的过渡层由一个
BN
层、一个1x1
卷积层、一个2x2
平均池化层组成。其中1x1
卷积层用于减少DenseNet
块的输出通道数,提高模型的紧凑性。如果不减少
DenseNet
块的输出通道数,则经过了 $ MathJax-Element-543 $ 个DenseNet
块之后,网络的feature map
的通道数为: $ MathJax-Element-374 $ ,其中 $ MathJax-Element-375 $ 为输入图片的通道数, $ MathJax-Element-696 $ 为每个DenseNet
块的层数。如果
Dense
块输出feature map
的通道数为 $ MathJax-Element-636 $ ,则可以使得过渡层输出feature map
的通道数为 $ MathJax-Element-378 $ ,其中 $ MathJax-Element-379 $ 为压缩因子。当 $ MathJax-Element-380 $ 时,经过过渡层的
feature map
通道数不变。当 $ MathJax-Element-381 $ 时,经过过渡层的
feature map
通道数减小。此时的DenseNet
称做DenseNet-C
。结合了
DenseNet-C
和DenseNet-B
的改进的网络称作DenseNet-BC
。
8.3 网络性能
网络结构:
ImageNet
训练的DenseNet
网络结构,其中增长率 $ MathJax-Element-382 $ 。- 表中的
conv
代表的是BN-ReLU-Conv
的组合。如1x1 conv
表示:先执行BN
,再执行ReLU
,最后执行1x1
的卷积。 DenseNet-xx
表示DenseNet
块有xx
层。如:DenseNet-169
表示DenseNet
块有 $ MathJax-Element-383 $ 层 。- 所有的
DenseNet
使用的是DenseNet-BC
结构,输入图片尺寸为224x224
,初始卷积尺寸为7x7
、输出通道2k
、步长为2
,压缩因子 $ MathJax-Element-384 $ 。 - 在所有
DenseNet
块的最后接一个全局平均池化层,该池化层的结果作为softmax
输出层的输入。
- 表中的
在
ImageNet
验证集的错误率(single-crop/10-crop
):模型 top-1 error(%) top-5 error(%) DenseNet-121 25.02/23.61 7.71/6.66 DenseNet-169 23.80/22.08 6.85/5.92 DenseNet-201 22.58/21.46 6.34/5.54 DenseNet-264 22.15/20.80 6.12/5.29 下图是
DenseNet
和ResNet
在ImageNet
验证集的错误率的比较(single-crop
)。左图为参数数量,右图为计算量。从实验可见:
DenseNet
的参数数量和计算量相对ResNet
明显减少。- 具有
20M
个参数的DenseNet-201
与具有40M
个参数的ResNet-101
验证误差接近。 - 和
ResNet-101
验证误差接近的DenseNet-201
的计算量接近于ResNet-50
,几乎是ResNet-101
的一半。
- 具有
DenseNet
在CIFAR
和SVHN
验证集的表现:C10+
和C100+
:表示对CIFAR10/CIFAR100
执行数据集增强,包括平移和翻转。- 在
C10/C100/SVHN
三列上的DenseNet
采用了Dropout
。 DenseNet
的Depth
列给出的是 $ MathJax-Element-696 $ 参数。
从实验可见:
不考虑压缩因子和
bottleneck
, $ MathJax-Element-696 $ 和 $ MathJax-Element-415 $ 越大DenseNet
表现更好。这主要是因为模型容量相应地增长。网络可以利用更大、更深的模型提高其表达学习能力,这也表明了
DenseNet
不会受到优化困难的影响。DenseNet
的参数效率更高,使用了压缩因子和bottleneck
的DenseNet-BC
的参数利用率极高。这带来的一个效果是:
DenseNet-BC
不容易发生过拟合。事实上在
CIFAR10
上,DenseNet
从 $ MathJax-Element-388 $ 中,参数数量提升了4倍但是验证误差反而5.77
下降到5.83
,明显发生了过拟合。而DenseNet-BC
未观察到过拟合。
DenseNet
提高准确率的一个可能的解释是:各层通过较短的连接(最多需要经过两个或者三个过渡层)直接从损失函数中接收额外的监督信息。
8.4 内存优化
8.4.1 内存消耗
虽然
DenseNet
的计算效率较高、参数相对较少,但是DenseNet
对内存不友好。考虑到GPU
显存大小的限制,因此无法训练较深的DenseNet
。假设
DenseNet
块包含 $ MathJax-Element-696 $ 层,对于第 $ MathJax-Element-642 $ 层有: $ MathJax-Element-391 $ 。假设每层的输出
feature map
尺寸均为 $ MathJax-Element-392 $ 、通道数为 $ MathJax-Element-415 $ , $ MathJax-Element-731 $ 由BN-ReLU-Conv(3x3)
组成,则:- 拼接
Concat
操作 $ MathJax-Element-395 $ :需要生成临时feature map
作为第 $ MathJax-Element-642 $ 层的输入,内存消耗为 $ MathJax-Element-398 $ 。 BN
操作:需要生成临时feature map
作为ReLU
的输入,内存消耗为 $ MathJax-Element-398 $ 。ReLU
操作:可以执行原地修改,因此不需要额外的feature map
存放ReLU
的输出。Conv
操作:需要生成输出feature map
作为第 $ MathJax-Element-642 $ 层的输出,它是必须的开销。
因此除了第 $ MathJax-Element-433 $ 层的输出
feature map
需要内存开销之外,第 $ MathJax-Element-642 $ 层还需要 $ MathJax-Element-418 $ 的内存开销来存放中间生成的临时feature map
。整个
DenseNet Block
需要 $ MathJax-Element-403 $ 的内存开销来存放中间生成的临时feature map
。即DenseNet Block
的内存消耗为 $ MathJax-Element-412 $ ,是网络深度的平方关系。- 拼接
拼接
Concat
操作是必须的,因为当卷积的输入存放在连续的内存区域时,卷积操作的计算效率较高。而DenseNet Block
中,第 $ MathJax-Element-642 $ 层的输入feature map
由前面各层的输出feature map
沿通道方向拼接而成。而这些输出feature map
并不在连续的内存区域。另外,拼接
feature map
并不是简单的将它们拷贝在一起。由于feature map
在Tensorflow/Pytorch
等等实现中的表示为 $ MathJax-Element-406 $ (channel first
),或者 $ MathJax-Element-407 $ (channel last
),如果简单的将它们拷贝在一起则是沿着mini batch
维度的拼接,而不是沿着通道方向的拼接。DenseNet Block
的这种内存消耗并不是DenseNet Block
的结构引起的,而是由深度学习库引起的。因为Tensorflow/PyTorch
等库在实现神经网络时,会存放中间生成的临时节点(如BN
的输出节点),这是为了在反向传播阶段可以直接获取临时节点的值。这是在时间代价和空间代价之间的折中:通过开辟更多的空间来存储临时值,从而在反向传播阶段节省计算。
除了临时
feature map
的内存消耗之外,网络的参数也会消耗内存。设 $ MathJax-Element-731 $ 由BN-ReLU-Conv(3x3)
组成,则第 $ MathJax-Element-642 $ 层的网络参数数量为: $ MathJax-Element-410 $ (不考虑BN
)。整个
DenseNet Block
的参数数量为 $ MathJax-Element-411 $ ,即 $ MathJax-Element-412 $ 。因此网络参数的数量也是网络深度的平方关系。- 由于
DenseNet
参数数量与网络的深度呈平方关系,因此DenseNet
网络的参数更多、网络容量更大。这也是DenseNet
优于其它网络的一个重要因素。 - 通常情况下都有 $ MathJax-Element-413 $ ,其中 $ MathJax-Element-414 $ 为网络
feature map
的宽、高, $ MathJax-Element-415 $ 为网络的增长率。所以网络参数消耗的内存要远小于临时feature map
消耗的内存。
- 由于
8.4.2 内存优化
论文
《Memory-Efficient Implementation of DenseNets》
通过分配共享内存来降低内存需求,从而使得训练更深的DenseNet
成为可能。其思想是利用时间代价和空间代价之间的折中,但是侧重于牺牲时间代价来换取空间代价。其背后支撑的因素是:
Concat
操作和BN
操作的计算代价很低,但是空间代价很高。因此这种做法在DenseNet
中非常有效。传统的
DenseNet Block
实现与内存优化的DenseNet Block
对比如下(第 $ MathJax-Element-642 $ 层,该层的输入feature map
来自于同一个块中早前的层的输出):左图为传统的
DenseNet Block
的第 $ MathJax-Element-642 $ 层。首先将feature map
拷贝到连续的内存块,拷贝时完成拼接的操作。然后依次执行BN
、ReLU
、Conv
操作。该层的临时
feature map
需要消耗内存 $ MathJax-Element-418 $ ,该层的输出feature map
需要消耗内存 $ MathJax-Element-419 $ 。- 另外某些实现(如
LuaTorch
)还需要为反向传播过程的梯度分配内存,如左图下半部分所示。如:计算BN
层输出的梯度时,需要用到第 $ MathJax-Element-642 $ 层输出层的梯度和BN
层的输出。存储这些梯度需要额外的 $ MathJax-Element-421 $ 的内存。 - 另外一些实现(如
PyTorch,MxNet
)会对梯度使用共享的内存区域来存放这些梯度,因此只需要 $ MathJax-Element-422 $ 的内存。
- 另外某些实现(如
右图为内存优化的
DenseNet Block
的第 $ MathJax-Element-642 $ 层。采用两组预分配的共享内存区Shared memory Storage location
来存Concate
操作和BN
操作输出的临时feature map
。
第一组预分配的共享内存区:
Concat
操作共享区。第 $ MathJax-Element-433 $ 层的Concat
操作的输出都写入到该共享区,第 $ MathJax-Element-645 $ 层的写入会覆盖第 $ MathJax-Element-642 $ 层的结果。对于整个
Dense Block
,这个共享区只需要分配 $ MathJax-Element-436 $ (最大的feature map
)的内存,即内存消耗为 $ MathJax-Element-437 $ (对比传统DenseNet
的 $ MathJax-Element-438 $ )。后续的
BN
操作直接从这个共享区读取数据。由于第 $ MathJax-Element-645 $ 层的写入会覆盖第 $ MathJax-Element-642 $ 层的结果,因此这里存放的数据是临时的、易丢失的。因此在反向传播阶段还需要重新计算第 $ MathJax-Element-642 $ 层的
Concat
操作的结果。因为
Concat
操作的计算效率非常高,因此这种额外的计算代价很低。
第二组预分配的共享内存区:
BN
操作共享区。第 $ MathJax-Element-433 $ 层的BN
操作的输出都写入到该共享区,第 $ MathJax-Element-645 $ 层的写入会覆盖第 $ MathJax-Element-642 $ 层的结果。对于整个
Dense Block
,这个共享区也只需要分配 $ MathJax-Element-436 $ (最大的feature map
)的内存,即内存消耗为 $ MathJax-Element-437 $ (对比传统DenseNet
的 $ MathJax-Element-438 $ )。后续的卷积操作直接从这个共享区读取数据。
与
Concat
操作共享区同样的原因,在反向传播阶段还需要重新计算第 $ MathJax-Element-642 $ 层的BN
操作的结果。BN
的计算效率也很高,只需要额外付出大约 5% 的计算代价。
由于
BN
操作和Concat
操作在神经网络中大量使用,因此这种预分配共享内存区的方法可以广泛应用。它们可以在增加少量的计算时间的情况下节省大量的内存消耗。
8.4.3 优化结果
如下图所示,
DenseNet
不同实现方式的实验结果:Naive Implementation(LuaTorch)
:采用LuaTorch
实现的,不采用任何的内存共享区。Shared Gradient Strorage(LuaTorch)
:采用LuaTorch
实现的,采用梯度内存共享区。Shared Gradient Storage(PyTorch)
:采用PyTorch
实现的,采用梯度内存共享区。Shared Gradient+BN+Concat Strorate(LuaTorch)
:采用LuaTorch
实现的,采用梯度内存共享区、Concat
内存共享区、BN
内存共享区。Shared Gradient+BN+Concat Strorate(PyTorch)
:采用LuaTorch
实现的,采用梯度内存共享区、Concat
内存共享区、BN
内存共享区。
注意:
PyTorch
自动实现了梯度的内存共享区。内存消耗是参数数量的线性函数。因为参数数量本质上是网络深度的二次函数,而内存消耗也是网络深度的二次函数。
如前面的推导过程中,
DenseNet Block
参数数量 $ MathJax-Element-440 $ ,内存消耗 $ MathJax-Element-441 $ 。因此 $ MathJax-Element-442 $ ,即 $ MathJax-Element-443 $ 。
如下图所示,
DenseNet
不同实现方式的训练时间差异(NVIDIA Maxwell Titan-X
):- 梯度共享存储区不会带来额外时间的开销。
Concat
内存共享区、BN
内存共享区需要额外消耗 15%(LuaTorch
) 或者20% (PyTorch
) 的时间。
如下图所示,不同
DenseNet
的不同实现方式在ImageNet
上的表现(single-crop test
):DenseNet cosine
使用 $ MathJax-Element-444 $ 学习率。经过内存优化的
DenseNet
可以在单个工作站(8 NVIDIA Tesla M40 GPU
)上训练 264 层的网络,并取得了top-1 error=20.26%
的好成绩。网络参数数量:232 层
DenseNet
:k=48,55M
参数。 264 层DenseNet
:k=32,33M
参数;k=48,73M
参数。
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。
绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论