这个自定义的pytorch损失功能可区分吗
我有一个自定义forward
实现了Pytorch损失。培训效果很好。我已经检查了lose.grad_fn
,它不是none
。 我试图理解两件事:
此函数如何可区分,因为 - else 在从输入到输出的路径上语句?
gt
(地面真相输入)到损失(输出)的路径是否需要区分?或仅来自pred
(预测输入)的路径?
这是源代码:
class FocalLoss(nn.Module):
def __init__(self):
super(FocalLoss, self).__init__()
def forward(self, pred, gt):
pos_inds = gt.eq(1).float()
neg_inds = gt.lt(1).float()
neg_weights = torch.pow(1 - gt, 4)
pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds
neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * neg_weights * neg_inds
num_pos = pos_inds.float().sum()
pos_loss_s = pos_loss.sum()
neg_loss_s = neg_loss.sum()
if num_pos == 0:
loss = - neg_loss_s
else:
loss = - (pos_loss_s + neg_loss_s) / num_pos
return loss
I have a custom forward
implementation for a PyTorch loss. The training works well. I've checked the loss.grad_fn
and it is not None
.
I'm trying to understand two things:
How this function can be differentiable since there is an
if
-else
statement on the path from input to output?Does the path from
gt
(ground truth input) to loss (output) need to be differentiable? or only the path frompred
(prediction input)?
Here is the source code:
class FocalLoss(nn.Module):
def __init__(self):
super(FocalLoss, self).__init__()
def forward(self, pred, gt):
pos_inds = gt.eq(1).float()
neg_inds = gt.lt(1).float()
neg_weights = torch.pow(1 - gt, 4)
pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds
neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * neg_weights * neg_inds
num_pos = pos_inds.float().sum()
pos_loss_s = pos_loss.sum()
neg_loss_s = neg_loss.sum()
if num_pos == 0:
loss = - neg_loss_s
else:
loss = - (pos_loss_s + neg_loss_s) / num_pos
return loss
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论
评论(2)
如果
语句不是计算图的一部分。它是用于动态构建此图的代码的一部分(即forward
函数),但本身并不是其中的一部分。要遵循的原则是问自己是否使用grad_fn
back track track track track track track track track track track track track track( ie 输入和参数)每个节点的回调,通过图形反向传播。答案是您只有在每个操作员都可以区分的情况下才能做到这一点:用编程术语,他们实现向后功能操作( akagrad_fn
)。在您的示例中,
num_pos
是否等于0
是否,所产生的损耗张量将取决于neg_loss_s
单独还是pos_loss_s
和neg_loss_s
。但是,在任何一种情况下,结果损失
张量仍然连接到输入pred
::
neg_loss_s
”节点pos_loss_s
”和“neg_loss_s
”节点。无论哪种方式,在您的设置中,操作都是可区分的。
gt
是一个接地真相张量,则它不需要梯度,并且从最终损失到最终损失的操作无需区分。在您的示例中,这就是pos_inds
和neg_inds
都是非差异的,因为它们是布尔运算符,因此neg__inds
是这种情况。The
if
statement is not part of the computational graph. It is part of the code used to build this graph dynamically (i.e. theforward
function) but it isn't in itself part of it. The principle to follow is to ask yourself whether you backtrack to the leaves of the graph (tensors that do not have parents in the graph, i.e. inputs, and parameters) usinggrad_fn
callbacks of each node, backpropagating through the graph. The answer is you can only do so if each of the operators is differentiable: in programming terms, they implement a backward function operation (a.k.a.grad_fn
).In your example, whether
num_pos
is equal to0
or not, the resulting loss tensor will depend onneg_loss_s
alone or onpos_loss_s
andneg_loss_s
. However in either cases, the resultingloss
tensor remains attached to the inputpred
:neg_loss_s
" nodepos_loss_s
" and "neg_loss_s
" nodes.In your setup, either way, the operation is differentiable.
gt
is a ground-truth tensor then it doesn't require gradient and the operation from it to the final loss doesn't need to be differentiable. This is the case in your example where bothpos_inds
, andneg_inds
are non-differientblae because they are boolean operators.pytorch确实不是计算梯度wrt损失函数本身。 pytorch记录了
forward
通过,例如日志,凸起,乘法,加法等,在forward
通过期间执行的标准数学操作的顺序向后()
被调用。因此,只要您仅使用标准数学操作来计算损失,的存在条件与pytorch无关紧要。PyTorch does not compute gradients w.r.t the loss function itself. PyTorch records the sequence of standard mathematical operations performed during the
forward
pass, such as log, exponentiation, multiplication, addition, etc., and computes their gradients w.r.t those mathematical operations whenbackward()
is called. Thus, the presence ofif-else
conditions don't matter to PyTorch provided you use only the standard math operations to compute your loss.