返回介绍

CenterLoss - 实战&源码

发布于 2025-02-25 23:05:00 字数 3053 浏览 0 评论 0 收藏 0

本来要继续连载 CRF 的,但是最近看到了知乎上各位大神的论文解读 - 《A Discriminative Feature Learning Approach for Deep Face Recognition》,其中介绍了他们设计的一个用户提高类别的区分度的损失函数 - Center Loss,并且有一个基于 MNist 数据库的小实验,于是本人就简单实践了一下。

本着其他大神已经深度解析了这篇文章的内容,所以我觉得我再凑热闹把那些理论的东西讲一遍意义不大,倒不如抡起袖子干点实事 - 把代码写出来跑一跑。在我完成实验前,已经看到 github 上有人完成了 MXNet 的 center loss 代码,所以作为 Caffe 派的革命小斗士,我得抓紧时间了。

不说废话,先用一个 8 拍把问题说明白:

一个 8 拍的 center loss

1:同样是分类问题,我们之前只关注了待识别的图像应该属于哪个类别,但是并没有关心一个同样重要的问题:最终分类器的分界面区域内的空间是不是都应该属于这个类别?空间内这些长得很像的图像,它们的特征会不会其实差距有点大?

2:于是乎作者设计了一个网络,在倒数第二层全连接层输出了一个 2 维的特征向量,并以此进行进一步的分类。我们把 MNist 的 Test 集合数据通过最终训练好的模型进行预测,倒数第二层的样子是这样的:

嗯,貌似和论文上的图像差不多啊……

3:于是乎,作者就觉得,怎么每个类别的特征都是长长的一条啊?那么实际上同一个类内部的差距还很大呢……而且,同一类别下两个图像的距离可能比不同类的距离还大,这种现象如果一直存在,那么在作者关注的人脸识别领域,会不会出现一个人脸和别人的脸太相似而被误刷啊……

4:不能这样子,于是作者设计了一个新的 loss 叫 center loss。我们给每个 label 的数据定义一个 center,大家要向 center 靠近,离得远的要受惩罚,于是 center loss 就出现了:

CenterLoss=\frac{1}{2N}\sum_{i=1}^N|x_i-c|^2_2

5:众人纷纷表示这个思路很好,但是这个 c 怎么定义呢?首先拍脑袋想到的就是在 batch 训练的时候不断地计算更新它,每一轮我们计算一下当前数据和 center 的距离,然后把这个距离以某种形式 - 就是梯度叠加到 center 上:

\frac{\partial CenterLoss}{\partial x}=\frac{1}{N}\sum_{i=1}^N(c-x_i)

6:吃瓜群众立刻表示:每个 batch 的数据并不算多,这样更新会不会容易 center 产生抖动?数值上的不稳定在优化中是大忌啊!于是作者简单粗暴地讲:那我们加个 scale,让它不要太大:

\Delta c = \frac{\alpha}{N} \sum_{i=1}^N(c-x_i)

这个 scale 肯定是小于 1 的。

7:吃瓜群众满意了,吃瓜子的群众有发话了:现在你有两个 loss 了,我们该怎么平衡这两个 loss 之间的权重呢?作者心想:你这不是明知故问么……于是又加了一个超参数\lambda ,用于控制两个 loss 之间的比例。

反正多来个超参数无所谓,你们慢慢调去吧~

8:该定义的终于都结束了,我们加入新的 loss,训练之后得到了这个结果:

于是掌声雷动,这个效果看着确实不错啊~

一些细节

由于倒数第二层的特征维度被缩减成了 2,所以识别的精度肯定会受到些影响,不过这只是为了可视化的效果,所以在真正的实验中我们可以把这个数字调大。在我的实验中最终的 Test Accuracy 是 0.9888。比正常 LeNet 的 0.99 稍低一点。

直接修改 LeNet 倒数第二层的维度会造成无法训练,所以论文中的 LeNet++使用了 6 层卷积。对于 MNist 这样的小问题使用 6 层卷积也是没谁了,所以训练起来还是要费点时间的。

在我的实验中加入 center loss 后 Test Accuracy 实际上是下降了一点的。不过这点下降并不能说明 center loss 对这个问题起了反作用,后面还是需要尝试当倒数第二层的维度大于 2 时的情况。

干货来了

说了这么多废话,it's time to show me the code。链接: GitHub - hsmyy/CenterLoss_Caffe_Mnist: It's the script of Center loss on mnist dataset running on Caffe.

由于是快速尝试,只实现了 cpu 的版本,而且写得比较粗糙,望各位大神不吝赐教。

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。
列表为空,暂无数据
    我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
    原文