Pytorch上的结构相似性损失实施给NAN

发布于 2025-01-26 19:42:05 字数 1895 浏览 4 评论 0原文

我决定根据文章来写我的损失结构相似性损失

https://arxiv.org/pdf/1910.08711.pdf对他们来说,我有一个问题,有时是数值不稳定的,而我的自编写的侦察模型在训练过程中散发出了NAN,因此损失也变成了NAN。而在其他损失(BCE,骰子损失,局部损失)方面,一切都是稳定的。详细打印了变量后,我发现y_pred = NAN到达之前的损耗值足够了,因此我的假设是损失梯度不正确地计数,但尚不清楚如何修复它。

def ssl_loss (y_real, y_pred, window_size=11, eps = 0.01):
    beta = 0.1
    Lambda = 0.5
    
    #input size(B, C, H, W)
    #C = 1, because we compare monochrome segmentations

    y_real, y_pred = y_real.to(device).squeeze(), y_pred.to(device).squeeze()
    
    bce_matrix = (y_pred - y_real * y_pred + torch.log(1 + torch.exp(-y_pred)))
    
    y_pred = torch.sigmoid(y_pred)
    
    blurer = T.GaussianBlur(kernel_size=(11, 11), sigma=(1.5, 1.5))
    
    mu_y = blurer(y_real)
    sigma_y = blurer((y_real - mu_y) ** 2)
    
    mu_p = blurer(y_pred)
    sigma_p = blurer((y_pred - mu_p) ** 2)
    
    errors = torch.abs((y_real - mu_y + eps) / (torch.sqrt(sigma_y) + eps) - (y_pred - mu_p + eps) / (torch.sqrt(sigma_p) + eps)).squeeze()
    
    f_n_c = (errors > beta * errors.max()).int()
    
    M = f_n_c.sum(dim=(1, 2)).unsqueeze(1).unsqueeze(2)
    
    ssl_matrix = (errors * f_n_c * bce_matrix / M)
    
    loss = Lambda * bce_matrix.mean() + (1 - Lambda) * ssl_matrix.mean()
    
    return loss

这是我火车功能的有意义的部分

    for epoch in range(epochs):
        avg_loss = 0
        model.train()
        for X_batch, Y_batch in data_tr:
            X_batch = X_batch.to(device)
            Y_batch = Y_batch.to(device)

            opt.zero_grad()

            Y_pred = model(X_batch)
            
            loss =  loss_fn(Y_batch, Y_pred)
            loss.backward()
            opt.step()
            
            avg_loss += loss / len(data_tr)
        
        scheduler.step()
     

I decided to write my lossfunction Structural Similarity Loss according to the article
https://arxiv.org/pdf/1910.08711.pdf

Testing different models for segmentation and different losses for them I have a problem that sometimes there is numerical instability and my self-written Segnet model gives out NaN during training, due to which loss also becomes NaN. While on other losses (bce, dice loss, focal loss) everything is stable. After printing out the variables in detail, I found out that the loss value before the y_pred=NaN arrives is adequate, so my assumption is that the loss gradients are counted incorrectly, but it's not clear how to fix it.

def ssl_loss (y_real, y_pred, window_size=11, eps = 0.01):
    beta = 0.1
    Lambda = 0.5
    
    #input size(B, C, H, W)
    #C = 1, because we compare monochrome segmentations

    y_real, y_pred = y_real.to(device).squeeze(), y_pred.to(device).squeeze()
    
    bce_matrix = (y_pred - y_real * y_pred + torch.log(1 + torch.exp(-y_pred)))
    
    y_pred = torch.sigmoid(y_pred)
    
    blurer = T.GaussianBlur(kernel_size=(11, 11), sigma=(1.5, 1.5))
    
    mu_y = blurer(y_real)
    sigma_y = blurer((y_real - mu_y) ** 2)
    
    mu_p = blurer(y_pred)
    sigma_p = blurer((y_pred - mu_p) ** 2)
    
    errors = torch.abs((y_real - mu_y + eps) / (torch.sqrt(sigma_y) + eps) - (y_pred - mu_p + eps) / (torch.sqrt(sigma_p) + eps)).squeeze()
    
    f_n_c = (errors > beta * errors.max()).int()
    
    M = f_n_c.sum(dim=(1, 2)).unsqueeze(1).unsqueeze(2)
    
    ssl_matrix = (errors * f_n_c * bce_matrix / M)
    
    loss = Lambda * bce_matrix.mean() + (1 - Lambda) * ssl_matrix.mean()
    
    return loss

And here's meaningful part of my train function

    for epoch in range(epochs):
        avg_loss = 0
        model.train()
        for X_batch, Y_batch in data_tr:
            X_batch = X_batch.to(device)
            Y_batch = Y_batch.to(device)

            opt.zero_grad()

            Y_pred = model(X_batch)
            
            loss =  loss_fn(Y_batch, Y_pred)
            loss.backward()
            opt.step()
            
            avg_loss += loss / len(data_tr)
        
        scheduler.step()
     

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。
列表为空,暂无数据
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文