如何从 PyTorch split() 获取张量
PyTorch 的 split 函数返回张量元组。但我需要批量矩阵乘以结果。有没有一种简单的方法来分割张量并返回张量?这就是我尝试过的:
m = [[2, 3, 5, 7],
[11, 13, 17, 19],
[23, 29, 31, 37],
[41, 43, 47, 53]]
m_split = torch.tensor(m).split(2, dim=1)
torch.tensor([[[2, 3, 5, 7]]]).matmul(m_split)
这给了我一个错误,因为 m_split 是张量的元组而不是张量。我可以进行 view
或 reshape
调用吗?
PyTorch's split
function returns back a tuple of tensors. But I need to batch matrix multiply the result. Is there an easy way to split a tensor and get back a tensor? This is what I tried:
m = [[2, 3, 5, 7],
[11, 13, 17, 19],
[23, 29, 31, 37],
[41, 43, 47, 53]]
m_split = torch.tensor(m).split(2, dim=1)
torch.tensor([[[2, 3, 5, 7]]]).matmul(m_split)
This gives me an error because m_split
is a tuple of tensors rather than being a tensor. Is there a view
or reshape
call I can make instead?
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。
绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论
评论(1)
我认为你可以这样做
i think you can do as following