文章来源于网络收集而来,版权归原创者所有,如有侵权请及时联系!
6.5 模型定义
模型的定义主要保存在 models/
目录下,其中 BasicModule
是对 nn.Module
的简易封装,提供快速加载和保存模型的接口。
class BasicModule(t.nn.Module):
'''
封装了 nn.Module,主要提供 save 和 load 两个方法
'''
def __init__(self):
super(BasicModule,self).__init__()
self.model_name = str(type(self)) # 模型的默认名字
def load(self, path):
'''
可加载指定路径的模型
'''
self.load_state_dict(t.load(path))
def save(self, name=None):
'''
保存模型,默认使用“模型名字+时间”作为文件名,
如 AlexNet_0710_23:57:29.pth
'''
if name is None:
prefix = 'checkpoints/' + self.model_name + '_'
name = time.strftime(prefix + '%m%d_%H:%M:%S.pth')
t.save(self.state_dict(), name)
return name
在实际使用中,直接调用 model.save()
及 model.load(opt.load_path)
即可。
其它自定义模型一般继承 BasicModule
,然后实现自己的模型。其中 AlexNet.py
实现了 AlexNet, ResNet34
实现了 ResNet34。在 models/__init__py
中,代码如下:
from .AlexNet import AlexNet
from .ResNet34 import ResNet34
这样在主函数中就可以写成:
from models import AlexNet
或
import models
model = models.AlexNet()
或
import models
model = getattr('models', 'AlexNet')()
其中最后一种写法最为关键,这意味着我们可以通过字符串直接指定使用的模型,而不必使用判断语句,也不必在每次新增加模型后都修改代码。新增模型后只需要在 models/__init__.py
中加上 from .new_module import new_module
即可。
其它关于模型定义的注意事项,在上一章中已详细讲解,这里就不再赘述,总结起来就是:
- 尽量使用
nn.Sequential
(比如 AlexNet) - 将经常使用的结构封装成子 Module(比如 GoogLeNet 的 Inception 结构,ResNet 的 Residual Block 结构)
- 将重复且有规律性的结构,用函数生成(比如 VGG 的多种变体,ResNet 多种变体都是由多个重复卷积层组成)
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。
绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论