Pytorch上的结构相似性损失实施给NAN
https://arxiv.org/pdf/1910.08711.pdf对他们来说,我有一个问题,有时是数值不稳定的,而我的自编写的侦察模型在训练过程中散发出了NAN,因此损失也变成了NAN。而在其他损失(BCE,骰子损失,局部损失)方面,一切都是稳定的。详细打印了变量后,我发现y_pred = NAN到达之前的损耗值足够了,因此我的假设是损失梯度不正确地计数,但尚不清楚如何修复它。
def ssl_loss (y_real, y_pred, window_size=11, eps = 0.01):
beta = 0.1
Lambda = 0.5
#input size(B, C, H, W)
#C = 1, because we compare monochrome segmentations
y_real, y_pred = y_real.to(device).squeeze(), y_pred.to(device).squeeze()
bce_matrix = (y_pred - y_real * y_pred + torch.log(1 + torch.exp(-y_pred)))
y_pred = torch.sigmoid(y_pred)
blurer = T.GaussianBlur(kernel_size=(11, 11), sigma=(1.5, 1.5))
mu_y = blurer(y_real)
sigma_y = blurer((y_real - mu_y) ** 2)
mu_p = blurer(y_pred)
sigma_p = blurer((y_pred - mu_p) ** 2)
errors = torch.abs((y_real - mu_y + eps) / (torch.sqrt(sigma_y) + eps) - (y_pred - mu_p + eps) / (torch.sqrt(sigma_p) + eps)).squeeze()
f_n_c = (errors > beta * errors.max()).int()
M = f_n_c.sum(dim=(1, 2)).unsqueeze(1).unsqueeze(2)
ssl_matrix = (errors * f_n_c * bce_matrix / M)
loss = Lambda * bce_matrix.mean() + (1 - Lambda) * ssl_matrix.mean()
return loss
这是我火车功能的有意义的部分
for epoch in range(epochs):
avg_loss = 0
model.train()
for X_batch, Y_batch in data_tr:
X_batch = X_batch.to(device)
Y_batch = Y_batch.to(device)
opt.zero_grad()
Y_pred = model(X_batch)
loss = loss_fn(Y_batch, Y_pred)
loss.backward()
opt.step()
avg_loss += loss / len(data_tr)
scheduler.step()
I decided to write my lossfunction Structural Similarity Loss according to the article
https://arxiv.org/pdf/1910.08711.pdf
Testing different models for segmentation and different losses for them I have a problem that sometimes there is numerical instability and my self-written Segnet model gives out NaN during training, due to which loss also becomes NaN. While on other losses (bce, dice loss, focal loss) everything is stable. After printing out the variables in detail, I found out that the loss value before the y_pred=NaN arrives is adequate, so my assumption is that the loss gradients are counted incorrectly, but it's not clear how to fix it.
def ssl_loss (y_real, y_pred, window_size=11, eps = 0.01):
beta = 0.1
Lambda = 0.5
#input size(B, C, H, W)
#C = 1, because we compare monochrome segmentations
y_real, y_pred = y_real.to(device).squeeze(), y_pred.to(device).squeeze()
bce_matrix = (y_pred - y_real * y_pred + torch.log(1 + torch.exp(-y_pred)))
y_pred = torch.sigmoid(y_pred)
blurer = T.GaussianBlur(kernel_size=(11, 11), sigma=(1.5, 1.5))
mu_y = blurer(y_real)
sigma_y = blurer((y_real - mu_y) ** 2)
mu_p = blurer(y_pred)
sigma_p = blurer((y_pred - mu_p) ** 2)
errors = torch.abs((y_real - mu_y + eps) / (torch.sqrt(sigma_y) + eps) - (y_pred - mu_p + eps) / (torch.sqrt(sigma_p) + eps)).squeeze()
f_n_c = (errors > beta * errors.max()).int()
M = f_n_c.sum(dim=(1, 2)).unsqueeze(1).unsqueeze(2)
ssl_matrix = (errors * f_n_c * bce_matrix / M)
loss = Lambda * bce_matrix.mean() + (1 - Lambda) * ssl_matrix.mean()
return loss
And here's meaningful part of my train function
for epoch in range(epochs):
avg_loss = 0
model.train()
for X_batch, Y_batch in data_tr:
X_batch = X_batch.to(device)
Y_batch = Y_batch.to(device)
opt.zero_grad()
Y_pred = model(X_batch)
loss = loss_fn(Y_batch, Y_pred)
loss.backward()
opt.step()
avg_loss += loss / len(data_tr)
scheduler.step()
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论