Pytorch dist.all_gather_object 挂起
我正在使用 dist.all_gather_object (PyTorch 版本 1.8)从所有 GPU 收集样本 ID:
for batch in dataloader:
video_sns = batch["video_ids"]
logits = model(batch)
group_gather_vdnames = [None for _ in range(envs['nGPU'])]
group_gather_logits = [torch.zeros_like(logits) for _ in range(envs['nGPU'])]
dist.all_gather(group_gather_logits, logits)
dist.all_gather_object(group_gather_vdnames, video_sns)
行 dist.all_gather(group_gather_logits, logits)
工作正常, 但程序挂在 dist.all_gather_object(group_gather_vdnames, video_sns)
行。
我想知道为什么程序挂在 dist.all_gather_object()
处,我该如何修复它?
额外信息: 我在具有多个 GPU 的本地计算机上运行 ddp 代码。启动脚本是:
export NUM_NODES=1
export NUM_GPUS_PER_NODE=2
export NODE_RANK=0
export WORLD_SIZE=$(($NUM_NODES * $NUM_GPUS_PER_NODE))
python -m torch.distributed.launch \
--nproc_per_node=$NUM_GPUS_PER_NODE \
--nnodes=$NUM_NODES \
--node_rank $NODE_RANK \
main.py \
--my_args
I'm using dist.all_gather_object (PyTorch version 1.8) to collect sample ids from all GPUs:
for batch in dataloader:
video_sns = batch["video_ids"]
logits = model(batch)
group_gather_vdnames = [None for _ in range(envs['nGPU'])]
group_gather_logits = [torch.zeros_like(logits) for _ in range(envs['nGPU'])]
dist.all_gather(group_gather_logits, logits)
dist.all_gather_object(group_gather_vdnames, video_sns)
The line dist.all_gather(group_gather_logits, logits)
works properly,
but program hangs at line dist.all_gather_object(group_gather_vdnames, video_sns)
.
I wonder why the program hangs at dist.all_gather_object()
, how can I fix it ?
EXTRA INFO:
I run my ddp code on a local machine with multiple GPUs. The start script is:
export NUM_NODES=1
export NUM_GPUS_PER_NODE=2
export NODE_RANK=0
export WORLD_SIZE=$(($NUM_NODES * $NUM_GPUS_PER_NODE))
python -m torch.distributed.launch \
--nproc_per_node=$NUM_GPUS_PER_NODE \
--nnodes=$NUM_NODES \
--node_rank $NODE_RANK \
main.py \
--my_args
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。
绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论
评论(1)
结果我们需要手动设置设备 ID,如 docstring< 中提到的
dist.all_gather_object()
API 的 /a>。添加
并且代码可以工作。
我一直以为 GPU ID 是由 PyTorch dist 自动设置的,事实证明并非如此。
Turns out we need to set the device id manually as mentioned in the docstring of
dist.all_gather_object()
API.Adding
and the codes work.
I always thought the GPU ID is set automatically by PyTorch dist, turns out it's not.