numba njit 的混合数据类型输入

发布于 2025-01-11 04:23:12 字数 1916 浏览 0 评论 0原文

我有一个大数组用于操作,例如矩阵转置。 numba 更快:

#test_transpose.py
import numpy as np
import numba as nb
import time


@nb.njit('float64[:,:](float64[:,:])', parallel=True)
def transpose(x):
    r, c = x.shape
    x2 = np.zeros((c, r))
    for i in nb.prange(c):
        for j in range(r):
            x2[i, j] = x[j][i]
   return x2


if __name__ == "__main__":
    x = np.random.randn(int(3e6), 50)
    t = time.time()
    x = x.transpose().copy()
    print(f"numpy transpose: {round(time.time() - t, 4)} secs")

    x = np.random.randn(int(3e6), 50)
    t = time.time()
    x = transpose(x)
    print(f"numba paralleled transpose: {round(time.time() - t, 4)} secs")

在 Windows 命令提示符中运行

D:\data\test>python test_transpose.py
numpy transpose: 2.0961 secs
numba paralleled transpose: 0.8584 secs

但是,我想输入另一个大矩阵,它们是整数,使用 x 因为

x = np.random.randint(int(3e6), size=(int(3e6), 50), dtype=np.int64)

引发异常,因为

Traceback (most recent call last):
  File "test_transpose.py", line 39, in <module>
    x = transpose(x)
  File "C:\Program Files\Python38\lib\site-packages\numba\core\dispatcher.py", line 703, in _explain_matching_error
    raise TypeError(msg)
TypeError: No matching definition for argument type(s) array(int64, 2d, C)

它无法识别输入数据矩阵为整数。如果我释放整数矩阵的数据类型检查,因为

@nb.njit(parallel=True) # 'float64[:,:](float64[:,:])'
def transpose(x):
    r, c = x.shape
    x2 = np.zeros((c, r))
    for i in nb.prange(c):
        for j in range(r):
            x2[i, j] = x[j][i]
    return x2

它速度较慢:

D:\Data\test>python test_transpose.py
numba paralleled transpose: 1.6653 secs

使用 @nb.njit('int64[:,:](int64[:,:])', parallel=True)正如预期的那样,整数数据矩阵更快。

那么,我如何仍然允许混合数据类型输入但保持速度,而不是为不同类型分别创建函数?

I have a large array for operation, for example, matrix transpose. numba is much faster:

#test_transpose.py
import numpy as np
import numba as nb
import time


@nb.njit('float64[:,:](float64[:,:])', parallel=True)
def transpose(x):
    r, c = x.shape
    x2 = np.zeros((c, r))
    for i in nb.prange(c):
        for j in range(r):
            x2[i, j] = x[j][i]
   return x2


if __name__ == "__main__":
    x = np.random.randn(int(3e6), 50)
    t = time.time()
    x = x.transpose().copy()
    print(f"numpy transpose: {round(time.time() - t, 4)} secs")

    x = np.random.randn(int(3e6), 50)
    t = time.time()
    x = transpose(x)
    print(f"numba paralleled transpose: {round(time.time() - t, 4)} secs")

Run in Windows command prompt

D:\data\test>python test_transpose.py
numpy transpose: 2.0961 secs
numba paralleled transpose: 0.8584 secs

However, I want to input another large matrix, which are integers, using x as

x = np.random.randint(int(3e6), size=(int(3e6), 50), dtype=np.int64)

Exception is raised as

Traceback (most recent call last):
  File "test_transpose.py", line 39, in <module>
    x = transpose(x)
  File "C:\Program Files\Python38\lib\site-packages\numba\core\dispatcher.py", line 703, in _explain_matching_error
    raise TypeError(msg)
TypeError: No matching definition for argument type(s) array(int64, 2d, C)

It does not recognize the input data matrix as integer. If I release the data type check for the integer matrix as

@nb.njit(parallel=True) # 'float64[:,:](float64[:,:])'
def transpose(x):
    r, c = x.shape
    x2 = np.zeros((c, r))
    for i in nb.prange(c):
        for j in range(r):
            x2[i, j] = x[j][i]
    return x2

It is slower:

D:\Data\test>python test_transpose.py
numba paralleled transpose: 1.6653 secs

Using @nb.njit('int64[:,:](int64[:,:])', parallel=True) for the integer data matrix is faster, as expected.

So, how can I still allow mixed data type intputs but keep the speed, instead of creating functions each for different types?

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(1

暗地喜欢 2025-01-18 04:23:12

那么,我如何仍然允许混合数据类型输入但保持速度,而不是为不同类型分别创建函数?

问题在于 Numba 函数仅针对 float64 类型定义,而不是针对 int64 类型。类型规范是必需的,因为 Numba 将 Python 代码编译为具有明确定义的类型的本机代码。您可以向 Numba 函数添​​加多个签名

@nb.njit(['float64[:,:](float64[:,:])', 'int64[:,:](int64[:,:])'], parallel=True)
def transpose(x):
    r, c = x.shape
    # Specifying the dtype is very important here.
    # This is a good habit to take to avoid numerical issues and 
    # slower performance in Numpy too.
    x2 = np.zeros((c, r), dtype=x.dtype)
    for i in nb.prange(c):
        for j in range(r):
            x2[i, j] = x[j][i]
   return x2

比较慢

这是因为延迟编译。第一次执行包括编译时间。当指定签名时,情况并非如此,因为使用了即时编译。

numba 更快

考虑到使用了很多核心,所以这里没有太多。事实上,朴素转置在大矩阵上效率非常低(在这种情况下,在大数组上浪费了大约 90% 的内存吞吐量)。有更快的算法。有关更多信息,请阅读这篇文章 (它只考虑就地 2D 平方转置,这要简单得多,但想法是相同的)。另请注意,类型越宽,数组越大。数组越大,转置速度越慢。

So, how can I still allow mixed data type intputs but keep the speed, instead of creating functions each for different types?

The problem is that the Numba function is defined only for float64 types and not int64. The specification of the types is required because Numba compile the Python code to a native code with well-defined types. You can add multiple signatures to a Numba function:

@nb.njit(['float64[:,:](float64[:,:])', 'int64[:,:](int64[:,:])'], parallel=True)
def transpose(x):
    r, c = x.shape
    # Specifying the dtype is very important here.
    # This is a good habit to take to avoid numerical issues and 
    # slower performance in Numpy too.
    x2 = np.zeros((c, r), dtype=x.dtype)
    for i in nb.prange(c):
        for j in range(r):
            x2[i, j] = x[j][i]
   return x2

It is slower

This is because of lazy compilation. The first execution include the compilation time. THis is not the case when the signature is specified because of eager compilation is used instead.

numba is much faster

Well, not to much here considering many cores are used. In fact, the naive transposition is very inefficient on big matrices (is wast about 90% of the memory throughput in this case on large arrays). There are faster algorithms. For more information, please read this post (it only consider in-place 2D square transposition which is much simpler but the idea is the same). Also note that the wider the type, the bigger the array. The bigger the array the slower the transposition.

~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文