重构代码以避免训练循环中的 for 循环?

发布于 2025-01-16 19:30:54 字数 772 浏览 2 评论 0原文

我正在定义一个火车函数,我将其传递给 data_loader 作为字典。

  • data_loader['train']:由训练数据组成
  • data_loader['val'] 由验证数据组成。

我创建了一个循环,它迭代我所处的阶段(train 或 val),并将模型相应地设置为 model.train() 或 model.eval() 。但是我觉得这里有太多嵌套的 for 循环,这使得计算成本很高。有人能推荐一种更好的方法来构建我的火车功能吗?我应该创建一个单独的函数来进行验证吗?

以下是我到目前为止所拥有的:

#Make train function (simple at first)
def train_network(model, optimizer, data_loader, no_epochs):

  total_epochs = notebook.tqdm(range(no_epochs))

  for epoch in total_epochs:
    
    for phase in ['train', 'val']:
      if phase == 'train':
        model.train()
      else:
        model.eval()

      for i, (images, g_truth) in enumerate(data_loader[phase]):
        images = images.to(device)
        g_truth = g_truth.to(device)

I am defining a train function which I pass in a
data_loader as a dict.

  • data_loader['train']: consists of train data
  • data_loader['val'] consists of validation data.

I created a loop which iterates through which phase I am in (either train or val) and sets the model to either model.train() or model.eval() accordingly. However I feel I have too many nested for loops here making it computationally expensive. Could anyone recommend a better way of going about constructing my train function? Should I create a separate function for validating instead?

Below is what I have so far:

#Make train function (simple at first)
def train_network(model, optimizer, data_loader, no_epochs):

  total_epochs = notebook.tqdm(range(no_epochs))

  for epoch in total_epochs:
    
    for phase in ['train', 'val']:
      if phase == 'train':
        model.train()
      else:
        model.eval()

      for i, (images, g_truth) in enumerate(data_loader[phase]):
        images = images.to(device)
        g_truth = g_truth.to(device)

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(1

缪败 2025-01-23 19:30:54

在编写训练脚本时,最外层和最内层的 for 循环很常见。

我看到的最常见的模式是:

total_epochs = notebook.tqdm(range(no_epochs))

for epoch in total_epochs:
    # Training
    for i, (images, g_truth) in enumerate(train_data_loader):
        model.train()
        images = images.to(device)
        g_truth = g_truth.to(device)
        ...

    # Validating
    for i, (images, g_truth) in enumerate(val_data_loader):
        model.eval()
        images = images.to(device)
        g_truth = g_truth.to(device)
        ...

如果您需要使用以前的变量 data_loader,则可以将 train_data_loader 替换为 data_loader["train"] 和 val_data_loaderdata_loader["val"]

这种布局很常见,因为我们通常希望在验证时而不是训练时做一些不同的事情。这也可以更好地构建代码,并避免在最内层循环的不同部分可能需要的大量 if stage == "train" 。但这确实意味着您可能需要重复一些代码。这种权衡是普遍接受的,如果我们有 3 个或更多阶段(例如多个验证阶段或评估阶段),则可能会考虑您的原始代码。

The outer-most and inner-most for loops are common when writing training scripts.

The most common pattern I see is to do:

total_epochs = notebook.tqdm(range(no_epochs))

for epoch in total_epochs:
    # Training
    for i, (images, g_truth) in enumerate(train_data_loader):
        model.train()
        images = images.to(device)
        g_truth = g_truth.to(device)
        ...

    # Validating
    for i, (images, g_truth) in enumerate(val_data_loader):
        model.eval()
        images = images.to(device)
        g_truth = g_truth.to(device)
        ...

If you need to use your previous variable data_loader, you can replace train_data_loader with data_loader["train"] and val_data_loader with data_loader["val"]

This layout is common because we generally want to do some things differently when validating as opposed to training. This also structures the code better and avoids a lot of if phase == "train" that you might need at different parts of your inner-most loop. This does however mean that you might need to duplicate some code. The trade off is generally accepted and your original code might be considered if we had 3 or more phases, like multiple validation phases or an evaluation phase as well.

~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文