返回介绍

概述

文章

用户指南

NumPy 参考手册

其他文档

编写自定义数组容器

发布于 2023-06-23 19:27:33 字数 11141 浏览 0 评论 0 收藏 0

NumPy 的分派机制(在numpy版本v1.16中引入)是编写与numpy API兼容并提供numpy功能的自定义实现的自定义N维数组容器的推荐方法。 应用包括 daskopen in new window 数组(分布在多个节点上的N维数组) 和 cupyopen in new window 数组(GPU上的N维数组)。

为了获得编写自定义数组容器的感觉,我们将从一个简单的示例开始,该示例具有相当狭窄的实用程序,但说明了所涉及的概念。

>>> import numpy as np
>>> class DiagonalArray:
...     def __init__(self, N, value):
...         self._N = N
...         self._i = value
...     def __repr__(self):
...         return f"{self.__class__.__name__}(N={self._N}, value={self._i})"
...     def __array__(self):
...         return self._i * np.eye(self._N)
...

我们的自定义数组可以实例化,如下所示:

>>> arr = DiagonalArray(5, 1)
>>> arr
DiagonalArray(N=5, value=1)

我们可以使用 numpy.arrayopen in new windownumpy.asarrayopen in new window, 转换为numpy数组,这将调用它的 __array__ 方法来获得标准 numpy.ndarray

>>> np.asarray(arr)
array([[1., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0.],
       [0., 0., 1., 0., 0.],
       [0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 1.]])

如果我们使用 numpy 函数对 arr 进行操作,numpy 将再次使用 __array__接口将其转换为数组,然后以通常的方式应用该函数。

>>> np.multiply(arr, 2)
array([[2., 0., 0., 0., 0.],
       [0., 2., 0., 0., 0.],
       [0., 0., 2., 0., 0.],
       [0., 0., 0., 2., 0.],
       [0., 0., 0., 0., 2.]])

注意,返回类型是标准 numpy.ndarray

>>> type(arr)
numpy.ndarray

我们如何通过此函数传递我们的自定义数组类型?Numpy允许类指示它希望通过交互 __array_ufunc____array_function__ 以自定义方式处理计算。 让我们一次拿一个,从 __array_ufunc__ 开始。 此方法涵盖 Universal functions (ufunc), 这是一类函数,包括例如 numpy.multiplynumpy.sin

_array_ufunc_ 获得:

  • ufunc, 一个类似 numpy.multiply 的函数
  • method,一个字符串,区分 numpy.multiply(...)。 以及numpy.multiy.outernumpy.multiy.accumate等变体。对于常见情况,numpy.multiply(...)method='__call__'
  • inputs, 可能是不同类型的混合
  • kwargs, 传递给函数的关键字参数

对于这个例子,我们将只处理方法 '__call__

>>> from numbers import Number
>>> class DiagonalArray:
...     def __init__(self, N, value):
...         self._N = N
...         self._i = value
...     def __repr__(self):
...         return f"{self.__class__.__name__}(N={self._N}, value={self._i})"
...     def __array__(self):
...         return self._i * np.eye(self._N)
...     def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
...         if method == '__call__':
...             N = None
...             scalars = []
...             for input in inputs:
...                 if isinstance(input, Number):
...                     scalars.append(input)
...                 elif isinstance(input, self.__class__):
...                     scalars.append(input._i)
...                     if N is not None:
...                         if N != self._N:
...                             raise TypeError("inconsistent sizes")
...                     else:
...                         N = self._N
...                 else:
...                     return NotImplemented
...             return self.__class__(N, ufunc(*scalars, **kwargs))
...         else:
...             return NotImplemented
...

现在让我们的自定义数组类型通过numpy的函数。

>>> arr = DiagonalArray(5, 1)
>>> np.multiply(arr, 3)
DiagonalArray(N=5, value=3)
>>> np.add(arr, 3)
DiagonalArray(N=5, value=4)
>>> np.sin(arr)
DiagonalArray(N=5, value=0.8414709848078965)

此时 arr + 3 不起作用。

>>> arr + 3
TypeError: unsupported operand type(s) for *: 'DiagonalArray' and 'int'

为了支持它,我们需要定义Python接口 __add____lt__ 等,以便调度到相应的ufunc。 我们可以通过继承mixin NDArrayOperatorsMixinopen in new window 来方便地实现这一点。

