自定义损失函数 IoU 是不可微分的。你能为 ML 创建一个可微分的 IoU 损失函数吗?

发布于 2025-01-20 17:16:13 字数 1694 浏览 1 评论 0原文

我目前使用CNN开发了峰值检测算法来确定理想的卷积内核,该内核可表示为理想的母波函数,它将最大程度地提高峰值检测精度。

我试图为CNN培训模型创建自己的IOU损失功能,但失败了。 我自己的损失功能如下所述。


'''
1D intersection over union loss function class
'''
class IoU(nn.Module):
  def __init__(self, thresh: float = 0.5):
    super().__init__()
    self.thresh = thresh

  def forward(self, inputs: torch.Tensor, targets:torch.Tensor, weights: Optional[torch.Tensor] = None, smooth: float = 0.0) -> Tensor:

    inputs = torch.where(inputs < self.thresh, 0, 1)
    batch_size = targets.shape[0]

    intersect = torch.logical_and(inputs, targets)
    intersect = intersect.view(batch_size, -1).sum(-1)

    union = torch.logical_and(inputs, targets)
    union = union.view(batch_size, -1).sum(-1)

    IoU = (intersect + smooth) / (union + smooth)
    IoU = IoU.mean()

    return IoU

我尝试通过这些简单模型如下测试是否有效。


x = torch.tensor(0, 1001, 256) # e.g. [0, 200, 30, 1000, ...]
true = torch.tensor(0, 2, 256) # e.g. [0, 1, 1, 0, 1, ...]

model = nn.Linear(256, 256)
criterion = IoU()

output = model(x)
loss = criterion(output, true)
loss.backward() # I'm stuck on here, cause my loss func IoU is not differentiable

print(f"loss ={loss}")
print(f"model weight: {model.weight.grad}")
print(f"model params: [x.grad for x in {model.parameters()]")

终端上的输出是RuntimeError:变量的元素0不需要毕业,并且没有Grad_fn

这个项目是我使用Pytorch的时间,所以我不知道它是什么乍一看,但是经过快速的研究,我弄清楚了为什么这种损失功能失败(虽然这是正确的),

但我的损失功能并非可区分。

就是此损失函数中断的链条规则中断的地方。

IoU = torch.nan_to_num(IoU)
IoU = IoU.mean()

注意到这一点后不久,我对GitHub或堆栈溢出进行了更深入的了解以查找任何其他可区分的损失功能,但是我仍然不确定如何创建一个可区分的IOU损失功能(尤其是对于一维数据)。

谢谢

I'm currently developing the peak detection algorithm using CNN to determine the ideal convolution kernel which is representable as the ideal mother wavelet function that will maximize the peak detection accuracy.

I've tried to create my own IoU loss function for the CNN training model, but I failed.
My own loss function is described as below.


'''
1D intersection over union loss function class
'''
class IoU(nn.Module):
  def __init__(self, thresh: float = 0.5):
    super().__init__()
    self.thresh = thresh

  def forward(self, inputs: torch.Tensor, targets:torch.Tensor, weights: Optional[torch.Tensor] = None, smooth: float = 0.0) -> Tensor:

    inputs = torch.where(inputs < self.thresh, 0, 1)
    batch_size = targets.shape[0]

    intersect = torch.logical_and(inputs, targets)
    intersect = intersect.view(batch_size, -1).sum(-1)

    union = torch.logical_and(inputs, targets)
    union = union.view(batch_size, -1).sum(-1)

    IoU = (intersect + smooth) / (union + smooth)
    IoU = IoU.mean()

    return IoU

and I tried to test whether this works or not by these simple model like below.


x = torch.tensor(0, 1001, 256) # e.g. [0, 200, 30, 1000, ...]
true = torch.tensor(0, 2, 256) # e.g. [0, 1, 1, 0, 1, ...]

model = nn.Linear(256, 256)
criterion = IoU()

output = model(x)
loss = criterion(output, true)
loss.backward() # I'm stuck on here, cause my loss func IoU is not differentiable

print(f"loss ={loss}")
print(f"model weight: {model.weight.grad}")
print(f"model params: [x.grad for x in {model.parameters()]")

And the output on the terminal is RuntimeError: element 0 of variables does not require grad and does not have a grad_fn

This project is the time ever for me to use PyTorch, so I didn't know what it meant at the first glance, but after my quick research, I figured out why this loss function fails (I'm not sure this is correct though)

my loss function IoU is not differentiable.

and

This is where the chain rule of this loss function break.

IoU = torch.nan_to_num(IoU)
IoU = IoU.mean()

Soon after I noticed this, I took a deeper look at the GitHub or stack overflow to find any other differentiable IoU loss function, but I'm still not sure how to create a differentiable IoU loss function (especially for 1D data).

Thank you

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。
列表为空,暂无数据
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文