Numba 错误地解释了生成器函数

发布于 2025-01-10 00:10:32 字数 1802 浏览 0 评论 0原文

尽管Numba 据称支持生成器函数。具体来说,我正在尝试实现 Heap 的算法来生成数组的排列,因为 Numba 不能与 itertools 包一起使用(基本上只是从 Wikipedia):

import numpy as np
from numba import njit


@njit()
def permutations(arr):

    c = np.zeros(arr.shape[0], dtype=np.int64)
    yield arr
    i = 0
    while i < arr.shape[0]:
        if c[i] < i:
            if i % 2 == 0:
                temp = arr[1]
                arr[1] = arr[i]
                arr[i] = temp
            else:
                temp = arr[c[i]]
                arr[c[i]] = arr[i]
                arr[i] = temp
            yield arr
            c[i] += 1
            i = 0
        else:
            c[i] = 0
            i += 1


def main():

    arr = np.array(['A', 'B', 'C', 'D'])
    for permutation in permutations(arr):
        print(permutation)


if __name__ == '__main__':
    main()

如果没有 njit(),输出是正确的:

['A' 'B' 'C' 'D']
['B' 'A' 'C' 'D']
['B' 'C' 'A' 'D']
['C' 'B' 'A' 'D']
['C' 'A' 'B' 'D']
['A' 'C' 'B' 'D']
['D' 'C' 'B' 'A']
['C' 'D' 'B' 'A']
['C' 'B' 'D' 'A']
['B' 'C' 'D' 'A']
['B' 'D' 'C' 'A']
['D' 'B' 'C' 'A']
['D' 'A' 'C' 'B']
['A' 'D' 'C' 'B']
['A' 'C' 'D' 'B']
['C' 'A' 'D' 'B']
['C' 'D' 'A' 'B']
['D' 'C' 'A' 'B']
['D' 'C' 'B' 'A']
['C' 'D' 'B' 'A']
['C' 'B' 'D' 'A']
['B' 'C' 'D' 'A']
['B' 'D' 'C' 'A']
['D' 'B' 'C' 'A']

但是,使用njit(),我明白了:

['A' 'B' 'C' 'D']
['B' 'A' 'C' 'D']

我需要做一些特殊的事情来 JIT 生成器函数吗?我是否错误地使用了生成器的输出来理解 Numba 的解释方式?这是生成器函数的 Numba 实现中的错误吗?

I'm having some trouble JIT-ing a generator function with Numba, despite Numba supposedly having support for generator functions. Specifically, I'm trying to implement Heap's algorithm for generating permutations of an array given that Numba doesn't work with the itertools package (basically just copying pseudocode from Wikipedia):

import numpy as np
from numba import njit


@njit()
def permutations(arr):

    c = np.zeros(arr.shape[0], dtype=np.int64)
    yield arr
    i = 0
    while i < arr.shape[0]:
        if c[i] < i:
            if i % 2 == 0:
                temp = arr[1]
                arr[1] = arr[i]
                arr[i] = temp
            else:
                temp = arr[c[i]]
                arr[c[i]] = arr[i]
                arr[i] = temp
            yield arr
            c[i] += 1
            i = 0
        else:
            c[i] = 0
            i += 1


def main():

    arr = np.array(['A', 'B', 'C', 'D'])
    for permutation in permutations(arr):
        print(permutation)


if __name__ == '__main__':
    main()

Without njit(), the output is correct:

['A' 'B' 'C' 'D']
['B' 'A' 'C' 'D']
['B' 'C' 'A' 'D']
['C' 'B' 'A' 'D']
['C' 'A' 'B' 'D']
['A' 'C' 'B' 'D']
['D' 'C' 'B' 'A']
['C' 'D' 'B' 'A']
['C' 'B' 'D' 'A']
['B' 'C' 'D' 'A']
['B' 'D' 'C' 'A']
['D' 'B' 'C' 'A']
['D' 'A' 'C' 'B']
['A' 'D' 'C' 'B']
['A' 'C' 'D' 'B']
['C' 'A' 'D' 'B']
['C' 'D' 'A' 'B']
['D' 'C' 'A' 'B']
['D' 'C' 'B' 'A']
['C' 'D' 'B' 'A']
['C' 'B' 'D' 'A']
['B' 'C' 'D' 'A']
['B' 'D' 'C' 'A']
['D' 'B' 'C' 'A']

However, with njit(), I get this:

['A' 'B' 'C' 'D']
['B' 'A' 'C' 'D']

Is there something special I need to do to JIT a generator function? Am I using the output of the generator incorrectly for how Numba interprets it? Is this a bug in the Numba implementation of generator functions?

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。
列表为空,暂无数据
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文