返回介绍

1.6.2 TensorFlow Serving介绍

发布于 2020-10-01 16:39:38 字数 4780 浏览 1070 评论 0 收藏 0

简介

TensorFlow的模型文件包含了深度学习模型的Graph和所有参数,其实就是checkpoint文件,用户可以加载模型文件继续训练或者对外提供Inference服务。

使用SavedModel导出模型

模型导出方式参考 https://tensorflow.github.io/serving/serving_basic

使用方法基本如下。

from tensorflow.python.saved_model import builder as saved_model_builder

export_path_base = sys.argv[-1]
export_path = os.path.join(
      compat.as_bytes(export_path_base),
      compat.as_bytes(str(FLAGS.model_version)))
print 'Exporting trained model to', export_path

builder = saved_model_builder.SavedModelBuilder(export_path)
builder.add_meta_graph_and_variables(
      sess, [tag_constants.SERVING],
      signature_def_map={
           'predict_images':
               prediction_signature,
           signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
               classification_signature,
      },
      legacy_init_op=legacy_init_op)

builder.save()

可以参考 https://github.com/tobegit3hub/deep_recommend_system/ 提供的可运行代码示例。

./dense_classifier.py --mode savedmodel

使用exporter导出模型

这里有导出TensorFlow serving支持的模型文件例子,可以参考使用 https://github.com/tobegit3hub/deep_recommend_system/blob/master/dense_classifier.py

导出的代码也比较简单,用户在inputs和output中填入模型Inference时的输入和输出即可。

from tensorflow.contrib.session_bundle import exporter

flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string("model_path", "./model", "The path to export the model")
flags.DEFINE_integer("export_version", 1, "Version number of the model")

# Define the graph
keys_placeholder = tf.placeholder(tf.int32, shape=[None, 1])
keys = tf.identity(keys_placeholder)

# Start the session

# Export the model
print("Exporting trained model to {}".format(FLAGS.model_path))
model_exporter = exporter.Exporter(saver)
model_exporter.init(
  sess.graph.as_graph_def(),
    named_graph_signatures={
      'inputs': exporter.generic_signature({"keys": keys_placeholder, "features": inference_features}),
      'outputs': exporter.generic_signature({"keys": keys, "softmax": inference_softmax, "prediction": inference_op})
    })
model_exporter.export(FLAGS.model_path, tf.constant(FLAGS.export_version), sess)
print 'Done exporting!'

与SavedModel方法相比,两者都可以直接用TensorFlow Serving加载,我们使用deep_recommend_system导出两种模型方式测试过预测结果一模一样,只是模型文件大小不同。

导入带assert的模型文件

在NLP等场景除了参数文件,还需要导入vocabulary等文件,可以在exporter中设置assets_collection,参考 https://github.com/tensorflow/serving/issues/264

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。
列表为空,暂无数据
    我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
    原文