文章来源于网络收集而来,版权归原创者所有,如有侵权请及时联系!
5.5 持久化
在 PyTorch 中,以下对象可以持久化到硬盘,并能通过相应的方法加载到内存中:
- Tensor
- Variable
- nn.Module
- Optimizer
本质上上述这些信息最终都是保存成 Tensor。Tensor 的保存和加载十分的简单,使用 t.save 和 t.load 即可完成相应的功能。在 save/load 时可指定使用的 pickle 模块,在 load 时还可将 GPU tensor 映射到 CPU 或其它 GPU 上。
我们可以通过 t.save(obj, file_name)
等方法保存任意可序列化的对象,然后通过 obj = t.load(file_name)
方法加载保存的数据。对于 Module 和 Optimizer 对象,这里建议保存对应的 state_dict
,而不是直接保存整个 Module/Optimizer 对象。Optimizer 对象保存的主要是参数,以及动量信息,通过加载之前的动量信息,能够有效地减少模型震荡,下面举例说明。
a = t.Tensor(3, 4)
if t.cuda.is_available():
a = a.cuda(1) # 把 a 转为 GPU1 上的 tensor,
t.save(a,'a.pth')
# 加载为 b, 存储于 GPU1 上(因为保存时 tensor 就在 GPU1 上)
b = t.load('a.pth')
# 加载为 c, 存储于 CPU
c = t.load('a.pth', map_location=lambda storage, loc: storage)
# 加载为 d, 存储于 GPU0 上
d = t.load('a.pth', map_location={'cuda:1':'cuda:0'})
t.set_default_tensor_type('torch.FloatTensor')
from torchvision.models import SqueezeNet
model = SqueezeNet()
# module 的 state_dict 是一个字典
model.state_dict().keys()
odict_keys(['features.0.weight', 'features.0.bias', 'features.3.squeeze.weight', 'features.3.squeeze.bias', 'features.3.expand1x1.weight', 'features.3.expand1x1.bias', 'features.3.expand3x3.weight', 'features.3.expand3x3.bias', 'features.4.squeeze.weight', 'features.4.squeeze.bias', 'features.4.expand1x1.weight', 'features.4.expand1x1.bias', 'features.4.expand3x3.weight', 'features.4.expand3x3.bias', 'features.5.squeeze.weight', 'features.5.squeeze.bias', 'features.5.expand1x1.weight', 'features.5.expand1x1.bias', 'features.5.expand3x3.weight', 'features.5.expand3x3.bias', 'features.7.squeeze.weight', 'features.7.squeeze.bias', 'features.7.expand1x1.weight', 'features.7.expand1x1.bias', 'features.7.expand3x3.weight', 'features.7.expand3x3.bias', 'features.8.squeeze.weight', 'features.8.squeeze.bias', 'features.8.expand1x1.weight', 'features.8.expand1x1.bias', 'features.8.expand3x3.weight', 'features.8.expand3x3.bias', 'features.9.squeeze.weight', 'features.9.squeeze.bias', 'features.9.expand1x1.weight', 'features.9.expand1x1.bias', 'features.9.expand3x3.weight', 'features.9.expand3x3.bias', 'features.10.squeeze.weight', 'features.10.squeeze.bias', 'features.10.expand1x1.weight', 'features.10.expand1x1.bias', 'features.10.expand3x3.weight', 'features.10.expand3x3.bias', 'features.12.squeeze.weight', 'features.12.squeeze.bias', 'features.12.expand1x1.weight', 'features.12.expand1x1.bias', 'features.12.expand3x3.weight', 'features.12.expand3x3.bias', 'classifier.1.weight', 'classifier.1.bias'])
# Module 对象的保存与加载
t.save(model.state_dict(), 'squeezenet.pth')
model.load_state_dict(t.load('squeezenet.pth'))
optimizer = t.optim.Adam(model.parameters(), lr=0.1)
t.save(optimizer.state_dict(), 'optimizer.pth')
optimizer.load_state_dict(t.load('optimizer.pth'))
all_data = dict(
optimizer = optimizer.state_dict(),
model = model.state_dict(),
info = u'模型和优化器的所有参数'
)
t.save(all_data, 'all.pth')
all_data = t.load('all.pth')
all_data.keys()
dict_keys(['model', 'optimizer', 'info'])
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。
绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论