返回介绍

6.9 争议

发布于 2024-01-28 10:46:53 字数 3489 浏览 0 评论 0 收藏 0

以上的程序设计规范带有作者强烈的个人喜好,并不想作为一个标准,而是作为一个提议和一种参考。上述设计在很多地方还有待商榷,例如对于训练过程是否应该封装成一个 trainer 对象,或者直接封装到 BaiscModuletrain 方法之中。对命令行参数的处理也有不少值得讨论之处。因此不要将本文中的观点作为一个必须遵守的规范,而应该看作一个参考。

本章中的设计可能会引起不少争议,其中比较值得商榷的部分主要有以下两个方面:

  • 命令行参数的设置。目前大多数程序都是使用 Python 标准库中的 argparse 来处理命令行参数,也有些使用比较轻量级的 click 。这种处理相对来说对命令行的支持更完备,但根据作者的经验来看,这种做法不够直观,并且代码量相对来说也较多。比如 argparse ,每次增加一个命令行参数,都必须写如下代码:
parser.add_argument('-save-interval', type=int, default=500, help='how many steps to wait before saving [default:500]')

在读者眼中,这种实现方式远不如一个专门的 config.py 来的直观和易用。尤其是对于使用 Jupyter notebook 或 IPython 等交互式调试的用户来说, argparse 较难使用。

  • 模型训练。有不少人喜欢将模型的训练过程集成于模型的定义之中,代码结构如下所示:
  class MyModel(nn.Module):
  	
      def __init__(self,opt):
          self.dataloader = Dataloader(opt)
          self.optimizer  = optim.Adam(self.parameters(),lr=0.001)
          self.lr = opt.lr
          self.model = make_model()
      
      def forward(self,input):
          pass
      
      def train_(self):
          # 训练模型
          for epoch in range(opt.max_epoch)
          	for ii,data in enumerate(self.dataloader):
              	train_epoch()
              
          	model.save()
  	
      def train_epoch(self):
          pass

抑或是专门设计一个 Trainer 对象,形如:

    '''
  code simplified from:
   https://github.com/pytorch/pytorch/blob/master/torch/utils/trainer/trainer.py 
  '''
  import heapq
  from torch.autograd import Variable

  class Trainer(object):

      def __init__(self, model=None, criterion=None, optimizer=None, dataset=None):
          self.model = model
          self.criterion = criterion
          self.optimizer = optimizer
          self.dataset = dataset
          self.iterations = 0

      def run(self, epochs=1):
          for i in range(1, epochs + 1):
              self.train()

      def train(self):
          for i, data in enumerate(self.dataset, self.iterations + 1):
              batch_input, batch_target = data
              self.call_plugins('batch', i, batch_input, batch_target)
              input_var = Variable(batch_input)
              target_var = Variable(batch_target)
    
              plugin_data = [None, None]
    
              def closure():
                  batch_output = self.model(input_var)
                  loss = self.criterion(batch_output, target_var)
                  loss.backward()
                  if plugin_data[0] is None:
                      plugin_data[0] = batch_output.data
                      plugin_data[1] = loss.data
                  return loss
    
              self.optimizer.zero_grad()
              self.optimizer.step(closure)
    
          self.iterations += i

还有一些人喜欢模仿 keras 和 scikit-learn 的设计,设计一个 fit 接口。对读者来说,这些处理方式很难说哪个更好或更差,找到最适合自己的方法才是最好的。

  • BasicModule 的封装,可多可少。训练过程中的很多操作都可以移到 BasicModule 之中,比如 get_optimizer 方法用来获取优化器,比如 train_step 用来执行单歩训练。对于不同的模型,如果对应的优化器定义不一样,或者是训练方法不一样,可以复写这些函数自定义相应的方法,取决于自己的喜好和项目的实际需求。

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。
列表为空,暂无数据
    我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
    原文