文章来源于网络收集而来,版权归原创者所有,如有侵权请及时联系!
4.2 保存训练检查点
上文曾经提到,训练模型意味着通过许多个训练周期更新其参数(或者用TensorFlow的语言来说,变量)。由于变量都保存在内存中,所以若计算机经历长时间训练后突然断电,所有工作都将丢失。幸运的是,借助tf.train.Saver类可将数据流图中的变量保存到专门的二进制文件中。我们应当周期性地保存所有变量,创建检查点(checkpoint)文件,并在必要时从最近的检查点恢复训练。
为使用Saver类,需要对之前的训练闭环代码框架略做修改:
在上述代码中,在开启会话对象之前实例化了一个Saver对象,然后在训练闭环部分插入了几行代码,使的每完成1000次训练迭代便调用一次tf.train.Saver.save方法,并在训练结束后,再次调用该方法。每次调用tf.train.Saver.save方法都将创建一个遵循命名模板my-model-{step}的检查点文件,如my-model-1000、my-model-2000等。该文件会保存每个变量的当前值。默认情况下,Saver对象只会保留最近的5个文件,更早的文件都将被自动删除。
如果希望从某个检查点恢复训练,则应使用tf.train.get_checkpoint_state方法,以验证之前是否有检查点文件被保存下来,而tf.train.Saver.restore方法将负责恢复变量的值。
在上述代码中,首先检查是否有检查点文件存在,并在开始训练闭环前恢复各变量的值,还可依据检查点文件的名称恢复全局迭代次数。
既然已了解了有监督学习的一般原理,以及如何保存训练进度,接下来将对一些常见的推断模型进行讨论。
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。
绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论