如何在培训UNET模型中处理NotimplemplementError?

发布于 2025-02-13 11:24:56 字数 1509 浏览 3 评论 0 原文

def train_fn(data_loader, model, optimizer):

model.train()
total_loss = 0.0

for images, masks in tqdm(data_loader):

  images = images.to(DEVICE)
  masks = masks.to(DEVICE)

  optimizer.zero_grad()
  logits, loss = model(images,masks)
  loss.backward()
  optimizer.step()

  total_loss += loss.item()



return total_loss/ len(data_loader)


def eval_fn(data_loader, model):

model.eval()
total_loss = 0.0

with torch.no_grad():

  for images, masks in tqdm(data_loader):

    images = images.to(DEVICE)
    masks = masks.to(DEVICE)

    logits, loss = model(images,masks)


    total_loss += loss.item()


return total_loss/ len(data_loader)

optimizer = torch.optim.Adam(model.parameters(), lr = LR)

best_valid_loss = np.Inf

for i in range(EPOCHS):


train_loss = train_fn(trainloader, model, optimizer)
valid_loss = eval_fn(validloader, model)

if valid_loss < best_valid_loss:
  torch.save(model.state_dict(), 'best_model.pt')
  print("SAVED_MODEL")
  best_valid_loss = valid_loss

print(f“ epoch:{i+1} train_loss:{train_loss}有效_loss:{有效_loss}”)

当我尝试训练模型时,我会出现以下错误:

0%| | 0/15 [00:00&lt;?,?it/s]

NotimplemplededError Trackback(最近的最新通话最后) 在 () 4 5 ----&gt; 6 train_loss = train_fn(火车负载器,模型,优化器) 7有效_loss = eval_fn(有效载荷,模型) 8

2帧 /USR/local/lib/python3.7/dist-packages/torch/nn/modules/module.py.py in _forward_unimplemented(self, *input) 199个注册钩子,而后者默默地忽略了它们。 200“” - &gt; 201#提高通知 202 203

NotimplementedError:

我该如何处理?

def train_fn(data_loader, model, optimizer):

model.train()
total_loss = 0.0

for images, masks in tqdm(data_loader):

  images = images.to(DEVICE)
  masks = masks.to(DEVICE)

  optimizer.zero_grad()
  logits, loss = model(images,masks)
  loss.backward()
  optimizer.step()

  total_loss += loss.item()



return total_loss/ len(data_loader)


def eval_fn(data_loader, model):

model.eval()
total_loss = 0.0

with torch.no_grad():

  for images, masks in tqdm(data_loader):

    images = images.to(DEVICE)
    masks = masks.to(DEVICE)

    logits, loss = model(images,masks)


    total_loss += loss.item()


return total_loss/ len(data_loader)

optimizer = torch.optim.Adam(model.parameters(), lr = LR)

best_valid_loss = np.Inf

for i in range(EPOCHS):


train_loss = train_fn(trainloader, model, optimizer)
valid_loss = eval_fn(validloader, model)

if valid_loss < best_valid_loss:
  torch.save(model.state_dict(), 'best_model.pt')
  print("SAVED_MODEL")
  best_valid_loss = valid_loss

print(f"Epoch : {i+1} Train_loss: {train_loss} Valid_loss: {valid_loss}")

I get the following error when I try to train the model:

0%| | 0/15 [00:00<?, ?it/s]

NotImplementedError Traceback (most recent call last)
in ()
4
5
----> 6 train_loss = train_fn(trainloader, model, optimizer)
7 valid_loss = eval_fn(validloader, model)
8

2 frames
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _forward_unimplemented(self, *input)
199 registered hooks while the latter silently ignores them.
200 """
--> 201 # raise NotImplementedError
202
203

NotImplementedError:

How do I deal with this?

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(1

挽手叙旧 2025-02-20 11:24:56

查看您在评论中提供的链接,您的模型定义看起来像这样:

class SegmentationModel(nn.Module):

  def __init__(self):
    super(SegmentationModel,self).__init__()

    self.arc = smp.Unet(
        encoder_name = ENCODER,
        encoder_weights = WEIGHTS,
        in_channels = 3,
        classes = 1,
        activation = None
    )

    def forward(self, images, masks = None):
      logits = self.arc(images)

      if masks != None:
        loss1 = DiceLoss(mode = 'binary')(logits, masks)
        loss2 = nn.BCEWithLogitsLoss()(logits,masks)
        return logits, loss1 + loss2

      return logits

如果您近距离看,您会看到 forward()具有不稳定的额外凹痕,使其成为内部功能 __ INIT __()而不是 SemengeationModel 的方法。将其转移到左侧,应该很好地工作:

class SegmentationModel(nn.Module):

  def __init__(self):
    super(SegmentationModel,self).__init__()

    self.arc = smp.Unet(
        encoder_name = ENCODER,
        encoder_weights = WEIGHTS,
        in_channels = 3,
        classes = 1,
        activation = None
    )

  def forward(self, images, masks = None):
    logits = self.arc(images)

    if masks != None:
      loss1 = DiceLoss(mode = 'binary')(logits, masks)
      loss2 = nn.BCEWithLogitsLoss()(logits,masks)
      return logits, loss1 + loss2

    return logits

Looking at the link you provided in the comment, your model definition looks like this:

class SegmentationModel(nn.Module):

  def __init__(self):
    super(SegmentationModel,self).__init__()

    self.arc = smp.Unet(
        encoder_name = ENCODER,
        encoder_weights = WEIGHTS,
        in_channels = 3,
        classes = 1,
        activation = None
    )

    def forward(self, images, masks = None):
      logits = self.arc(images)

      if masks != None:
        loss1 = DiceLoss(mode = 'binary')(logits, masks)
        loss2 = nn.BCEWithLogitsLoss()(logits,masks)
        return logits, loss1 + loss2

      return logits

If you look close, you'll see forward() has an erratic extra indentation, making it an internal function inside __init__() rather than a method of a SegmentationModel. Shift it a bit to left, and it should work fine:

class SegmentationModel(nn.Module):

  def __init__(self):
    super(SegmentationModel,self).__init__()

    self.arc = smp.Unet(
        encoder_name = ENCODER,
        encoder_weights = WEIGHTS,
        in_channels = 3,
        classes = 1,
        activation = None
    )

  def forward(self, images, masks = None):
    logits = self.arc(images)

    if masks != None:
      loss1 = DiceLoss(mode = 'binary')(logits, masks)
      loss2 = nn.BCEWithLogitsLoss()(logits,masks)
      return logits, loss1 + loss2

    return logits
~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文