pytorch用argmax索引

发布于 2025-02-07 12:19:30 字数 986 浏览 3 评论 0 原文

亲爱的社区,我在pytorch的张量索引方面面临着挑战。问题很简单。给定张量创建索引张量,以索引每列的最大值。

x = T.tensor([[0, 3, 0, 5, 9, 8, 2, 0], 
              [0, 4, 9, 6, 7, 9, 1, 0]])

鉴于此张量,我想构建一个布尔面膜,以索引其每个柱面的最大值。要具体,我不需要它的最大值, torch.max(x,dim = 0),也不需要其索引, torch.argmax(x,dim = 0) ,但根据此张量最大值对其他张量索引的布尔掩膜。我理想的输出是:

# Input tensor
x
tensor([[0, 3, 0, 5, 9, 8, 2, 0],
        [0, 4, 9, 6, 7, 9, 1, 0]])

# Ideal output bool mask tensor
idx
tensor([[1, 0, 0, 0, 1, 0, 1, 1],
        [0, 1, 1, 1, 0, 1, 0, 0]])

我知道 values_max = x [idx] and values_max = x.max(dim = 0)是等效的,但我不是在寻找 values_max 但对于 idx

我已经围绕它建立了一个解决方案,但似乎很复杂,我敢肯定 torch 有一种优化的方法来做到这一点。我尝试使用 torch.index_select x.argmax(dim = 0)的输出,但失败了,因此我构建了一个自定义解决方案,对我来说似乎很麻烦我正在寻求帮助以矢量化 /张力 /火炬的方式进行此操作。

Dear community I have a challenge with regard to tensor indexing in PyTorch. The problem is very simple. Given a tensor create an index tensor to index its maximum values per column.

x = T.tensor([[0, 3, 0, 5, 9, 8, 2, 0], 
              [0, 4, 9, 6, 7, 9, 1, 0]])

Given this tensor I would like to build a boolean mask for indexing its maximum values per colum. To be specific I do not need its maximum values, torch.max(x, dim=0), nor its indices, torch.argmax(x, dim=0), but a boolean mask for indexing other tensor based on this tensor max values. My ideal output would be:

# Input tensor
x
tensor([[0, 3, 0, 5, 9, 8, 2, 0],
        [0, 4, 9, 6, 7, 9, 1, 0]])

# Ideal output bool mask tensor
idx
tensor([[1, 0, 0, 0, 1, 0, 1, 1],
        [0, 1, 1, 1, 0, 1, 0, 0]])

I know that values_max = x[idx] and values_max = x.max(dim=0) are equivalent but I am not looking for values_max but for idx.

I have built a solution around it but it just seem to complex and I am sure torch have an optimized way to do this. I have tried to use torch.index_select with the output of x.argmax(dim=0) but failed so I built a custom solution that seems to cumbersome to me so I am asking for help to do this in a vectorized / tensorial / torch way.

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(1

萝莉病 2025-02-14 12:19:30

您可以通过首先提取张量的最大值列的索引来执行此操作noreferrer“> torch.argmax ,设置 keepdim to true

>>> x.argmax(0, keepdim=True)
tensor([[0, 1, 1, 1, 0, 1, 0, 0]])

,您可以使用

>>> torch.zeros_like(x).scatter(0, x.argmax(0,True), value=1)
tensor([[1, 0, 0, 0, 1, 0, 1, 1],
        [0, 1, 1, 1, 0, 1, 0, 0]])

You can perform this operation by first extracting the index of the maximum value column-wise of your tensor with torch.argmax, setting keepdim to True

>>> x.argmax(0, keepdim=True)
tensor([[0, 1, 1, 1, 0, 1, 0, 0]])

Then you can use torch.scatter to place 1s in a zero tensor at the designated indices:

>>> torch.zeros_like(x).scatter(0, x.argmax(0,True), value=1)
tensor([[1, 0, 0, 0, 1, 0, 1, 1],
        [0, 1, 1, 1, 0, 1, 0, 0]])
~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文