返回介绍

6.5 模型定义

发布于 2024-01-28 10:44:37 字数 1933 浏览 0 评论 0 收藏 0

模型的定义主要保存在 models/ 目录下,其中 BasicModule 是对 nn.Module 的简易封装,提供快速加载和保存模型的接口。

class BasicModule(t.nn.Module):
    '''
    封装了 nn.Module,主要提供 save 和 load 两个方法
    '''

    def __init__(self):
        super(BasicModule,self).__init__()
        self.model_name = str(type(self)) # 模型的默认名字

    def load(self, path):
        '''
        可加载指定路径的模型
        '''
        self.load_state_dict(t.load(path))

    def save(self, name=None):
        '''
        保存模型,默认使用“模型名字+时间”作为文件名,
        如 AlexNet_0710_23:57:29.pth
        '''
        if name is None:
            prefix = 'checkpoints/' + self.model_name + '_'
            name = time.strftime(prefix + '%m%d_%H:%M:%S.pth')
        t.save(self.state_dict(), name)
        return name

在实际使用中,直接调用 model.save()model.load(opt.load_path) 即可。

其它自定义模型一般继承 BasicModule ,然后实现自己的模型。其中 AlexNet.py 实现了 AlexNet, ResNet34 实现了 ResNet34。在 models/__init__py 中,代码如下:

from .AlexNet import AlexNet
from .ResNet34 import ResNet34

这样在主函数中就可以写成:

from models import AlexNet
或
import models
model = models.AlexNet()
或
import models
model = getattr('models', 'AlexNet')()

其中最后一种写法最为关键,这意味着我们可以通过字符串直接指定使用的模型,而不必使用判断语句,也不必在每次新增加模型后都修改代码。新增模型后只需要在 models/__init__.py 中加上 from .new_module import new_module 即可。

其它关于模型定义的注意事项,在上一章中已详细讲解,这里就不再赘述,总结起来就是:

  • 尽量使用 nn.Sequential (比如 AlexNet)
  • 将经常使用的结构封装成子 Module(比如 GoogLeNet 的 Inception 结构,ResNet 的 Residual Block 结构)
  • 将重复且有规律性的结构,用函数生成(比如 VGG 的多种变体,ResNet 多种变体都是由多个重复卷积层组成)

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。
列表为空,暂无数据
    我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
    原文