亲爱的社区,我在pytorch的张量索引方面面临着挑战。问题很简单。给定张量创建索引张量,以索引每列的最大值。
x = T.tensor([[0, 3, 0, 5, 9, 8, 2, 0],
[0, 4, 9, 6, 7, 9, 1, 0]])
鉴于此张量,我想构建一个布尔面膜,以索引其每个柱面的最大值。要具体,我不需要它的最大值, torch.max(x,dim = 0)
,也不需要其索引, torch.argmax(x,dim = 0)
,但根据此张量最大值对其他张量索引的布尔掩膜。我理想的输出是:
# Input tensor
x
tensor([[0, 3, 0, 5, 9, 8, 2, 0],
[0, 4, 9, 6, 7, 9, 1, 0]])
# Ideal output bool mask tensor
idx
tensor([[1, 0, 0, 0, 1, 0, 1, 1],
[0, 1, 1, 1, 0, 1, 0, 0]])
我知道 values_max = x [idx]
and values_max = x.max(dim = 0)
是等效的,但我不是在寻找 values_max
但对于 idx
。
我已经围绕它建立了一个解决方案,但似乎很复杂,我敢肯定 torch
有一种优化的方法来做到这一点。我尝试使用 torch.index_select
与 x.argmax(dim = 0)
的输出,但失败了,因此我构建了一个自定义解决方案,对我来说似乎很麻烦我正在寻求帮助以矢量化 /张力 /火炬的方式进行此操作。
Dear community I have a challenge with regard to tensor indexing in PyTorch. The problem is very simple. Given a tensor create an index tensor to index its maximum values per column.
x = T.tensor([[0, 3, 0, 5, 9, 8, 2, 0],
[0, 4, 9, 6, 7, 9, 1, 0]])
Given this tensor I would like to build a boolean mask for indexing its maximum values per colum. To be specific I do not need its maximum values, torch.max(x, dim=0)
, nor its indices, torch.argmax(x, dim=0)
, but a boolean mask for indexing other tensor based on this tensor max values. My ideal output would be:
# Input tensor
x
tensor([[0, 3, 0, 5, 9, 8, 2, 0],
[0, 4, 9, 6, 7, 9, 1, 0]])
# Ideal output bool mask tensor
idx
tensor([[1, 0, 0, 0, 1, 0, 1, 1],
[0, 1, 1, 1, 0, 1, 0, 0]])
I know that values_max = x[idx]
and values_max = x.max(dim=0)
are equivalent but I am not looking for values_max
but for idx
.
I have built a solution around it but it just seem to complex and I am sure torch
have an optimized way to do this. I have tried to use torch.index_select
with the output of x.argmax(dim=0)
but failed so I built a custom solution that seems to cumbersome to me so I am asking for help to do this in a vectorized / tensorial / torch way.
发布评论
评论(1)
您可以通过首先提取张量的最大值列的索引来执行此操作noreferrer“>
torch.argmax
,设置keepdim
totrue
,您可以使用
You can perform this operation by first extracting the index of the maximum value column-wise of your tensor with
torch.argmax
, settingkeepdim
toTrue
Then you can use
torch.scatter
to place1
s in a zero tensor at the designated indices: