取出一批指数的矩阵 - python

发布于 2025-01-26 13:53:12 字数 596 浏览 3 评论 0原文

我们如何提取给定一批索引(在Python)的矩阵的行?

i = [[0,1],[1,2],[2,3]]
a = jnp.array([[1,2,3,4],[2,3,4,5]])


def extract(A,idx):
    A = A[:,idx]
    return A

B = extract(a,i)

我希望能得到这个结果(矩阵堆叠):

B = [[[1,2],
      [2,3]],

      [[2,3],
       [3,4]],

      [3,4],
      [4,5]]]

而不是:

  B_ = [[1, 2],
     [2, 3],
     [3, 4]],

     [[2, 3],
     [3 ,4],
     [4, 5]]]

在这种情况下,行被堆叠在一起,但我想堆叠不同的矩阵。

我尝试使用

jax.vmap(提取)(a,i),

但这给了我一个错误,因为A和我没有相同的维度。...是否有替代方案,不使用循环?

How can we extract the rows of a matrix given a batch of indices (in Python)?

i = [[0,1],[1,2],[2,3]]
a = jnp.array([[1,2,3,4],[2,3,4,5]])


def extract(A,idx):
    A = A[:,idx]
    return A

B = extract(a,i)

I expect to get this result (where the matrices are stacked):

B = [[[1,2],
      [2,3]],

      [[2,3],
       [3,4]],

      [3,4],
      [4,5]]]

And NOT:

  B_ = [[1, 2],
     [2, 3],
     [3, 4]],

     [[2, 3],
     [3 ,4],
     [4, 5]]]

In this case, the rows are stacked, but I want to stack the different matrices.

I tried using

jax.vmap(extract)(a,i),

but this gives me an error since a and i don't have the same dimension.... Is there an alternative, without using loops?

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(2

半仙 2025-02-02 13:53:12

如果您以正确的方式指定in_axes,可以使用vmap进行此操作,然后将索引列表转换为索引阵列:

vmap(extract, in_axes=(None, 0))(a, jnp.array(i))
# DeviceArray([[[1, 2],
#               [2, 3]],
# 
#              [[2, 3],
#               [3, 4]],
# 
#              [[3, 4],
#               [4, 5]]], dtype=int32)

当您说in_axes =(none,none,none,none, 0),它指定您希望第一个参数是未绘制的,并且您希望第二个参数沿其领先轴映射。

您需要将i从列表转换为数组或一般 pytree ,它试图在每个数组中的每个数组中的值映射,收藏。

You can do this with vmap if you specify in_axes in the right way, and convert your index list into an index array:

vmap(extract, in_axes=(None, 0))(a, jnp.array(i))
# DeviceArray([[[1, 2],
#               [2, 3]],
# 
#              [[2, 3],
#               [3, 4]],
# 
#              [[3, 4],
#               [4, 5]]], dtype=int32)

When you say in_axes=(None, 0), it specifies that you want the first argument to be unmapped, and you want the second argument to be mapped along its leading axis.

The reason you need to convert i from a list to an array is because JAX will only map over array arguments: if vmap encounters a collection like a list, tuple, dict, or a general pytree, it attempts to map over each array-like value within the collection.

放飞的风筝 2025-02-02 13:53:12

您可以立即在矩阵a上使用索引:

a.T[i,:]

You can use indexing right away on the matrix a transposed:

a.T[i,:]
~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文