如何在 PyTorch 中使用真实世界权重交叉熵损失

发布于 2025-01-14 19:15:01 字数 1202 浏览 2 评论 0原文

我正在研究多类分类,其中一些错误比其他错误更严重。因此,我想将成本纳入我的损失函数中。我在 Real-World-Weight Cross-Entropy 的名称下发现了这个,在本文中进行了描述。公式如下:

在此处输入图像描述

除了标准 CrossEntropyLossweight 参数之外,我还没有找到任何现成的实现,其中我相信效果完全不同到我的用例(据我所知,错误分类一个类别的成本是相同的,无论它与哪个类别混淆)。

我如何在 PyTorch 中应用它?

import torch.nn as nn
import torch

loss = nn.CrossEntropyLoss()
input = torch.randn(3, 5, requires_grad=True)
target = torch.empty(3, dtype=torch.long).random_(5)

cost_matrix = torch.zeros((5, 5))
cost_matrix[1, 0] = 0.4
cost_matrix[2, 0] = 0.1
cost_matrix[2, 1] = 0.9
cost_matrix[3, 0] = 0.4
cost_matrix[3, 1] = 0.9
cost_matrix[3, 2] = 0.1
cost_matrix[4, 0] = 0.1
cost_matrix[4, 1] = 0.4
cost_matrix[4, 2] = 0.9
cost_matrix[4, 3] = 0.1
cost_matrix[0, 1] = 0.4
cost_matrix[0, 2] = 0.1
cost_matrix[1, 2] = 0.9
cost_matrix[0, 3] = 0.4
cost_matrix[1, 3] = 0.9
cost_matrix[2, 3] = 0.1
cost_matrix[0, 4] = 0.1
cost_matrix[1, 4] = 0.4
cost_matrix[2, 4] = 0.9
cost_matrix[3, 4] = 0.1
 

I'm working on multiclass classification where some mistakes are more severe than others. Therefore, I would like to incorporate the costs into my loss function. I found this under the name Real-World-Weight Cross-Entropy, described in this paper. The formula goes as below:

enter image description here

I haven't find any ready-to-use implementation, apart from weight argument of standard CrossEntropyLoss, which I believe works quite different to my use-case (as far as I understand the cost of incorrectly classifying one category is the same no matter with which category it was confused).

How can I apply this in PyTorch?

import torch.nn as nn
import torch

loss = nn.CrossEntropyLoss()
input = torch.randn(3, 5, requires_grad=True)
target = torch.empty(3, dtype=torch.long).random_(5)

cost_matrix = torch.zeros((5, 5))
cost_matrix[1, 0] = 0.4
cost_matrix[2, 0] = 0.1
cost_matrix[2, 1] = 0.9
cost_matrix[3, 0] = 0.4
cost_matrix[3, 1] = 0.9
cost_matrix[3, 2] = 0.1
cost_matrix[4, 0] = 0.1
cost_matrix[4, 1] = 0.4
cost_matrix[4, 2] = 0.9
cost_matrix[4, 3] = 0.1
cost_matrix[0, 1] = 0.4
cost_matrix[0, 2] = 0.1
cost_matrix[1, 2] = 0.9
cost_matrix[0, 3] = 0.4
cost_matrix[1, 3] = 0.9
cost_matrix[2, 3] = 0.1
cost_matrix[0, 4] = 0.1
cost_matrix[1, 4] = 0.4
cost_matrix[2, 4] = 0.9
cost_matrix[3, 4] = 0.1
 

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(1

冬天的雪花 2025-01-21 19:15:01

我建议您研究这个笔记本,它包含您问题中指出的论文作者的完整实现:
真实世界-权重交叉熵损失函数

I recommend you investigating this notebook, it contains the full implementation by the author of the paper indicated in your question:
The-Real-World-Weight-Crossentropy-Loss-Function.

~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文