取出一批指数的矩阵 - python
我们如何提取给定一批索引(在Python)的矩阵的行?
i = [[0,1],[1,2],[2,3]]
a = jnp.array([[1,2,3,4],[2,3,4,5]])
def extract(A,idx):
A = A[:,idx]
return A
B = extract(a,i)
我希望能得到这个结果(矩阵堆叠):
B = [[[1,2],
[2,3]],
[[2,3],
[3,4]],
[3,4],
[4,5]]]
而不是:
B_ = [[1, 2],
[2, 3],
[3, 4]],
[[2, 3],
[3 ,4],
[4, 5]]]
在这种情况下,行被堆叠在一起,但我想堆叠不同的矩阵。
我尝试使用
jax.vmap(提取)(a,i),
但这给了我一个错误,因为A和我没有相同的维度。...是否有替代方案,不使用循环?
How can we extract the rows of a matrix given a batch of indices (in Python)?
i = [[0,1],[1,2],[2,3]]
a = jnp.array([[1,2,3,4],[2,3,4,5]])
def extract(A,idx):
A = A[:,idx]
return A
B = extract(a,i)
I expect to get this result (where the matrices are stacked):
B = [[[1,2],
[2,3]],
[[2,3],
[3,4]],
[3,4],
[4,5]]]
And NOT:
B_ = [[1, 2],
[2, 3],
[3, 4]],
[[2, 3],
[3 ,4],
[4, 5]]]
In this case, the rows are stacked, but I want to stack the different matrices.
I tried using
jax.vmap(extract)(a,i),
but this gives me an error since a and i don't have the same dimension.... Is there an alternative, without using loops?
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论
评论(2)
如果您以正确的方式指定
in_axes
,可以使用vmap
进行此操作,然后将索引列表转换为索引阵列:当您说
in_axes =(none,none,none,none, 0)
,它指定您希望第一个参数是未绘制的,并且您希望第二个参数沿其领先轴映射。您需要将
i
从列表转换为数组或一般 pytree ,它试图在每个数组中的每个数组中的值映射,收藏。You can do this with
vmap
if you specifyin_axes
in the right way, and convert your index list into an index array:When you say
in_axes=(None, 0)
, it specifies that you want the first argument to be unmapped, and you want the second argument to be mapped along its leading axis.The reason you need to convert
i
from a list to an array is because JAX will only map over array arguments: ifvmap
encounters a collection like a list, tuple, dict, or a general pytree, it attempts to map over each array-like value within the collection.您可以立即在矩阵
a
上使用索引:You can use indexing right away on the matrix
a
transposed: