runtimeerror:torch.nn.functional.binary_cross_entropy and Torch.nn.Bceloss不安全
我正在尝试实现 u^2 net for stricient对象检测。由于此代码未针对培训进行优化,因此遵循 AMP 的文档,我对原始代码进行了一些更改 a>检查效果。
我已经准确地使用了代码,当您 AS:
! git clone https://github.com/deshwalmahesh/U-2-Net
%cd ./U-2-Net/
!python u2net_train.py
它会给您带来一些错误。整个堆栈最终发布。我挖了起来,发现这是由于 muti_bce_loss_fusion
作者使用为:
bce_loss = nn.BCELoss(size_average=True)
def muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v):
loss0 = bce_loss(d0,labels_v)
loss1 = bce_loss(d1,labels_v)
loss2 = bce_loss(d2,labels_v)
loss3 = bce_loss(d3,labels_v)
loss4 = bce_loss(d4,labels_v)
loss5 = bce_loss(d5,labels_v)
loss6 = bce_loss(d6,labels_v)
loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6
return loss0, loss
另外,在模型定义的最后一行IE行526中,该模型返回7个传递给损失函数的Sigmoid值。
F.sigmoid(d0), F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)
现在可以做些什么来避免此错误?
错误跟踪
/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py:780: UserWarning: Note that order of the arguments: ceil_mode and return_indices will changeto match the args list in nn.MaxPool2d in a future release.
warnings.warn("Note that order of the arguments: ceil_mode and return_indices will change"
/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py:3704: UserWarning: nn.functional.upsample is deprecated. Use nn.functional.interpolate instead.
warnings.warn("nn.functional.upsample is deprecated. Use nn.functional.interpolate instead.")
/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py:1944: UserWarning: nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.
warnings.warn("nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.")
Traceback (most recent call last):
File "u2net_train.py", line 148, in <module>
loss2, loss = muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v)
File "u2net_train.py", line 33, in muti_bce_loss_fusion
loss0 = bce_loss(d0,labels_v)
File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/loss.py", line 612, in forward
return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)
File "/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py", line 3065, in binary_cross_entropy
return torch._C._nn.binary_cross_entropy(input, target, weight, reduction_enum)
RuntimeError: torch.nn.functional.binary_cross_entropy and torch.nn.BCELoss are unsafe to autocast.
Many models use a sigmoid layer right before the binary cross entropy layer.
In this case, combine the two layers using torch.nn.functional.binary_cross_entropy_with_logits
or torch.nn.BCEWithLogitsLoss. binary_cross_entropy_with_logits and BCEWithLogits are
safe to autocast.
I am trying to implement U^2 Net for Salient Object detection. Since this code is not optimised for training, following this official documentation for AMP, I have made some changes to the original code in my fork to check the effects.
I have used the code exactly and when you run my version of training code on colab
as :
! git clone https://github.com/deshwalmahesh/U-2-Net
%cd ./U-2-Net/
!python u2net_train.py
It'll throw you some error. The whole stack is posted in the end. I dug up and found that it is due to the custom loss function as muti_bce_loss_fusion
which the authors have used as:
bce_loss = nn.BCELoss(size_average=True)
def muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v):
loss0 = bce_loss(d0,labels_v)
loss1 = bce_loss(d1,labels_v)
loss2 = bce_loss(d2,labels_v)
loss3 = bce_loss(d3,labels_v)
loss4 = bce_loss(d4,labels_v)
loss5 = bce_loss(d5,labels_v)
loss6 = bce_loss(d6,labels_v)
loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6
return loss0, loss
Also, in the last line i.e line 526 of the model definition, the model returns 7 sigmoid values which are passed to the loss function.
F.sigmoid(d0), F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)
Now what can be done to avoid this error?
Error trace
/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py:780: UserWarning: Note that order of the arguments: ceil_mode and return_indices will changeto match the args list in nn.MaxPool2d in a future release.
warnings.warn("Note that order of the arguments: ceil_mode and return_indices will change"
/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py:3704: UserWarning: nn.functional.upsample is deprecated. Use nn.functional.interpolate instead.
warnings.warn("nn.functional.upsample is deprecated. Use nn.functional.interpolate instead.")
/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py:1944: UserWarning: nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.
warnings.warn("nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.")
Traceback (most recent call last):
File "u2net_train.py", line 148, in <module>
loss2, loss = muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v)
File "u2net_train.py", line 33, in muti_bce_loss_fusion
loss0 = bce_loss(d0,labels_v)
File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/loss.py", line 612, in forward
return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)
File "/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py", line 3065, in binary_cross_entropy
return torch._C._nn.binary_cross_entropy(input, target, weight, reduction_enum)
RuntimeError: torch.nn.functional.binary_cross_entropy and torch.nn.BCELoss are unsafe to autocast.
Many models use a sigmoid layer right before the binary cross entropy layer.
In this case, combine the two layers using torch.nn.functional.binary_cross_entropy_with_logits
or torch.nn.BCEWithLogitsLoss. binary_cross_entropy_with_logits and BCEWithLogits are
safe to autocast.
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论
评论(1)
它是由于
sigmoid + bce
的不稳定性的主要原因。参考文档和torch
社区,我要做的就是从f.sigmoid(d0)...
替换模型为d0 ... 。现在,该模型运行良好。
The main reason why it was due to unstable nature of
Sigmoid + BCE
. Referring to documentation andtorch
community, all I had to to do was to replace the models fromF.sigmoid(d0)...
tod0.....
and then in turn replacenn.BCELoss(size_average=True)
withnn.BCEWithLogitsLoss(size_average=True)
. Now the model is running fine.