CenterLoss - 实战&源码
本来要继续连载 CRF 的,但是最近看到了知乎上各位大神的论文解读 - 《A Discriminative Feature Learning Approach for Deep Face Recognition》,其中介绍了他们设计的一个用户提高类别的区分度的损失函数 - Center Loss,并且有一个基于 MNist 数据库的小实验,于是本人就简单实践了一下。
本着其他大神已经深度解析了这篇文章的内容,所以我觉得我再凑热闹把那些理论的东西讲一遍意义不大,倒不如抡起袖子干点实事 - 把代码写出来跑一跑。在我完成实验前,已经看到 github 上有人完成了 MXNet 的 center loss 代码,所以作为 Caffe 派的革命小斗士,我得抓紧时间了。
不说废话,先用一个 8 拍把问题说明白:
一个 8 拍的 center loss
1:同样是分类问题,我们之前只关注了待识别的图像应该属于哪个类别,但是并没有关心一个同样重要的问题:最终分类器的分界面区域内的空间是不是都应该属于这个类别?空间内这些长得很像的图像,它们的特征会不会其实差距有点大?
2:于是乎作者设计了一个网络,在倒数第二层全连接层输出了一个 2 维的特征向量,并以此进行进一步的分类。我们把 MNist 的 Test 集合数据通过最终训练好的模型进行预测,倒数第二层的样子是这样的:
嗯,貌似和论文上的图像差不多啊……
3:于是乎,作者就觉得,怎么每个类别的特征都是长长的一条啊?那么实际上同一个类内部的差距还很大呢……而且,同一类别下两个图像的距离可能比不同类的距离还大,这种现象如果一直存在,那么在作者关注的人脸识别领域,会不会出现一个人脸和别人的脸太相似而被误刷啊……
4:不能这样子,于是作者设计了一个新的 loss 叫 center loss。我们给每个 label 的数据定义一个 center,大家要向 center 靠近,离得远的要受惩罚,于是 center loss 就出现了:
5:众人纷纷表示这个思路很好,但是这个 c 怎么定义呢?首先拍脑袋想到的就是在 batch 训练的时候不断地计算更新它,每一轮我们计算一下当前数据和 center 的距离,然后把这个距离以某种形式 - 就是梯度叠加到 center 上:
6:吃瓜群众立刻表示:每个 batch 的数据并不算多,这样更新会不会容易 center 产生抖动?数值上的不稳定在优化中是大忌啊!于是作者简单粗暴地讲:那我们加个 scale,让它不要太大:
这个 scale 肯定是小于 1 的。
7:吃瓜群众满意了,吃瓜子的群众有发话了:现在你有两个 loss 了,我们该怎么平衡这两个 loss 之间的权重呢?作者心想:你这不是明知故问么……于是又加了一个超参数,用于控制两个 loss 之间的比例。
反正多来个超参数无所谓,你们慢慢调去吧~
8:该定义的终于都结束了,我们加入新的 loss,训练之后得到了这个结果:
于是掌声雷动,这个效果看着确实不错啊~
一些细节
由于倒数第二层的特征维度被缩减成了 2,所以识别的精度肯定会受到些影响,不过这只是为了可视化的效果,所以在真正的实验中我们可以把这个数字调大。在我的实验中最终的 Test Accuracy 是 0.9888。比正常 LeNet 的 0.99 稍低一点。
直接修改 LeNet 倒数第二层的维度会造成无法训练,所以论文中的 LeNet++使用了 6 层卷积。对于 MNist 这样的小问题使用 6 层卷积也是没谁了,所以训练起来还是要费点时间的。
在我的实验中加入 center loss 后 Test Accuracy 实际上是下降了一点的。不过这点下降并不能说明 center loss 对这个问题起了反作用,后面还是需要尝试当倒数第二层的维度大于 2 时的情况。
干货来了
说了这么多废话,it's time to show me the code。链接: GitHub - hsmyy/CenterLoss_Caffe_Mnist: It's the script of Center loss on mnist dataset running on Caffe.
由于是快速尝试,只实现了 cpu 的版本,而且写得比较粗糙,望各位大神不吝赐教。
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论