泡菜更改类型JAX
我有一个包含jax numpy数组的亚麻结构数据类。
当我腌制此对象并再次加载它时,数组不再是jax numpy数组,并且转换为numpy数组,这是重现它的代码:
import flax
import jax.numpy as jnp
import pickle
@flax.struct.dataclass
class A:
data: jnp.ndarray
a = A(data=jnp.zeros((2,2)))
print(a, type(a.data))
with open('file.pickle', 'wb') as handle:
pickle.dump(a, handle)
with open('file.pickle', 'rb') as handle:
loaded_a = pickle.load(handle)
print(loaded_a, type(loaded_a.data))
我不想要这种行为,我希望它能保持其原始类型,有可能吗?
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论
评论(1)
更新:此错误已在。从JAX(v。0.3.14)的下一个版本
开始在jax; https://github.com/google/google/jax/jax/issues/2632
请参阅 图书馆开发人员认为是一种不幸的行为,但尚未确定修复程序。如果您有兴趣,您可能会在这个问题上权衡。
Update: this bug has been fixed in https://github.com/google/jax/pull/10659. Starting in the next release of JAX (v. 0.3.14)
pickle
anddeepcopy
should no longer convert JAX arrays to device arrays.This is a known behavior in JAX; see https://github.com/google/jax/issues/2632
It's something that the library developers recognize as an unfortunate behavior, but a fix has not yet been prioritized. If you're interested, you might weigh-in on that issue.