泡菜更改类型JAX

发布于 2025-01-28 04:20:17 字数 505 浏览 4 评论 0 原文

我有一个包含jax numpy数组的亚麻结构数据类。

当我腌制此对象并再次加载它时,数组不再是jax numpy数组,并且转换为numpy数组,这是重现它的代码:

import flax
import jax.numpy as jnp
import pickle

@flax.struct.dataclass
class A:
    data: jnp.ndarray

a = A(data=jnp.zeros((2,2)))
print(a, type(a.data))



with open('file.pickle', 'wb') as handle:
    pickle.dump(a, handle)
    
with open('file.pickle', 'rb') as handle:
    loaded_a = pickle.load(handle)

print(loaded_a, type(loaded_a.data))

我不想要这种行为,我希望它能保持其原始类型,有可能吗?

I have a flax struct dataclass containing a jax numpy array.

When I pickle dump this object and load it again, the array is not anymore a jax numpy array and is converted to a numpy array, here is the code to reproduce it:

import flax
import jax.numpy as jnp
import pickle

@flax.struct.dataclass
class A:
    data: jnp.ndarray

a = A(data=jnp.zeros((2,2)))
print(a, type(a.data))



with open('file.pickle', 'wb') as handle:
    pickle.dump(a, handle)
    
with open('file.pickle', 'rb') as handle:
    loaded_a = pickle.load(handle)

print(loaded_a, type(loaded_a.data))

I don't want this behavior and I'd like it to keep its original type, is it possible ?

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(1

陪你到最终 2025-02-04 04:20:17

更新:此错误已在。从JAX(v。0.3.14)的下一个版本


开始在jax; https://github.com/google/google/jax/jax/issues/2632

请参阅 图书馆开发人员认为是一种不幸的行为,但尚未确定修复程序。如果您有兴趣,您可能会在这个问题上权衡。

Update: this bug has been fixed in https://github.com/google/jax/pull/10659. Starting in the next release of JAX (v. 0.3.14) pickle and deepcopy should no longer convert JAX arrays to device arrays.


This is a known behavior in JAX; see https://github.com/google/jax/issues/2632

It's something that the library developers recognize as an unfortunate behavior, but a fix has not yet been prioritized. If you're interested, you might weigh-in on that issue.

~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文