Numba 错误地解释了生成器函数
尽管Numba 据称支持生成器函数。具体来说,我正在尝试实现 Heap 的算法来生成数组的排列,因为 Numba 不能与 itertools 包一起使用(基本上只是从 Wikipedia):
import numpy as np
from numba import njit
@njit()
def permutations(arr):
c = np.zeros(arr.shape[0], dtype=np.int64)
yield arr
i = 0
while i < arr.shape[0]:
if c[i] < i:
if i % 2 == 0:
temp = arr[1]
arr[1] = arr[i]
arr[i] = temp
else:
temp = arr[c[i]]
arr[c[i]] = arr[i]
arr[i] = temp
yield arr
c[i] += 1
i = 0
else:
c[i] = 0
i += 1
def main():
arr = np.array(['A', 'B', 'C', 'D'])
for permutation in permutations(arr):
print(permutation)
if __name__ == '__main__':
main()
如果没有 njit()
,输出是正确的:
['A' 'B' 'C' 'D']
['B' 'A' 'C' 'D']
['B' 'C' 'A' 'D']
['C' 'B' 'A' 'D']
['C' 'A' 'B' 'D']
['A' 'C' 'B' 'D']
['D' 'C' 'B' 'A']
['C' 'D' 'B' 'A']
['C' 'B' 'D' 'A']
['B' 'C' 'D' 'A']
['B' 'D' 'C' 'A']
['D' 'B' 'C' 'A']
['D' 'A' 'C' 'B']
['A' 'D' 'C' 'B']
['A' 'C' 'D' 'B']
['C' 'A' 'D' 'B']
['C' 'D' 'A' 'B']
['D' 'C' 'A' 'B']
['D' 'C' 'B' 'A']
['C' 'D' 'B' 'A']
['C' 'B' 'D' 'A']
['B' 'C' 'D' 'A']
['B' 'D' 'C' 'A']
['D' 'B' 'C' 'A']
但是,使用njit()
,我明白了:
['A' 'B' 'C' 'D']
['B' 'A' 'C' 'D']
我需要做一些特殊的事情来 JIT 生成器函数吗?我是否错误地使用了生成器的输出来理解 Numba 的解释方式?这是生成器函数的 Numba 实现中的错误吗?
I'm having some trouble JIT-ing a generator function with Numba, despite Numba supposedly having support for generator functions. Specifically, I'm trying to implement Heap's algorithm for generating permutations of an array given that Numba doesn't work with the itertools package (basically just copying pseudocode from Wikipedia):
import numpy as np
from numba import njit
@njit()
def permutations(arr):
c = np.zeros(arr.shape[0], dtype=np.int64)
yield arr
i = 0
while i < arr.shape[0]:
if c[i] < i:
if i % 2 == 0:
temp = arr[1]
arr[1] = arr[i]
arr[i] = temp
else:
temp = arr[c[i]]
arr[c[i]] = arr[i]
arr[i] = temp
yield arr
c[i] += 1
i = 0
else:
c[i] = 0
i += 1
def main():
arr = np.array(['A', 'B', 'C', 'D'])
for permutation in permutations(arr):
print(permutation)
if __name__ == '__main__':
main()
Without njit()
, the output is correct:
['A' 'B' 'C' 'D']
['B' 'A' 'C' 'D']
['B' 'C' 'A' 'D']
['C' 'B' 'A' 'D']
['C' 'A' 'B' 'D']
['A' 'C' 'B' 'D']
['D' 'C' 'B' 'A']
['C' 'D' 'B' 'A']
['C' 'B' 'D' 'A']
['B' 'C' 'D' 'A']
['B' 'D' 'C' 'A']
['D' 'B' 'C' 'A']
['D' 'A' 'C' 'B']
['A' 'D' 'C' 'B']
['A' 'C' 'D' 'B']
['C' 'A' 'D' 'B']
['C' 'D' 'A' 'B']
['D' 'C' 'A' 'B']
['D' 'C' 'B' 'A']
['C' 'D' 'B' 'A']
['C' 'B' 'D' 'A']
['B' 'C' 'D' 'A']
['B' 'D' 'C' 'A']
['D' 'B' 'C' 'A']
However, with njit()
, I get this:
['A' 'B' 'C' 'D']
['B' 'A' 'C' 'D']
Is there something special I need to do to JIT a generator function? Am I using the output of the generator incorrectly for how Numba interprets it? Is this a bug in the Numba implementation of generator functions?
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。
绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论