不同设备上的张量可以相加吗?

发布于 2025-01-14 01:20:19 字数 727 浏览 3 评论 0原文

最近我发现了一件奇怪的事情。据我所知,当你想对两个张量进行一些操作时,你应该确保它们位于同一设备上。但是,当我这样编写代码时,它会意外运行

import torch
a = torch.tensor(1, device='cuda')
print(a.device)
b = torch.tensor(2, device='cpu')
print(b.device)
torch(a+b)


cuda:0
cpu
tensor(3, device='cuda:0')

并且无法在我的代码中工作,如下所示:

pts_1_tile = torch.tensor([[0], [0]], dtype=torch.float32)
torch.add(pred_4pt_shift, pts_1_tile)

在此处输入图像描述

这里 pred_4pt_shift 是子网的中间结果,它是 GPU 上的张量。 我的问题是,为什么第一个代码可以工作,但第二个代码报告这个不同的设备错误?

I found a curious thing recently. As far as I know, when you want to do some operations on two tensors, you should make sure that they are on the same device. But when I write my code like this, it runs unexpectly

import torch
a = torch.tensor(1, device='cuda')
print(a.device)
b = torch.tensor(2, device='cpu')
print(b.device)
torch(a+b)


cuda:0
cpu
tensor(3, device='cuda:0')

And it can't work in my code like this:

pts_1_tile = torch.tensor([[0], [0]], dtype=torch.float32)
torch.add(pred_4pt_shift, pts_1_tile)

enter image description here

here pred_4pt_shift is an intermediate result of a sub-Net, and it is a tensor on GPU.
My question is that why the first code can work but the second one reports this different device error?

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(1

剧终人散尽 2025-01-21 01:20:19

我猜你的意思是 print(a+b) 而不是 torch(a+b)

标量张量是一种特殊情况,可以自动移动到目标设备。
如果将 ab 定义为一维张量,则会出错:

import torch
a = torch.tensor([1], device='cuda')
b = torch.tensor([2], device='cpu')
print(a+b)

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

I guess you mean print(a+b) rather than torch(a+b).

Scalar tensor is a special case which could be automatically moved to target device.
If you define a and b as 1-d tensor, it will error out:

import torch
a = torch.tensor([1], device='cuda')
b = torch.tensor([2], device='cpu')
print(a+b)

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文