返回介绍

4.2 保存训练检查点

发布于 2024-02-05 23:12:36 字数 979 浏览 0 评论 0 收藏 0

上文曾经提到,训练模型意味着通过许多个训练周期更新其参数(或者用TensorFlow的语言来说,变量)。由于变量都保存在内存中,所以若计算机经历长时间训练后突然断电,所有工作都将丢失。幸运的是,借助tf.train.Saver类可将数据流图中的变量保存到专门的二进制文件中。我们应当周期性地保存所有变量,创建检查点(checkpoint)文件,并在必要时从最近的检查点恢复训练。

为使用Saver类,需要对之前的训练闭环代码框架略做修改:

在上述代码中,在开启会话对象之前实例化了一个Saver对象,然后在训练闭环部分插入了几行代码,使的每完成1000次训练迭代便调用一次tf.train.Saver.save方法,并在训练结束后,再次调用该方法。每次调用tf.train.Saver.save方法都将创建一个遵循命名模板my-model-{step}的检查点文件,如my-model-1000、my-model-2000等。该文件会保存每个变量的当前值。默认情况下,Saver对象只会保留最近的5个文件,更早的文件都将被自动删除。

如果希望从某个检查点恢复训练,则应使用tf.train.get_checkpoint_state方法,以验证之前是否有检查点文件被保存下来,而tf.train.Saver.restore方法将负责恢复变量的值。

在上述代码中,首先检查是否有检查点文件存在,并在开始训练闭环前恢复各变量的值,还可依据检查点文件的名称恢复全局迭代次数。

既然已了解了有监督学习的一般原理,以及如何保存训练进度,接下来将对一些常见的推断模型进行讨论。

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。
列表为空,暂无数据
    我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
    原文