返回介绍

1.5.6.1 分布式训练

发布于 2020-10-01 16:39:38 字数 6986 浏览 1249 评论 0 收藏 0

简介

TensorFlow只是library,分布式TensorFlow应用需要我们在多个节点启动Python脚本组成分布式计算集群。

Xiaomi Cloud-ML支持标准的分布式TensorFlow应用,用户只需编写对应的Python脚本即可提交运行,用法与单机版类似。

代码规范

由于分布式TensorFlow应用需要启动多节点,每个节点需要知道自己的角色,一般都是通过命令行参数传入,而用户自定义的命令行参数名和个数可能不同。Cloud-ML要求用户通过DISTRIBUTED_CONFIG或TF_CONFIG(Cloud-ML原先只支持tensorflow分布式时,使用TF_CONFIG这个环境变量传递分布式参数,当前仍保留,后期会统一为DISTRIBUTED_CONFIG)这个环境变量传入集群和节点的信息。

如1个master、1个ps、1个worker的情况,传入的参数如下:

DISTRIBUTED_CONFIG='{"cluster": {"master": ["127.0.0.1:3000"], "ps": ["127.0.0.1:3001"], "worker": ["127.0.0.1:3002"]}, "task": {"index": 0, "type": "ps"}, "environment": "cloud"}'
TF_CONFIG='{"cluster": {"master": ["127.0.0.1:3000"], "ps": ["127.0.0.1:3001"], "worker": ["127.0.0.1:3002"]}, "task": {"index": 0, "type": "ps"}, "environment": "cloud"}'

注: 其中 environment 赋值为 cloud 表明为云上分布式训练,tensorflow 框架会根据这个变量来判断 is_chief

然后用户Python代码中可以直接读取环境变量,获取cluster spec和type、index信息。

if os.environ.get('DISTRIBUTED_CONFIG', ""):
  env = json.loads(os.environ.get('DISTRIBUTED_CONFIG', '{}'))
  task_data = env.get('task', None)
  cluster_spec = env["cluster"]
  task_type = task_data["type"]
  task_index = task_data["index"]

代码实例

我们也实现了标准的分布式TensorFlow应用,代码地址 https://github.com/XiaoMi/cloud-ml-sdk/blob/master/cloud_ml_samples/tensorflow/linear_regression/trainer/task.py

本地运行

本地启动分布式TensorFlow应用,以samples代码为例,可以先打开3个终端,然后分别运行下面的命令。

CUDA_VISIBLE_DEVICES='' DISTRIBUTED_CONFIG='{"cluster": {"ps": ["127.0.0.1:3001"], "worker": ["127.0.0.1:3002"]}, "task": {"index": 0, "type": "ps"}}' python -m trainer.task

CUDA_VISIBLE_DEVICES='' DISTRIBUTED_CONFIG='{"cluster": {"ps": ["127.0.0.1:3001"], "worker": ["127.0.0.1:3002"]}, "task": {"index": 0, "type": "worker"}}' python -m trainer.task

使用Xiaomi Cloud-ML

如果使用Xiaomi Cloud-ML,只需要把Python代码打包,然后运行时传入 -D,之后根据提示输入task type的名称、数量、资源及定制参数等信息:

cloudml jobs submit -n distributed -m trainer.task -u fds://cloud-ml/linear/trainer-1.0.tar.gz -D

分布式训练任务提交后,可以通过命令行看到多个任务的启动,查看具体某个worker日志发现分布式训练任务也正常完成。

cloudml jobs list

cloudml jobs logs distributed-worker-0

cloudml jobs logs distributed-ps-0

参数介绍

  • -D 表示使用分布式,按照提示输入分布式相关的信息即可,支持通用分布式

旧版本分布式参数:

  • -p 表示集群的ps的个数,暂时只支持TensorFlow深度学习框架。
  • -w 表示集群的worker个数,暂时只支持TensorFlow深度学习框架。

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。
列表为空,暂无数据
    我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
    原文