获取给定批处理值的字典键 - python

发布于 2025-02-01 23:42:35 字数 565 浏览 5 评论 0 原文

我定义了一个字典 a ,并想找到给定值的键 a

def dictionary(r):
 return dict(enumerate(r))

def get_key(val, my_dict):
   for key, value in my_dict.items():
      if np.array_equal(val,value):
          return key
    

 # dictionary
 A = jnp.array([[0, 0],[1,1],[2,2],[3,3]])
 A = dictionary(A)

 a = jnp.array([[[1, 1],[2, 2], [3,3]],[[0, 0],[3, 3], [2,2]]])
 keys = jax.vmap(jax.vmap(get_key, in_axes=(0,None)), in_axes=(0,None))(a, A)

预期输出应为: 键= [[1,2,3],[0,3,2]]

为什么我得到 none 作为输出?

I defined a dictionary A and would like to find the keys given a batch of values a:

def dictionary(r):
 return dict(enumerate(r))

def get_key(val, my_dict):
   for key, value in my_dict.items():
      if np.array_equal(val,value):
          return key
    

 # dictionary
 A = jnp.array([[0, 0],[1,1],[2,2],[3,3]])
 A = dictionary(A)

 a = jnp.array([[[1, 1],[2, 2], [3,3]],[[0, 0],[3, 3], [2,2]]])
 keys = jax.vmap(jax.vmap(get_key, in_axes=(0,None)), in_axes=(0,None))(a, A)

The expected output should be:
keys = [[1,2,3],[0,3,2]]

Why am I getting Noneas an output?

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(1

浊酒尽余欢 2025-02-08 23:42:35

jax像 vmap 通过跟踪的功能一样转换,这意味着它们用该值的抽象表示替换值,以提取函数中编码的操作顺序(请参见如何在jax中思考 </

这意味着要与 vmap 正确工作,一个函数只能使用jax方法,而不是numpy方法,因此您使用 np.array_equal_equal 破坏了抽象。

不幸的是,它实际上没有任何替代品,因为没有机制可以在混凝土python词典中查找抽象的jax值。如果您想对jax值进行确定查找,则应避免转换并使用python循环:

keys = jnp.array([[get_key(x, A) for x in row] for row in a])

另一方面,我怀疑这更多的是 xy问题;您的目标不是在JAX变换中查找字典值,而是要解决一些问题。也许您应该问一个有关如何解决问题的问题,而不是如何解决您尝试过的解决方案的问题。

但是,如果您愿意不直接使用DICE,则与JAX兼容的替代 get_key 实现可能看起来像这样:

def get_key(val, my_dict):
  keys = jnp.array(list(my_dict.keys()))
  values = jnp.array(list(my_dict.values()))
  return keys[jnp.where((values == val).all(-1), size=1)]

JAX transforms like vmap work by tracing the function, meaning they replace the value with an abstract representation of the value to extract the sequence of operations encoded in the function (See How to think in JAX for a good intro to this concept).

What this means is that to work correctly with vmap, a function can only use JAX methods, not numpy methods, so your use of np.array_equal breaks the abstraction.

Unfortunately, there's not really any replacement for it, because there's no mechanism to look up an abstract JAX value in a concrete Python dictionary. If you want to do dict lookups of JAX values, you should avoid transforms and just use Python loops:

keys = jnp.array([[get_key(x, A) for x in row] for row in a])

On the other hand, I suspect this is more of an XY problem; your goal is not to look up dictionary values within a jax transform, but rather to solve some problem. Perhaps you should ask a question about how to solve the problem, rather than how to get around an issue with the solution you have tried.

But if you're willing to not directly use the dict, an alternative get_key implementation that is compatible with JAX might look something like this:

def get_key(val, my_dict):
  keys = jnp.array(list(my_dict.keys()))
  values = jnp.array(list(my_dict.values()))
  return keys[jnp.where((values == val).all(-1), size=1)]
~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文