这个自定义的pytorch损失功能可区分吗

发布于 2025-01-30 11:08:40 字数 997 浏览 3 评论 0原文

我有一个自定义forward实现了Pytorch损失。培训效果很好。我已经检查了lose.grad_fn,它不是none。 我试图理解两件事:

  1. 此函数如何可区分,因为 - else 在从输入到输出的路径上语句?

  2. gt(地面真相输入)到损失(输出)的路径是否需要区分?或仅来自pred(预测输入)的路径?

这是源代码:

class FocalLoss(nn.Module):
    def __init__(self):
        super(FocalLoss, self).__init__()

    def forward(self, pred, gt):
        pos_inds = gt.eq(1).float()
        neg_inds = gt.lt(1).float()
        neg_weights = torch.pow(1 - gt, 4)

        pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds
        neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * neg_weights * neg_inds

        num_pos = pos_inds.float().sum()
        pos_loss_s = pos_loss.sum()
        neg_loss_s = neg_loss.sum()
        if num_pos == 0:
            loss = - neg_loss_s
        else:
            loss = - (pos_loss_s + neg_loss_s) / num_pos

        return loss

I have a custom forward implementation for a PyTorch loss. The training works well. I've checked the loss.grad_fn and it is not None.
I'm trying to understand two things:

  1. How this function can be differentiable since there is an if-else statement on the path from input to output?

  2. Does the path from gt (ground truth input) to loss (output) need to be differentiable? or only the path from pred (prediction input)?

Here is the source code:

class FocalLoss(nn.Module):
    def __init__(self):
        super(FocalLoss, self).__init__()

    def forward(self, pred, gt):
        pos_inds = gt.eq(1).float()
        neg_inds = gt.lt(1).float()
        neg_weights = torch.pow(1 - gt, 4)

        pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds
        neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * neg_weights * neg_inds

        num_pos = pos_inds.float().sum()
        pos_loss_s = pos_loss.sum()
        neg_loss_s = neg_loss.sum()
        if num_pos == 0:
            loss = - neg_loss_s
        else:
            loss = - (pos_loss_s + neg_loss_s) / num_pos

        return loss

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(2

记忆里有你的影子 2025-02-06 11:08:40

如果语句不是计算图的一部分。它是用于动态构建此图的代码的一部分(即forward函数),但本身并不是其中的一部分。要遵循的原则是问自己是否使用grad_fn back track track track track track track track track track track track track track( ie 输入和参数)每个节点的回调,通过图形反向传播。答案是您只有在每个操作员都可以区分的情况下才能做到这一点:用编程术语,他们实现向后功能操作( aka grad_fn)。

  1. 在您的示例中,num_pos是否等于0是否,所产生的损耗张量将取决于neg_loss_s单独还是pos_loss_sneg_loss_s。但是,在任何一种情况下,结果损失张量仍然连接到输入pred

    • 通过一种方法:“ neg_loss_s”节点
    • 另一个:“ pos_loss_s”和“ neg_loss_s”节点。

无论哪种方式,在您的设置中,操作都是可区分的。

  1. 如果gt是一个接地真相张量,则它不需要梯度,并且从最终损失到最终损失的操作无需区分。在您的示例中,这就是pos_indsneg_inds都是非差异的,因为它们是布尔运算符,因此neg__inds是这种情况。

The if statement is not part of the computational graph. It is part of the code used to build this graph dynamically (i.e. the forward function) but it isn't in itself part of it. The principle to follow is to ask yourself whether you backtrack to the leaves of the graph (tensors that do not have parents in the graph, i.e. inputs, and parameters) using grad_fn callbacks of each node, backpropagating through the graph. The answer is you can only do so if each of the operators is differentiable: in programming terms, they implement a backward function operation (a.k.a. grad_fn).

  1. In your example, whether num_pos is equal to 0 or not, the resulting loss tensor will depend on neg_loss_s alone or on pos_loss_s and neg_loss_s. However in either cases, the resulting loss tensor remains attached to the input pred:

    • via one way: the "neg_loss_s" node
    • or the other: the "pos_loss_s" and "neg_loss_s" nodes.

In your setup, either way, the operation is differentiable.

  1. If gt is a ground-truth tensor then it doesn't require gradient and the operation from it to the final loss doesn't need to be differentiable. This is the case in your example where both pos_inds, and neg_inds are non-differientblae because they are boolean operators.
向日葵 2025-02-06 11:08:40

pytorch确实不是计算梯度wrt损失函数本身。 pytorch记录了forward通过,例如日志,凸起,乘法,加法等,在forward通过期间执行的标准数学操作的顺序向后()被调用。因此,只要您仅使用标准数学操作来计算损失,的存在条件与pytorch无关紧要。

PyTorch does not compute gradients w.r.t the loss function itself. PyTorch records the sequence of standard mathematical operations performed during the forward pass, such as log, exponentiation, multiplication, addition, etc., and computes their gradients w.r.t those mathematical operations when backward() is called. Thus, the presence of if-else conditions don't matter to PyTorch provided you use only the standard math operations to compute your loss.

~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文