文章来源于网络收集而来,版权归原创者所有,如有侵权请及时联系!
6.9 争议
以上的程序设计规范带有作者强烈的个人喜好,并不想作为一个标准,而是作为一个提议和一种参考。上述设计在很多地方还有待商榷,例如对于训练过程是否应该封装成一个 trainer
对象,或者直接封装到 BaiscModule
的 train
方法之中。对命令行参数的处理也有不少值得讨论之处。因此不要将本文中的观点作为一个必须遵守的规范,而应该看作一个参考。
本章中的设计可能会引起不少争议,其中比较值得商榷的部分主要有以下两个方面:
- 命令行参数的设置。目前大多数程序都是使用 Python 标准库中的
argparse
来处理命令行参数,也有些使用比较轻量级的click
。这种处理相对来说对命令行的支持更完备,但根据作者的经验来看,这种做法不够直观,并且代码量相对来说也较多。比如argparse
,每次增加一个命令行参数,都必须写如下代码:
parser.add_argument('-save-interval', type=int, default=500, help='how many steps to wait before saving [default:500]')
在读者眼中,这种实现方式远不如一个专门的 config.py
来的直观和易用。尤其是对于使用 Jupyter notebook 或 IPython 等交互式调试的用户来说, argparse
较难使用。
- 模型训练。有不少人喜欢将模型的训练过程集成于模型的定义之中,代码结构如下所示:
class MyModel(nn.Module):
def __init__(self,opt):
self.dataloader = Dataloader(opt)
self.optimizer = optim.Adam(self.parameters(),lr=0.001)
self.lr = opt.lr
self.model = make_model()
def forward(self,input):
pass
def train_(self):
# 训练模型
for epoch in range(opt.max_epoch)
for ii,data in enumerate(self.dataloader):
train_epoch()
model.save()
def train_epoch(self):
pass
抑或是专门设计一个 Trainer
对象,形如:
'''
code simplified from:
https://github.com/pytorch/pytorch/blob/master/torch/utils/trainer/trainer.py
'''
import heapq
from torch.autograd import Variable
class Trainer(object):
def __init__(self, model=None, criterion=None, optimizer=None, dataset=None):
self.model = model
self.criterion = criterion
self.optimizer = optimizer
self.dataset = dataset
self.iterations = 0
def run(self, epochs=1):
for i in range(1, epochs + 1):
self.train()
def train(self):
for i, data in enumerate(self.dataset, self.iterations + 1):
batch_input, batch_target = data
self.call_plugins('batch', i, batch_input, batch_target)
input_var = Variable(batch_input)
target_var = Variable(batch_target)
plugin_data = [None, None]
def closure():
batch_output = self.model(input_var)
loss = self.criterion(batch_output, target_var)
loss.backward()
if plugin_data[0] is None:
plugin_data[0] = batch_output.data
plugin_data[1] = loss.data
return loss
self.optimizer.zero_grad()
self.optimizer.step(closure)
self.iterations += i
还有一些人喜欢模仿 keras 和 scikit-learn 的设计,设计一个 fit
接口。对读者来说,这些处理方式很难说哪个更好或更差,找到最适合自己的方法才是最好的。
BasicModule
的封装,可多可少。训练过程中的很多操作都可以移到BasicModule
之中,比如get_optimizer
方法用来获取优化器,比如train_step
用来执行单歩训练。对于不同的模型,如果对应的优化器定义不一样,或者是训练方法不一样,可以复写这些函数自定义相应的方法,取决于自己的喜好和项目的实际需求。
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。
绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论