将输入展平为 nn.MSELoss()

发布于 2025-01-11 21:12:11 字数 629 浏览 0 评论 0原文

以下是 YouTube 视频的屏幕截图,该视频实现了 YOLOv1 原始研究论文中的 Loss 函数。 输入图片这里的描述

我不明白的是在将输入传递给self.mse()时需要torch.Flatten(),其中,事实上,是nn.MSELoss()

视频只是提到了原因,因为 nn.MSELoss() 需要形状 (a,b) 的输入,我特别不明白如何或为何?

视频链接以防万一。 [作为参考,N批量大小S网格大小(分割大小)]

Here's the screenshot of a YouTube video implementing the Loss function from the YOLOv1 original research paper. enter image description here

What I don't understand is the need for torch.Flatten() while passing the input to self.mse(), which, in fact, is nn.MSELoss()

The video just mentions the reason as nn.MSELoss() expects the input in the shape (a,b), which I specifically don't understand how or why?

Video link just in case. [For reference, N is the batch size, S is the grid size (split size)]

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(2

抠脚大汉 2025-01-18 21:12:11

回到定义会有所帮助。什么是MSE?它在计算什么?

MSE = 均方误差。

这将用粗略的 Pythonic 伪代码来说明。

total = 0
for (x,y) in (data,labels):
   total += (x-y)**2
return total / len(labels)  # the average squared difference

对于每对条目,它将两个数字相减并返回所有相减后的平均值。
重新表述一下这个问题,如果不进行扁平化,您将如何解释 MSE?所描述和实现的 MSE 对于更高维度没有任何意义。如果您想使用矩阵输出(例如输出矩阵的范数),您可以使用其他损失函数。

无论如何,希望能回答您关于为什么需要扁平化的问题。

It helps to go back to the definitions. What is MSE? What is it computing?

MSE = mean squared error.

This will be rough pythonic pseudo code to illustrate.

total = 0
for (x,y) in (data,labels):
   total += (x-y)**2
return total / len(labels)  # the average squared difference

For each pair of entries it subtracts two numbers together and returns the average (or mean) after all of the subtractions.
To rephrase the question how would you interpret MSE without flattening? MSE as described and implemented doesn't mean anything for higher dimensions. You can use other loss functions if you want to work with the outputs being matrices such as norms of the output matrices.

Anyways hope that answers your question as to why the flattening is needed.

怪我入戏太深 2025-01-18 21:12:11

我有同样的问题。所以我尝试使用不同的 end_dims。
就像:

data = torch.randn((1, 7, 7, 4))
target = torch.randn((1, 7, 7, 4))

loss = torch.nn.MSELoss(reduction="sum")


object_loss = loss(
        torch.flatten(data, end_dim=-2),
        torch.flatten(target, end_dim=-2),
    )
object_loss1 = loss(
        torch.flatten(data, end_dim=-3),
        torch.flatten(target, end_dim=-3),
    )
print(object_loss)
print(object_loss1)

我得到了相同的结果。所以我认为这有助于解释 MSE。

I have the same question. So I try with different end_dims.
like:

data = torch.randn((1, 7, 7, 4))
target = torch.randn((1, 7, 7, 4))

loss = torch.nn.MSELoss(reduction="sum")


object_loss = loss(
        torch.flatten(data, end_dim=-2),
        torch.flatten(target, end_dim=-2),
    )
object_loss1 = loss(
        torch.flatten(data, end_dim=-3),
        torch.flatten(target, end_dim=-3),
    )
print(object_loss)
print(object_loss1)

I got the same result. So I think it just helps to intepret MSE.

~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文