>>> import numpy.lib.mixins
>>> class DiagonalArray(numpy.lib.mixins.NDArrayOperatorsMixin):
...     def __init__(self, N, value):
...         self._N = N
...         self._i = value
...     def __repr__(self):
...         return f"{self.__class__.__name__}(N={self._N}, value={self._i})"
...     def __array__(self):
...         return self._i * np.eye(self._N)
...     def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
...         if method == '__call__':
...             N = None
...             scalars = []
...             for input in inputs:
...                 if isinstance(input, Number):
...                     scalars.append(input)
...                 elif isinstance(input, self.__class__):
...                     scalars.append(input._i)
...                     if N is not None:
...                         if N != self._N:
...                             raise TypeError("inconsistent sizes")
...                     else:
...                         N = self._N
...                 else:
...                     return NotImplemented
...             return self.__class__(N, ufunc(*scalars, **kwargs))
...         else:
...             return NotImplemented
...
>>> arr = DiagonalArray(5, 1)
>>> arr + 3
DiagonalArray(N=5, value=4)
>>> arr > 0
DiagonalArray(N=5, value=True)

现在让我们来解决 __array_function__。 我们将创建将 numpy 函数映射到我们的自定义变体的 dict。

>>> HANDLED_FUNCTIONS = {}
>>> class DiagonalArray(numpy.lib.mixins.NDArrayOperatorsMixin):
...     def __init__(self, N, value):
...         self._N = N
...         self._i = value
...     def __repr__(self):
...         return f"{self.__class__.__name__}(N={self._N}, value={self._i})"
...     def __array__(self):
...         return self._i * np.eye(self._N)
...     def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
...         if method == '__call__':
...             N = None
...             scalars = []
...             for input in inputs:
...                 # In this case we accept only scalar numbers or DiagonalArrays.
...                 if isinstance(input, Number):
...                     scalars.append(input)
...                 elif isinstance(input, self.__class__):
...                     scalars.append(input._i)
...                     if N is not None:
...                         if N != self._N:
...                             raise TypeError("inconsistent sizes")
...                     else:
...                         N = self._N
...                 else:
...                     return NotImplemented
...             return self.__class__(N, ufunc(*scalars, **kwargs))
...         else:
...             return NotImplemented
...    def __array_function__(self, func, types, args, kwargs):
...        if func not in HANDLED_FUNCTIONS:
...            return NotImplemented
...        # Note: this allows subclasses that don't override
...        # __array_function__ to handle DiagonalArray objects.
...        if not all(issubclass(t, self.__class__) for t in types):
...            return NotImplemented
...        return HANDLED_FUNCTIONS[func](*args, **kwargs)
...

一个便捷的模式是定义一个可用于向 HANDLED_FUNCTIONS 添加函数的装饰器 实现

>>> def implements(np_function):
...    "Register an __array_function__ implementation for DiagonalArray objects."
...    def decorator(func):
...        HANDLED_FUNCTIONS[np_function] = func
...        return func
...    return decorator
...

现在我们为 DiagonalArray 编写numpy函数的实现。 为了完整性,为了支持使用 arr.sum(), 添加一个调用 numpy.sum(self) 的方法 sum,对于 mean 来说也是一样的。

>>> @implements(np.sum)
... def sum(a):
...     "Implementation of np.sum for DiagonalArray objects"
...     return arr._i * arr._N
...
>>> @implements(np.mean)
... def sum(a):
...     "Implementation of np.mean for DiagonalArray objects"
...     return arr._i / arr._N
...
>>> arr = DiagonalArray(5, 1)
>>> np.sum(arr)
5
>>> np.mean(arr)
0.2

如果用户尝试使用 HANDLED_FUNCTIONS 中未包含的任何numpy函数, 则numpy将引发 TypeError,表示不支持此操作。 例如,连接两个 DiagonalArrays 不会产生另一个对角线数组,因此不支持它。

>>> np.concatenate([arr, arr])
TypeError: no implementation found for 'numpy.concatenate' on types that implement __array_function__: [<class '__main__.DiagonalArray'>]

另外,我们的 summean 实现不接受numpy实现的可选参数。

>>> np.sum(arr, axis=0)
TypeError: sum() got an unexpected keyword argument 'axis'

用户总是可以选择使用 numpy.asarray 转换为普通的 numpy.asarray,并使用标准的numpy。

>>> np.concatenate([np.asarray(arr), np.asarray(arr)])
array([[1., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0.],
       [0., 0., 1., 0., 0.],
       [0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 1.],
       [1., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0.],
       [0., 0., 1., 0., 0.],
       [0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 1.]])

有关自定义数组容器的更完整工作示例,请参阅dask源代码open in new windowcupy源代码open in new window

另外可以看一下 NEP 18open in new window

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。
列表为空,暂无数据
    我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
    原文