我正在遇到一个梯度计算内置错误

发布于 2025-01-19 23:25:44 字数 2243 浏览 1 评论 0 原文

我正在运行此代码()在一个自定义数据集上,但我正在遇到此错误。 RuntimeError:梯度计算所需的一个变量之一已通过实施操作进行了修改:[TORCH。 cuda.floattensor [1,512,4,4]]在版本2中;预期版1。提示:上面的返回轨道进一步显示了未能计算其梯度的操作。所讨论的变量在那里或以后在任何地方更改。祝你好运!

”错误消息

请参阅上面的代码链接,以澄清发生错误的位置。

我正在自定义数据集上运行此模型,数据加载程序部件在下面粘贴。

    import torchvision.transforms as transforms
    train_transform = transforms.Compose([
    transforms.Resize((256,256)),
    #transforms.RandomResizedCrop(256),
    #transforms.RandomHorizontalFlip(),
    #transforms.ColorJitter(),
    transforms.ToTensor(),
    transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])
 ])

class Flare(Dataset):
  def __init__(self, flare_dir, wf_dir,transform = None):
    self.flare_dir = flare_dir
    self.wf_dir = wf_dir
    self.transform = transform
    self.flare_img = os.listdir(flare_dir)
    self.wf_img = os.listdir(wf_dir)
    
  def __len__(self):
     return len(self.flare_img)
  def __getitem__(self, idx):
    f_img = Image.open(os.path.join(self.flare_dir, self.flare_img[idx])).convert("RGB")
    for i in self.wf_img:
        if (self.flare_img[idx].split('.')[0][4:] == i.split('.')[0]):
            wf_img = Image.open(os.path.join(self.wf_dir, i)).convert("RGB")
            break
    f_img = self.transform(f_img)
    wf_img = self.transform(wf_img)
    
   return f_img, wf_img         





flare_dir = '../input/flaredataset/Flare/Flare_img'
wf_dir = '../input/flaredataset/Flare/Without_Flare_'
flare_img = os.listdir(flare_dir)
wf_img = os.listdir(wf_dir)
wf_img.sort()
flare_img.sort()
print(wf_img[0])
train_ds = Flare(flare_dir, wf_dir,train_transform)
train_loader = torch.utils.data.DataLoader(dataset=train_ds,
                                       batch_size=BATCH_SIZE, 
                                       shuffle=True)

为了更好地了解数据集类别,您可以将我的数据集类别与上面粘贴的链接进行比较

I am running this code (https://github.com/ayu-22/BPPNet-Back-Projected-Pyramid-Network/blob/master/Single_Image_Dehazing.ipynb) on a custom dataset but I am running into this error.
RuntimeError: one of the variables needed for gradient computation has been modified by an in place operation: [torch. cuda.FloatTensor [1, 512, 4, 4]] is at version 2; expected version 1 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

Error Message

Please refer to the code link above for clarification of where the error is occurring.

I am running this model on a custom dataset, the data loader part is pasted below.

    import torchvision.transforms as transforms
    train_transform = transforms.Compose([
    transforms.Resize((256,256)),
    #transforms.RandomResizedCrop(256),
    #transforms.RandomHorizontalFlip(),
    #transforms.ColorJitter(),
    transforms.ToTensor(),
    transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])
 ])

class Flare(Dataset):
  def __init__(self, flare_dir, wf_dir,transform = None):
    self.flare_dir = flare_dir
    self.wf_dir = wf_dir
    self.transform = transform
    self.flare_img = os.listdir(flare_dir)
    self.wf_img = os.listdir(wf_dir)
    
  def __len__(self):
     return len(self.flare_img)
  def __getitem__(self, idx):
    f_img = Image.open(os.path.join(self.flare_dir, self.flare_img[idx])).convert("RGB")
    for i in self.wf_img:
        if (self.flare_img[idx].split('.')[0][4:] == i.split('.')[0]):
            wf_img = Image.open(os.path.join(self.wf_dir, i)).convert("RGB")
            break
    f_img = self.transform(f_img)
    wf_img = self.transform(wf_img)
    
   return f_img, wf_img         





flare_dir = '../input/flaredataset/Flare/Flare_img'
wf_dir = '../input/flaredataset/Flare/Without_Flare_'
flare_img = os.listdir(flare_dir)
wf_img = os.listdir(wf_dir)
wf_img.sort()
flare_img.sort()
print(wf_img[0])
train_ds = Flare(flare_dir, wf_dir,train_transform)
train_loader = torch.utils.data.DataLoader(dataset=train_ds,
                                       batch_size=BATCH_SIZE, 
                                       shuffle=True)

To get a better idea of the dataset class , you can compare my dataset class with the link pasted above

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(1

婴鹅 2025-01-26 23:25:44

您的代码被困在 GAN 网络的所谓“反向传播”中。

您定义的后向图应遵循以下内容:

def backward(self, unet_loss, dis_loss):
        dis_loss.backward(retain_graph = True)
        self.dis_optimizer.step()

        unet_loss.backward()
        self.unet_optimizer.step()

因此,在您的后向图中,您首先传播 dis_loss ,它是鉴别器和对抗性损失的组合,然后传播 dis_loss 。 >unet_lossUNetSSIMContentLoss 的组合,但 unet_loss 连接到判别器的输出损失。因此,当您在存储 unet_loss 的向后图之前采取 dis_loss 的优化器步骤时,pytorch 会感到困惑并给出此错误,我建议您更改代码如下:

def backward(self, unet_loss, dis_loss):
        dis_loss.backward(retain_graph = True)
        unet_loss.backward()

        self.dis_optimizer.step()
        self.unet_optimizer.step()

这将开始您的训练!但您可以尝试使用 retain_graph=True

BPPNet Work 上的出色工作。

Your code is stuck in what is called the "Backpropagation" of your GAN Network.

What you have defined your backward graph should follow is the following:

def backward(self, unet_loss, dis_loss):
        dis_loss.backward(retain_graph = True)
        self.dis_optimizer.step()

        unet_loss.backward()
        self.unet_optimizer.step()

So in your backward graph, you are propagating the dis_loss which is the combination of the discriminator and adversarial loss first and then you are propagating the unet_loss which is the combination of UNet, SSIM and ContentLoss but the unet_loss is connected to discriminator's output loss. So the pytorch is confused and gives you this error as you are taking the optimizer step of dis_loss before even storing the backward graph for unet_loss and I would recommend you to change the code as follows:

def backward(self, unet_loss, dis_loss):
        dis_loss.backward(retain_graph = True)
        unet_loss.backward()

        self.dis_optimizer.step()
        self.unet_optimizer.step()

And this will start your training! but you can experiment with your retain_graph=True.

And great work on the BPPNet Work.

~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文