将 torch.multiprocessing 与 CUDA 结合使用时,避免张量的内存副本
我需要使用同一网络并行解析一些数据集。网络位于 CUDA 上,我在将其传递给解析函数之前调用 share_memory()
。我使用 torch.multiprocessing.Pool 生成多个进程来并行解析。
GPU 使用率随着我生成的进程数量线性增长。恐怕这是预料之中的,因为共享 CUDA 模型需要 spawn
启动方法。 我的模型仅用于评估,并在生成的函数中使用 torch.no_grad()
运行。
我可以阻止这种情况吗?提出了同样的问题 这里但从未得到回复。
这是一个 MWE。 使用 M = 100
我得到 CUDA OOM,使用 10 则没有。
如果我将模型移至 CPU 并使用 fork
内存保持不变。在 CPU 上使用 spawn
仍然会增加内存。
import signal
import numpy as np
from torch import multiprocessing as mp
import torch
import time
N = 1000
M = 100
DEVICE = 'cuda'
STOP = 'STOP'
data_in = {m: np.random.rand(N) for m in range(M)}
data_out = {m: np.random.rand(N) for m in range(M)}
def initializer():
"""Ignore CTRL+C in the worker process."""
signal.signal(signal.SIGINT, signal.SIG_IGN)
def online_test(queue, model, shared_stats):
while True: # keep process alive for testing
# print(f'... {data_id} waiting ...')
epoch, data_id = queue.get()
if data_id == STOP:
print(f'... test function is stopping ...')
break
print(f'testing function for {data_id} has started for epoch {epoch}')
shared_stats.update({data_id: {k: [] for k in ['prediction', 'error']}})
# print(f'... {data_id} evaluation ...')
# time.sleep(np.random.randint(1,10))
pred = model(torch.Tensor(data_in[data_id]).to(device=DEVICE)).cpu().detach().numpy()
err = pred - data_out[data_id]
shared_stats.update({data_id: {'prediction': epoch, 'error': - epoch}})
# shared_stats.update({data_id: {'prediction': list(pred), 'error': list(err)}})
queue.task_done() # notify parent that testing is done for requested epoch
if __name__ == '__main__':
stats = {**{'epoch': []},
**{data_id: {k: [] for k in ['prediction', 'error']} for data_id in data_in.keys()}}
train_model = torch.nn.Sequential(torch.nn.Linear(N, N)).to(device=DEVICE)
test_model = torch.nn.Sequential(torch.nn.Linear(N, N)).to(device=DEVICE)
test_model.share_memory()
mp.set_start_method('spawn')
manager = mp.Manager()
test_queue = manager.JoinableQueue()
shared_stats = manager.dict()
pool = mp.Pool(initializer=initializer)
for data_id in data_in.keys():
pool.apply_async(online_test,
args=(test_queue, test_model, shared_stats))
test_queue.put((0, data_id)) # testing can start
try: # wrap all in a try-except to handle KeyboardInterrupt
for epoch in range(5):
print('training epoch', epoch)
# time.sleep(3)
# ... here I do some training and then copy my parameters to test_model
print('... waiting for testing before moving on to next epoch ...')
test_queue.join()
stats['epoch'].append(epoch + 1)
test_model.load_state_dict(train_model.state_dict())
print(f'... epoch {epoch} testing is done, stats are')
for data_id in shared_stats.keys(): # but first copy stats here
for k in stats[data_id].keys():
mu = np.mean(shared_stats[data_id][k])
stats[data_id][k].append(mu)
# print(' ', data_id, k, mu)
test_queue.put((epoch + 1, data_id))
for data_id in shared_stats.keys(): # notify all procs to end
test_queue.put((-1, STOP))
print(stats)
except KeyboardInterrupt:
pool.terminate()
else:
pool.close()
pool.join()
I need to parse some datasets in parallel using the same network. The network is on CUDA and I call share_memory()
before passing it to the parse function. I spawn multiple processes to parse in parallel using torch.multiprocessing.Pool
.
The GPU usage grows linearly with the number of processes I spawn. I am afraid this is expected, because sharing CUDA models requires the spawn
start method.
My model is used only for evaluation and runs with torch.no_grad()
in the spawned function.
Can I prevent this? The same question was asked here but never got a reply.
Here is a MWE.
With M = 100
I get CUDA OOM, with 10 I don't.
If I move my model to CPU and use fork
memory stays constant. Using spawn
on CPU still increases memory.
import signal
import numpy as np
from torch import multiprocessing as mp
import torch
import time
N = 1000
M = 100
DEVICE = 'cuda'
STOP = 'STOP'
data_in = {m: np.random.rand(N) for m in range(M)}
data_out = {m: np.random.rand(N) for m in range(M)}
def initializer():
"""Ignore CTRL+C in the worker process."""
signal.signal(signal.SIGINT, signal.SIG_IGN)
def online_test(queue, model, shared_stats):
while True: # keep process alive for testing
# print(f'... {data_id} waiting ...')
epoch, data_id = queue.get()
if data_id == STOP:
print(f'... test function is stopping ...')
break
print(f'testing function for {data_id} has started for epoch {epoch}')
shared_stats.update({data_id: {k: [] for k in ['prediction', 'error']}})
# print(f'... {data_id} evaluation ...')
# time.sleep(np.random.randint(1,10))
pred = model(torch.Tensor(data_in[data_id]).to(device=DEVICE)).cpu().detach().numpy()
err = pred - data_out[data_id]
shared_stats.update({data_id: {'prediction': epoch, 'error': - epoch}})
# shared_stats.update({data_id: {'prediction': list(pred), 'error': list(err)}})
queue.task_done() # notify parent that testing is done for requested epoch
if __name__ == '__main__':
stats = {**{'epoch': []},
**{data_id: {k: [] for k in ['prediction', 'error']} for data_id in data_in.keys()}}
train_model = torch.nn.Sequential(torch.nn.Linear(N, N)).to(device=DEVICE)
test_model = torch.nn.Sequential(torch.nn.Linear(N, N)).to(device=DEVICE)
test_model.share_memory()
mp.set_start_method('spawn')
manager = mp.Manager()
test_queue = manager.JoinableQueue()
shared_stats = manager.dict()
pool = mp.Pool(initializer=initializer)
for data_id in data_in.keys():
pool.apply_async(online_test,
args=(test_queue, test_model, shared_stats))
test_queue.put((0, data_id)) # testing can start
try: # wrap all in a try-except to handle KeyboardInterrupt
for epoch in range(5):
print('training epoch', epoch)
# time.sleep(3)
# ... here I do some training and then copy my parameters to test_model
print('... waiting for testing before moving on to next epoch ...')
test_queue.join()
stats['epoch'].append(epoch + 1)
test_model.load_state_dict(train_model.state_dict())
print(f'... epoch {epoch} testing is done, stats are')
for data_id in shared_stats.keys(): # but first copy stats here
for k in stats[data_id].keys():
mu = np.mean(shared_stats[data_id][k])
stats[data_id][k].append(mu)
# print(' ', data_id, k, mu)
test_queue.put((epoch + 1, data_id))
for data_id in shared_stats.keys(): # notify all procs to end
test_queue.put((-1, STOP))
print(stats)
except KeyboardInterrupt:
pool.terminate()
else:
pool.close()
pool.join()
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论