请问下torch loss function输入的维度,如何去扩展高维?

发布于 2022-09-12 23:06:01 字数 1523 浏览 35 评论 0

前提:

b_xent=torch.nn.CrossEntropyLoss()

不报错案例

没问题1

a=torch.tensor([[0.02,0.3],[0.3,0.3],[0.3,0.3]])
b=torch.tensor([0,1,1])
b_xent(a,b)

Out[3]: tensor(0.7680)

报错案例

若直接扩展维度,报错

a=a.unsqueeze(0)
b=b.unsqueeze(0)
a.shape  # Out[20]: torch.Size([1, 3, 2])
b.shape  # Out[21]: torch.Size([1, 3])
b_xent(a,b)

维度出错2

a=torch.randn((1,273,512))
b=torch.randn((1,273))
b_xent(a,b)

报错内容:

Traceback (most recent call last):
  File "D:\lib\site-packages\IPython\core\interactiveshell.py", line 3418, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-1-b7d4d1dd28a0>", line 3, in <module>
    b_xent(a,b)
  File "D:\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "D:\lib\site-packages\torch\nn\modules\loss.py", line 961, in forward
    return F.cross_entropy(input, target, weight=self.weight,
  File "D:\lib\site-packages\torch\nn\functional.py", line 2468, in cross_entropy
    return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
  File "D:\lib\site-packages\torch\nn\functional.py", line 2273, in nll_loss
    raise ValueError('Expected target size {}, got {}'.format(
ValueError: Expected target size (1, 512), got torch.Size([1, 273])

那么如何在input和label都升高维度,并且使用CELoss呢?

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(1

野生奥特曼 2022-09-19 23:06:01

在计算交叉熵损失函数前,一般都是使用view()把输入的input拉成[m,c],c为分类,label拉成一维,再计算交叉熵,有点不太明白为什么要升高维度,可以看一下pytorh中关于交叉熵损失函数的介绍
交叉熵损失函数

~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文