我正在尝试创建一个希望与Pytorch集成的类似张量的类。该类将必须以自定义格式存储张量值和梯度。我一直在按照使用类似张量的类型扩展 TORCH
。我有疑问,特别是与梯度存储和计算有关的问题:
- 我想从(float)张量初始化我的类,并能够将其转换回。我知道我可以使用numpy()函数从张量中检索值,但是如果我也想存储它,我该如何获取梯度数据呢?是否可以直接访问
tensor.grad
正确的方法?当我转换回张量时,如何将其退还存储的梯度数据?
- 在上面的链接中,我看到了用于制作诸如
torch.add
的功能的说明由Autograd。如何定义Torch.ADD的自定义前向和后退版本?在其他地方,我看到了使用自定义功能,但我不确定是否可以将这些说明与这种情况合并。有什么方法可以创建自定义添加函数,并说该函数的应用
为我的课程实现 torch.add
吗?
谢谢
I am attempting to create a tensor-like class that I wish to integrate with PyTorch. The class will have to store the tensor values and gradients in a custom format. I've been following the instructions at extending torch
with a Tensor-like type. I have questions especially pertaining to gradient storage and calculation:
- I want to initialize my class from a (float) tensor and be able to convert it back. I know I can retrieve the values from the tensor using the numpy() function, but how do I get gradient data if I wish to store that too? Is directly accessing
tensor.grad
the correct way to do so? When I convert back to tensor, how can I give it back the stored gradient data?
- In the above link I saw the instructions for making functions like
torch.add
work with my custom type, but that seems to only deal with the forward pass and I will also need to modify how the gradient is calculated by autograd. How do I define both the custom forward and backward versions of torch.add? Elsewhere, I saw the instructions for extending torch.autograd
with a custom function but I am not sure if those instructions can be merged with this case. Is there a way I can create a custom add function and say that that function's apply
implements torch.add
for my class?
Thank you
发布评论
评论(1)
tensor.grad.data
(检查grad not none
!)torch.tensor
上执行操作,还是有成为您的自定义类型?然后,张量可以携带大多数函数的自动射击信息(您只能在自定义功能上计算梯度)。您也可以考虑子类torch.tensor
,这将减轻很多问题。tensor.grad.data
(check ifgrad not None
!)torch.Tensor
directly or does it have to be your custom type? Then the tensor could carry the autograd information for most function (you would only compute the gradient yourself on custom functions). You could also consider subclassingtorch.Tensor
, that would alleviate a lot of those problems.