检查是否订购了2D子阵列-Pyhthon jax

发布于 2025-02-01 05:01:59 字数 656 浏览 6 评论 0原文

让我们假设我们有一个数组订购的。我们要检查子阵列tt_inv是否遵循与order> order> order array中的订单相同的顺序。

从左到右读取:第一个元素是[0,0]等,直到[0,3]t_inv被颠倒了,因为第一个到元素被互换,它们不像订购中的订购一样。

# imposed order 
ordered = jnp.array([[0, 0],[0,1],[0,2],[0,3]])

# array with permuted order
t = jnp.array([[[0, 0],[0, 1], [0,3]]])
t_inv = jnp.array([[[0, 1],[0, 0], [0,3]]])

我期望以下内容:

 result: ordered(t) = 1, because "ordered"  
and ordered(t_inv) = -1, because "swapped/not ordered"

您如何检查子阵列是否确实是订购数组的一部分,并且该订单是否正确?

Let us suppose that we have an array ordered. We want to check if the sub-arrays t and t_inv are following the same order as the imposed order inorder array.

Reading from left to right: the first element is [0,0] and so on until [0,3].
t_inv is inversed because the first to elements are swapped, they do not follow the ordering as in ordered.

# imposed order 
ordered = jnp.array([[0, 0],[0,1],[0,2],[0,3]])

# array with permuted order
t = jnp.array([[[0, 0],[0, 1], [0,3]]])
t_inv = jnp.array([[[0, 1],[0, 0], [0,3]]])

I expect the following:

 result: ordered(t) = 1, because "ordered"  
and ordered(t_inv) = -1, because "swapped/not ordered"

How can you check that the sub arrays are indeed part of the ordered array and ouput whether the order is correct or not?

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(1

稚然 2025-02-08 05:01:59

您可以做类似的事情:

import jax.numpy as jnp

# imposed order 
ordered = jnp.array([[0, 0],[0,1],[0,2],[0,3]])

# array with permuted order
t = jnp.array([[0, 0],[0, 1], [0,3]])
t_inv = jnp.array([[0, 1],[0, 0], [0,3]])


def is_sorted(t, ordered):
  index = jnp.where((t[:, None] == ordered).all(-1))[1]
  return jnp.where((index == jnp.sort(index)).all(), 1, -1)

print(is_sorted(t, ordered))
# 1
print(is_sorted(t_inv, ordered))
# -1

从缩放方面,使用基于searchSorted的解决方案可能会更快,但是JAX中的JNP.SearchSorted的当前实现相对较慢。由于XLA没有任何本机二进制搜索算法,因此在实践中,整个成对比较通常可以更具性能。

You could do something like this:

import jax.numpy as jnp

# imposed order 
ordered = jnp.array([[0, 0],[0,1],[0,2],[0,3]])

# array with permuted order
t = jnp.array([[0, 0],[0, 1], [0,3]])
t_inv = jnp.array([[0, 1],[0, 0], [0,3]])


def is_sorted(t, ordered):
  index = jnp.where((t[:, None] == ordered).all(-1))[1]
  return jnp.where((index == jnp.sort(index)).all(), 1, -1)

print(is_sorted(t, ordered))
# 1
print(is_sorted(t_inv, ordered))
# -1

Scaling-wise, it might be faster to use a solution based on searchsorted, but the current implementation of jnp.searchsorted in JAX is relatively slow because XLA doesn't have any native binary search algorithm, so in practice the full pairwise comparison can often be more performant.

~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文