为什么我的numba jit函数被识别为数组?
因此,我试图加快代码速度,虽然它适用于大多数功能,但有一些功能不起作用。我指定了函数的签名,但它不起作用。如果我只写nb.njit,它可以工作,但根本没有加速,甚至有轻微的减慢。当按照我在下面发布的代码中指定签名时,出现以下错误:
类型错误:装饰对象不是函数(获取类型
)。
nb.njit(nb.int32[:],nb.int32[:],nb.int32[:],nb.int32[:,;](nb.int8,nb.int32[:,:],nb.int32[:,:],nb.int32[:,:],nb.int64,nb.float64[:,:])
def IdentifyVertsAndCells(i,cell_verts,vert_cells,vert_neighs,T1_edge,l):
#Find vertices undergoing T1
if T1_edge == len(l) - 1:
T1_verts = cell_verts[i,[T1_edge,0]]
else:
T1_verts = cell_verts[i,[T1_edge,T1_edge+1]]
#Identify the four cells are affected by transition
dummy = np.concatenate((vert_cells[T1_verts[0]],vert_cells[T1_verts[1]]))
T1cells = np.unique(dummy)
#Identify cells that are neighbours prior to transition (that won't be afterwards)
old_neigh = np.intersect1d(vert_cells[T1_verts[0]],vert_cells[T1_verts[1]])
#Identify cells that will be neighbours after transition
notneigh1 = T1cells[T1cells != old_neigh[0]]
notneigh2 = T1cells[T1cells != old_neigh[1]]
new_neigh = np.intersect1d(notneigh1,notneigh2)
old_vert_neighs = vert_neighs[T1_verts,:]
return T1_verts, old_neigh, new_neigh, old_vert_neighs
我检查了输入数组和数字的大小和数据类型,并确信我没有犯错误。我想补充一点,对于 int8 类型的数量,我必须使用 j = np.asarray([i],dtype='int8')[0] 将 int 更改为 int8,因为我没有找到类型int,但我为int8做了。我的代码中的输入数字 i 对应于 j 并且确实是 int8 类型。当我只在我的函数上使用 inform.isfunction 时,它会将其识别为函数。
下面是调用上述函数的代码:
def UpdateTopology(points,verts,vert_neighs,vert_cells,cell_verts,l,x_max,y_max,T1_thresh,N):
for i in range(N):
#Determine how many vertices are in cell i
vert_inds = cell_verts[i,:] < 2*N
j = np.array([i]).astype('int8')[0]
#If cell i only has three sides don't perform a T1 on it (you can't have a cell with 2 sides)
if(len(vert_inds) == 3):
continue
#Find UP TO DATE vertex coords of cell i (can't use cell_vert_coords as
#vertices will have changed due to previous T1 transitions)
vert_inds = cell_verts[i,:] < 2*N
cell_i = verts[cell_verts[i,vert_inds],:]
#Shift vertex coords to account for periodicity
rel_dists = cell_i - points[i,:]
cell_i = ShiftCoords(cell_i,rel_dists,x_max,y_max,4)
#Calculate the lengths, l, of each cell edge
shifted_verts = np.roll(cell_i,-1,axis=0)
l = shifted_verts - cell_i
l_mag = np.linalg.norm(l,axis=1)
#Determine if any edges are below threshold length
to_change = np.nonzero(l_mag < T1_thresh)
#print('l = ',l)
#print('T1_thresh = ', T1_thresh)
if len(to_change[0]) == 0:
continue
else:
T1_edge = to_change[0][0]
#Identify vertices and cells affected, also return vert_neighs of old neighbours (to be used when updating vert_neighs later on)
T1_verts, old_neigh, new_neigh, old_vert_neighs = T1f.IdentifyVertsAndCells(i,cell_verts,vert_cells,vert_neighs,T1_edge,l)
#Update vertex coordinates
verts = T1f.UpdateVertCoords(j,verts,points,cell_i,old_neigh,T1_verts,T1_thresh,l,T1_edge,x_max,y_max)
#Update vert_cells
vert_cells = T1f.UpdateVertCells(verts,points,vert_cells,T1_verts,old_neigh,new_neigh)
#Update cell_verts
cell_verts = T1f.UpdateCellVerts(verts,points,cell_verts,T1_verts,old_neigh,new_neigh,N)
#Update vert_neighs
vert_neighs = T1f.UpdateVertNeighs(vert_neighs,points,cell_verts,T1_verts,old_neigh,new_neigh,old_vert_neighs,N)
return verts, vert_neighs, vert_cells, cell_verts
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。
绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论
评论(1)
您收到的错误与您向
njit
装饰器提供的错误定义的签名有关。每当您的 jitted 函数返回多个值时,您都必须将返回类型定义为同构元组(如果所有返回类型都相同)或异构元组(如果返回类型不同)(请参阅 这个答案)。关于加速,您不会通过此代码示例获得任何加速:相反,您会减速。我能认识到的主要原因有以下两个:
numba
在加速循环方面效果很好。如果您的代码已经矢量化,那么您可能不会通过njit
调整您的函数来获得任何速度提升。The error you're getting is related to the wrongly-defined signature that you're providing to the
njit
decorator. Whenever your jitted function returns several values, you have to define the return type as either an homogeneous tuple (if all the return types are the same) or heterogeneous tuple (if the return types are different) (see this answer).Regarding the speedup, you won't get any with this code sample: rather you'll get a slowdown. The main reasons I can recognize are the following two:
numpy
's standard functions, that are already highly optimized. As a rule of thumb,numba
works well in speeding up loops. If you're code is already vectorized, you probably won't get any speed improvement bynjit
ting your function.cell_verts[i,[T1_edge,T1_edge+1]]
, which produce array copies that require memory allocation, a task to whichnumba
is not really good at (see this answer).