获取给定批处理值的字典键 - python
我定义了一个字典 a
,并想找到给定值的键 a
:
def dictionary(r):
return dict(enumerate(r))
def get_key(val, my_dict):
for key, value in my_dict.items():
if np.array_equal(val,value):
return key
# dictionary
A = jnp.array([[0, 0],[1,1],[2,2],[3,3]])
A = dictionary(A)
a = jnp.array([[[1, 1],[2, 2], [3,3]],[[0, 0],[3, 3], [2,2]]])
keys = jax.vmap(jax.vmap(get_key, in_axes=(0,None)), in_axes=(0,None))(a, A)
预期输出应为:
键= [[1,2,3],[0,3,2]]
为什么我得到 none
作为输出?
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论
评论(1)
jax像
vmap
通过跟踪的功能一样转换,这意味着它们用该值的抽象表示替换值,以提取函数中编码的操作顺序(请参见如何在jax中思考 </这意味着要与
vmap
正确工作,一个函数只能使用jax方法,而不是numpy方法,因此您使用np.array_equal_equal
破坏了抽象。不幸的是,它实际上没有任何替代品,因为没有机制可以在混凝土python词典中查找抽象的jax值。如果您想对jax值进行确定查找,则应避免转换并使用python循环:
另一方面,我怀疑这更多的是 xy问题;您的目标不是在JAX变换中查找字典值,而是要解决一些问题。也许您应该问一个有关如何解决问题的问题,而不是如何解决您尝试过的解决方案的问题。
但是,如果您愿意不直接使用DICE,则与JAX兼容的替代
get_key
实现可能看起来像这样:JAX transforms like
vmap
work by tracing the function, meaning they replace the value with an abstract representation of the value to extract the sequence of operations encoded in the function (See How to think in JAX for a good intro to this concept).What this means is that to work correctly with
vmap
, a function can only use JAX methods, not numpy methods, so your use ofnp.array_equal
breaks the abstraction.Unfortunately, there's not really any replacement for it, because there's no mechanism to look up an abstract JAX value in a concrete Python dictionary. If you want to do dict lookups of JAX values, you should avoid transforms and just use Python loops:
On the other hand, I suspect this is more of an XY problem; your goal is not to look up dictionary values within a jax transform, but rather to solve some problem. Perhaps you should ask a question about how to solve the problem, rather than how to get around an issue with the solution you have tried.
But if you're willing to not directly use the dict, an alternative
get_key
implementation that is compatible with JAX might look something like this: