Numba 函数使用相同的给定输入随机工作,这是一个错误吗?
我在 Numba
中编写了一个名为 not_test
的函数来获取作为排水网络的二维数组列表,然后我得到一个假想的水滴路由从下图可以看出。代码的要点是获取每个可能的排水流的下降路径。
结果
这个区域是我得到的结果,这是水滴落在流的开头时所采用的路由流,例如。然后落在点 1,路由流 [16, 15, 2, 1]。
[[16, 3],
[16, 15, 2, 0],
[16, 15, 2, 1],
[16, 15, 14, 13],
[16, 15, 14, 12, 4],
[16, 15, 14, 12, 11, 6],
[16, 15, 14, 12, 11, 10, 9],
[16, 15, 14, 12, 11, 10, 8, 5],
[16, 15, 14, 12, 11, 10, 8, 7]]
问题
该代码可以在普通 python 上运行,并且在使用 Numba
编译时也可以运行。如果您多次运行使用 Numba 编译的代码,就会出现问题,有时会出现错误,有时会起作用。
我无法在 Numba
中调试代码,并且在 python 模式下没有给出错误。它不会在 python 控制台或 pycharm 运行中显示任何特定错误,它只是停止。
注释的代码肯定不是我遇到的问题的一部分。
我真的很希望能够在这个函数上使用 Numba
,因为它的速度提高了 653 倍,并且这个函数将运行大约 5k 次,这意味着:
with Numba: 0.0015003681182861328s per run -> 7.5s total time
with Python: 0.9321613311767578s per run -> 1.3 hours total time
使用 Numba
对于这个特定问题来说是一个很大帮助,所以我将不胜感激任何帮助,因为普通的 python 不适用于应用程序的使用。
'错误示例'
in Pycharm error:
Now
0.0
0.2295396327972412
[16]
[ 3 15]
[ 3 15]
[ 2 14]
[0 1]
[12 13]
Process finished with exit code -1073740940 (0xC0000374)
in Pycharm no error:
Now
0.0
0.2430422306060791
[16]
[ 3 15]
[ 3 15]
[ 2 14]
[0 1]
[12 13]
[ 4 11]
[ 4 11]
[ 4 11]
[ 6 10]
[ 6 10]
[8 9]
[5 7]
[[16, 3], [16, 15, 2, 0], [16, 15, 2, 1], [16, 15, 14, 13], [16, 15, 14, 12, 4], [16, 15, 14, 12, 11, 6], [16, 15, 14, 12, 11, 10, 9], [16, 15, 14, 12, 11, 10, 8, 5], [16, 15, 14, 12, 11, 10, 8, 7]]
0.0016527080535889
Process finished with exit code 0
代码
文件链接:https://drive.google.com/file/d/1guAe1C2sKZyy2U2_qXAhMA1v46PfeKnN/view
import numpy as np
#from pypiper import RUT_5
import numba
def convert2(x, dtype=np.float64):
try:
# Try and convert x to a Numpy array. If this succeeds
# then we have reached the end of the nesting-depth.
y = np.asarray(x, dtype=dtype)
except:
# If the conversion to a Numpy array fails, then it can
# be because not all elements of x can be converted to
# the given dtype. There is currently no way to distinguish
# if this is because x is a nested list, or just a list
# of simple elements with incompatible types.
# Recursively call this function on all elements of x.
y = [convert2(x_, dtype=dtype) for x_ in x]
# Convert Python list to Numba list.
y = numba.typed.List(y)
return y
@numba.njit('(ListType(float64[:, ::1]), float64[:])')
def not_test(branches, outlet):
# get len of branches
_len_branches = len(branches)
# # empty array
# d_array = np.empty(shape=_len_branches, dtype=np.float64)
# # set outlet coordinates as arrays
# x_outlet, y_outlet = outlet
# x_outlet, y_outlet = np.array([x_outlet]), np.array([y_outlet])
#
# # get min distance from branches
# for pos in numba.prange(_len_branches):
# # get current branch
# branch = branches[pos]
# # get min distance from outlet point
# d_min = RUT_5.nb_cdist(branch, x_outlet, y_outlet).min()
# # add to array
# d_array[pos] = d_min
#
# #get index for minimun distance
# index_branch = np.argmin(d_array)
index_branch = 16
#remove initial branch
update_branches = branches.copy()
del update_branches[index_branch]
#define arrays
not_read = np.empty(shape=0, dtype=np.int64)
paths_update = [[np.int(x)] for x in range(0)]
points = np.empty(shape=(2, 2))
a_list = [np.int(x) for x in range(0)]
# avoid from loop
not_read = np.append(index_branch, not_read)
# iterable in loop
iterable = not_read.copy()
# conditions
cond = 0
cont = 0
while cond == 0:
for pos_idx in iterable:
print(iterable)
if cont > 0:
paths = paths_update.copy()
branch = branches[pos_idx]
points[0] = branch[0]
points[1] = branch[-1]
for point in points:
for pos_j in range(_len_branches):
if pos_j not in not_read:
diff = np.sum(point - branches[pos_j], axis=1)
if 0 in diff:
a_list.append(pos_j)
if cont == 0:
paths = [[pos_idx] + [i] for i in a_list]
paths_update = paths.copy()
cont = cont + 1
not_read = np.append(not_read, a_list)
iterable = np.array(a_list)
a_list = [np.int(x) for x in range(0)]
else:
if len(a_list):
path_arr = [_i for _i in paths if pos_idx in _i]
for path in path_arr:
for conexion in a_list:
temp_list = path.copy()
temp_list.append(conexion)
paths_update.append(temp_list)
paths_update.remove(path)
not_read = np.append(not_read, a_list)
iterable = np.array(a_list)
a_list = [np.int(x) for x in range(0)]
else:
continue
if len(branches) == len(np.unique(not_read)):
cond = 1
return paths
if __name__ == '__main__':
print('Now')
branches = np.load('test.npy', allow_pickle=True).item()
x_snap, y_snap = 717110.7843995667, 9669749.115011858
import time
t0 = time.time()
arr = []
for pos, branch in enumerate(branches.features):
arr.append(list(branch.geometry.coordinates))
print(time.time() - t0)
t0 = time.time()
arr = convert2(arr)
print(time.time() - t0)
t0 = time.time()
outlet = np.array([x_snap, y_snap])
print(not_test(branches=arr, outlet=outlet))
print(time.time() - t0)
I wrote a function called not_test
in Numba
to take a list of 2d arrays that are a drainage network, then I get an imaginary water drop routing from the figure below. The point of the code is to get the path of the drop for every possible drainage stream.
Results
This area the results I am getting, this are the routing stream a water drop would take if it falls in the start of the streams, eg. falls in point 1 then, routing stream [16, 15, 2, 1].
[[16, 3],
[16, 15, 2, 0],
[16, 15, 2, 1],
[16, 15, 14, 13],
[16, 15, 14, 12, 4],
[16, 15, 14, 12, 11, 6],
[16, 15, 14, 12, 11, 10, 9],
[16, 15, 14, 12, 11, 10, 8, 5],
[16, 15, 14, 12, 11, 10, 8, 7]]
Problem
The code works on normal python and it works as well when it is compile with Numba
. The problem comes if you ran several times the code compile with Numba, some times this gives an error and sometimes it does work.
I have not been able to debug the code in Numba
and it gives no error in python mode. And it does not show any particular error in the python console or pycharm run, it just stops.
The code that is commented is sure not to be part of the issue I am experiencing.
I would really like to been able to use Numba
on this function because it has a 653X speed up, and this function will ran around 5k times, this would mean:
with Numba: 0.0015003681182861328s per run -> 7.5s total time
with Python: 0.9321613311767578s per run -> 1.3 hours total time
Using Numba
is a BIG help in this particular issue, so I would appreciate any help, because normal python would not work for the application usage.
'Error example'
in Pycharm error:
Now
0.0
0.2295396327972412
[16]
[ 3 15]
[ 3 15]
[ 2 14]
[0 1]
[12 13]
Process finished with exit code -1073740940 (0xC0000374)
in Pycharm no error:
Now
0.0
0.2430422306060791
[16]
[ 3 15]
[ 3 15]
[ 2 14]
[0 1]
[12 13]
[ 4 11]
[ 4 11]
[ 4 11]
[ 6 10]
[ 6 10]
[8 9]
[5 7]
[[16, 3], [16, 15, 2, 0], [16, 15, 2, 1], [16, 15, 14, 13], [16, 15, 14, 12, 4], [16, 15, 14, 12, 11, 6], [16, 15, 14, 12, 11, 10, 9], [16, 15, 14, 12, 11, 10, 8, 5], [16, 15, 14, 12, 11, 10, 8, 7]]
0.0016527080535889
Process finished with exit code 0
Code
link to file: https://drive.google.com/file/d/1guAe1C2sKZyy2U2_qXAhMA1v46PfeKnN/view
import numpy as np
#from pypiper import RUT_5
import numba
def convert2(x, dtype=np.float64):
try:
# Try and convert x to a Numpy array. If this succeeds
# then we have reached the end of the nesting-depth.
y = np.asarray(x, dtype=dtype)
except:
# If the conversion to a Numpy array fails, then it can
# be because not all elements of x can be converted to
# the given dtype. There is currently no way to distinguish
# if this is because x is a nested list, or just a list
# of simple elements with incompatible types.
# Recursively call this function on all elements of x.
y = [convert2(x_, dtype=dtype) for x_ in x]
# Convert Python list to Numba list.
y = numba.typed.List(y)
return y
@numba.njit('(ListType(float64[:, ::1]), float64[:])')
def not_test(branches, outlet):
# get len of branches
_len_branches = len(branches)
# # empty array
# d_array = np.empty(shape=_len_branches, dtype=np.float64)
# # set outlet coordinates as arrays
# x_outlet, y_outlet = outlet
# x_outlet, y_outlet = np.array([x_outlet]), np.array([y_outlet])
#
# # get min distance from branches
# for pos in numba.prange(_len_branches):
# # get current branch
# branch = branches[pos]
# # get min distance from outlet point
# d_min = RUT_5.nb_cdist(branch, x_outlet, y_outlet).min()
# # add to array
# d_array[pos] = d_min
#
# #get index for minimun distance
# index_branch = np.argmin(d_array)
index_branch = 16
#remove initial branch
update_branches = branches.copy()
del update_branches[index_branch]
#define arrays
not_read = np.empty(shape=0, dtype=np.int64)
paths_update = [[np.int(x)] for x in range(0)]
points = np.empty(shape=(2, 2))
a_list = [np.int(x) for x in range(0)]
# avoid from loop
not_read = np.append(index_branch, not_read)
# iterable in loop
iterable = not_read.copy()
# conditions
cond = 0
cont = 0
while cond == 0:
for pos_idx in iterable:
print(iterable)
if cont > 0:
paths = paths_update.copy()
branch = branches[pos_idx]
points[0] = branch[0]
points[1] = branch[-1]
for point in points:
for pos_j in range(_len_branches):
if pos_j not in not_read:
diff = np.sum(point - branches[pos_j], axis=1)
if 0 in diff:
a_list.append(pos_j)
if cont == 0:
paths = [[pos_idx] + [i] for i in a_list]
paths_update = paths.copy()
cont = cont + 1
not_read = np.append(not_read, a_list)
iterable = np.array(a_list)
a_list = [np.int(x) for x in range(0)]
else:
if len(a_list):
path_arr = [_i for _i in paths if pos_idx in _i]
for path in path_arr:
for conexion in a_list:
temp_list = path.copy()
temp_list.append(conexion)
paths_update.append(temp_list)
paths_update.remove(path)
not_read = np.append(not_read, a_list)
iterable = np.array(a_list)
a_list = [np.int(x) for x in range(0)]
else:
continue
if len(branches) == len(np.unique(not_read)):
cond = 1
return paths
if __name__ == '__main__':
print('Now')
branches = np.load('test.npy', allow_pickle=True).item()
x_snap, y_snap = 717110.7843995667, 9669749.115011858
import time
t0 = time.time()
arr = []
for pos, branch in enumerate(branches.features):
arr.append(list(branch.geometry.coordinates))
print(time.time() - t0)
t0 = time.time()
arr = convert2(arr)
print(time.time() - t0)
t0 = time.time()
outlet = np.array([x_snap, y_snap])
print(not_test(branches=arr, outlet=outlet))
print(time.time() - t0)
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。
绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论
评论(1)
这不是一个真正的答案,因为它没有解决 numba 代码中潜在错误的实际问题,但它完成了工作。
使用
@numba.njit
装饰器时在代码中使用 pop 或 remove list 方法似乎是一个问题,该问题已报告,开发人员正在调试它。我最终避免了这些方法,当然它并不理想,因为它迭代了一些不应该的路径,但它比普通的 python 快得多。
代码
This is not a real answer as it does not addresses the actual problem of the potential bug in the
numba
code, but it gets the job done.It seems to be an issue when using the pop or remove list method in the code while using the
@numba.njit
decorator, this issue was reported and the developers are debugging it.I ended up avoiding these methods, surely it is not ideal as it iterates over some paths it should not, but it is quite faster than normal python.
Code