如何在 PyTorch 中使用真实世界权重交叉熵损失
我正在研究多类分类,其中一些错误比其他错误更严重。因此,我想将成本纳入我的损失函数中。我在 Real-World-Weight Cross-Entropy 的名称下发现了这个,在本文中进行了描述。公式如下:
除了标准 CrossEntropyLoss
的 weight
参数之外,我还没有找到任何现成的实现,其中我相信效果完全不同到我的用例(据我所知,错误分类一个类别的成本是相同的,无论它与哪个类别混淆)。
我如何在 PyTorch 中应用它?
import torch.nn as nn
import torch
loss = nn.CrossEntropyLoss()
input = torch.randn(3, 5, requires_grad=True)
target = torch.empty(3, dtype=torch.long).random_(5)
cost_matrix = torch.zeros((5, 5))
cost_matrix[1, 0] = 0.4
cost_matrix[2, 0] = 0.1
cost_matrix[2, 1] = 0.9
cost_matrix[3, 0] = 0.4
cost_matrix[3, 1] = 0.9
cost_matrix[3, 2] = 0.1
cost_matrix[4, 0] = 0.1
cost_matrix[4, 1] = 0.4
cost_matrix[4, 2] = 0.9
cost_matrix[4, 3] = 0.1
cost_matrix[0, 1] = 0.4
cost_matrix[0, 2] = 0.1
cost_matrix[1, 2] = 0.9
cost_matrix[0, 3] = 0.4
cost_matrix[1, 3] = 0.9
cost_matrix[2, 3] = 0.1
cost_matrix[0, 4] = 0.1
cost_matrix[1, 4] = 0.4
cost_matrix[2, 4] = 0.9
cost_matrix[3, 4] = 0.1
I'm working on multiclass classification where some mistakes are more severe than others. Therefore, I would like to incorporate the costs into my loss function. I found this under the name Real-World-Weight Cross-Entropy, described in this paper. The formula goes as below:
I haven't find any ready-to-use implementation, apart from weight
argument of standard CrossEntropyLoss
, which I believe works quite different to my use-case (as far as I understand the cost of incorrectly classifying one category is the same no matter with which category it was confused).
How can I apply this in PyTorch?
import torch.nn as nn
import torch
loss = nn.CrossEntropyLoss()
input = torch.randn(3, 5, requires_grad=True)
target = torch.empty(3, dtype=torch.long).random_(5)
cost_matrix = torch.zeros((5, 5))
cost_matrix[1, 0] = 0.4
cost_matrix[2, 0] = 0.1
cost_matrix[2, 1] = 0.9
cost_matrix[3, 0] = 0.4
cost_matrix[3, 1] = 0.9
cost_matrix[3, 2] = 0.1
cost_matrix[4, 0] = 0.1
cost_matrix[4, 1] = 0.4
cost_matrix[4, 2] = 0.9
cost_matrix[4, 3] = 0.1
cost_matrix[0, 1] = 0.4
cost_matrix[0, 2] = 0.1
cost_matrix[1, 2] = 0.9
cost_matrix[0, 3] = 0.4
cost_matrix[1, 3] = 0.9
cost_matrix[2, 3] = 0.1
cost_matrix[0, 4] = 0.1
cost_matrix[1, 4] = 0.4
cost_matrix[2, 4] = 0.9
cost_matrix[3, 4] = 0.1
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。
data:image/s3,"s3://crabby-images/d5906/d59060df4059a6cc364216c4d63ceec29ef7fe66" alt="扫码二维码加入Web技术交流群"
绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论
评论(1)
我建议您研究这个笔记本,它包含您问题中指出的论文作者的完整实现:
真实世界-权重交叉熵损失函数。
I recommend you investigating this notebook, it contains the full implementation by the author of the paper indicated in your question:
The-Real-World-Weight-Crossentropy-Loss-Function.