为什么我的numba jit函数被识别为数组?

发布于 2025-01-19 14:38:56 字数 3773 浏览 0 评论 0 原文

因此,我试图加快代码速度,虽然它适用于大多数功能,但有一些功能不起作用。我指定了函数的签名,但它不起作用。如果我只写nb.njit,它可以工作,但根本没有加速,甚至有轻微的减慢。当按照我在下面发布的代码中指定签名时,出现以下错误:

类型错误:装饰对象不是函数(获取类型 )。


nb.njit(nb.int32[:],nb.int32[:],nb.int32[:],nb.int32[:,;](nb.int8,nb.int32[:,:],nb.int32[:,:],nb.int32[:,:],nb.int64,nb.float64[:,:])
def IdentifyVertsAndCells(i,cell_verts,vert_cells,vert_neighs,T1_edge,l):
    #Find vertices undergoing T1
    if T1_edge == len(l) - 1:
        T1_verts = cell_verts[i,[T1_edge,0]]
    else:
        T1_verts = cell_verts[i,[T1_edge,T1_edge+1]]
    
    #Identify the four cells are affected by transition
    dummy = np.concatenate((vert_cells[T1_verts[0]],vert_cells[T1_verts[1]]))
    T1cells = np.unique(dummy)

    #Identify cells that are neighbours prior to transition (that won't be afterwards)
    old_neigh = np.intersect1d(vert_cells[T1_verts[0]],vert_cells[T1_verts[1]])
    
    #Identify cells that will be neighbours after transition
    notneigh1 = T1cells[T1cells != old_neigh[0]]
    notneigh2 = T1cells[T1cells != old_neigh[1]]
    new_neigh = np.intersect1d(notneigh1,notneigh2)
    old_vert_neighs = vert_neighs[T1_verts,:]
    return T1_verts, old_neigh, new_neigh, old_vert_neighs

我检查了输入数组和数字的大小和数据类型,并确信我没有犯错误。我想补充一点,对于 int8 类型的数量,我必须使用 j = np.asarray([i],dtype='int8')[0] 将 int 更改为 int8,因为我没有找到类型int,但我为int8做了。我的代码中的输入数字 i 对应于 j 并且确实是 int8 类型。当我只在我的函数上使用 inform.isfunction 时,它会将其识别为函数。

下面是调用上述函数的代码:

def UpdateTopology(points,verts,vert_neighs,vert_cells,cell_verts,l,x_max,y_max,T1_thresh,N):
    for i in range(N):
        #Determine how many vertices are in cell i
        vert_inds = cell_verts[i,:] < 2*N
        j = np.array([i]).astype('int8')[0]
        #If cell i only has three sides don't perform a T1 on it (you can't have a cell with 2 sides) 
        if(len(vert_inds) == 3):
            continue 
        
        #Find UP TO DATE vertex coords of cell i (can't use cell_vert_coords as
        #vertices will have changed due to previous T1 transitions)
        vert_inds = cell_verts[i,:] < 2*N
        cell_i = verts[cell_verts[i,vert_inds],:]
        
        #Shift vertex coords to account for periodicity
        rel_dists = cell_i - points[i,:]
        cell_i = ShiftCoords(cell_i,rel_dists,x_max,y_max,4)
        
        #Calculate the lengths, l, of each cell edge
        shifted_verts = np.roll(cell_i,-1,axis=0)
        l =  shifted_verts - cell_i
        l_mag = np.linalg.norm(l,axis=1)
        #Determine if any edges are below threshold length
        to_change = np.nonzero(l_mag < T1_thresh)
        #print('l = ',l)
        #print('T1_thresh = ', T1_thresh)
        if len(to_change[0]) == 0:
            continue
        else:
            T1_edge = to_change[0][0]
            #Identify vertices and cells affected, also return vert_neighs of old neighbours (to be used when updating vert_neighs later on)
            T1_verts, old_neigh, new_neigh, old_vert_neighs = T1f.IdentifyVertsAndCells(i,cell_verts,vert_cells,vert_neighs,T1_edge,l)
            #Update vertex coordinates
            verts = T1f.UpdateVertCoords(j,verts,points,cell_i,old_neigh,T1_verts,T1_thresh,l,T1_edge,x_max,y_max)    
            #Update vert_cells
            vert_cells = T1f.UpdateVertCells(verts,points,vert_cells,T1_verts,old_neigh,new_neigh)
            #Update cell_verts 
            cell_verts = T1f.UpdateCellVerts(verts,points,cell_verts,T1_verts,old_neigh,new_neigh,N)
            #Update vert_neighs
            vert_neighs = T1f.UpdateVertNeighs(vert_neighs,points,cell_verts,T1_verts,old_neigh,new_neigh,old_vert_neighs,N)

    return verts, vert_neighs, vert_cells, cell_verts

So I am trying to speed up a code, and while it works for most functions, a few do not work. I specify the function's signature, and it doesn't work. If I write only nb.njit, it works, but there is no speed up at all, there is even a slight slow down. When specifying the signature as in the code I posted below, I get the following error:

TypeError: The decorated object is not a function (got type <class 'numba.core.types.npytypes.Array'>).


nb.njit(nb.int32[:],nb.int32[:],nb.int32[:],nb.int32[:,;](nb.int8,nb.int32[:,:],nb.int32[:,:],nb.int32[:,:],nb.int64,nb.float64[:,:])
def IdentifyVertsAndCells(i,cell_verts,vert_cells,vert_neighs,T1_edge,l):
    #Find vertices undergoing T1
    if T1_edge == len(l) - 1:
        T1_verts = cell_verts[i,[T1_edge,0]]
    else:
        T1_verts = cell_verts[i,[T1_edge,T1_edge+1]]
    
    #Identify the four cells are affected by transition
    dummy = np.concatenate((vert_cells[T1_verts[0]],vert_cells[T1_verts[1]]))
    T1cells = np.unique(dummy)

    #Identify cells that are neighbours prior to transition (that won't be afterwards)
    old_neigh = np.intersect1d(vert_cells[T1_verts[0]],vert_cells[T1_verts[1]])
    
    #Identify cells that will be neighbours after transition
    notneigh1 = T1cells[T1cells != old_neigh[0]]
    notneigh2 = T1cells[T1cells != old_neigh[1]]
    new_neigh = np.intersect1d(notneigh1,notneigh2)
    old_vert_neighs = vert_neighs[T1_verts,:]
    return T1_verts, old_neigh, new_neigh, old_vert_neighs

I checked the sizes and data types of my input arrays and number and am sure I did not make a mistake there. I want to add that for the number of type int8, I had to change an int to an int8 using j = np.asarray([i],dtype='int8')[0] because I didn't find a type for int, but I did for int8. The input number i in my code corresponds to that j and is indeed of type int8. When I only use inspect.isfunction on my function, it recognizes it as a function.

Here is the code calling the above function:

def UpdateTopology(points,verts,vert_neighs,vert_cells,cell_verts,l,x_max,y_max,T1_thresh,N):
    for i in range(N):
        #Determine how many vertices are in cell i
        vert_inds = cell_verts[i,:] < 2*N
        j = np.array([i]).astype('int8')[0]
        #If cell i only has three sides don't perform a T1 on it (you can't have a cell with 2 sides) 
        if(len(vert_inds) == 3):
            continue 
        
        #Find UP TO DATE vertex coords of cell i (can't use cell_vert_coords as
        #vertices will have changed due to previous T1 transitions)
        vert_inds = cell_verts[i,:] < 2*N
        cell_i = verts[cell_verts[i,vert_inds],:]
        
        #Shift vertex coords to account for periodicity
        rel_dists = cell_i - points[i,:]
        cell_i = ShiftCoords(cell_i,rel_dists,x_max,y_max,4)
        
        #Calculate the lengths, l, of each cell edge
        shifted_verts = np.roll(cell_i,-1,axis=0)
        l =  shifted_verts - cell_i
        l_mag = np.linalg.norm(l,axis=1)
        #Determine if any edges are below threshold length
        to_change = np.nonzero(l_mag < T1_thresh)
        #print('l = ',l)
        #print('T1_thresh = ', T1_thresh)
        if len(to_change[0]) == 0:
            continue
        else:
            T1_edge = to_change[0][0]
            #Identify vertices and cells affected, also return vert_neighs of old neighbours (to be used when updating vert_neighs later on)
            T1_verts, old_neigh, new_neigh, old_vert_neighs = T1f.IdentifyVertsAndCells(i,cell_verts,vert_cells,vert_neighs,T1_edge,l)
            #Update vertex coordinates
            verts = T1f.UpdateVertCoords(j,verts,points,cell_i,old_neigh,T1_verts,T1_thresh,l,T1_edge,x_max,y_max)    
            #Update vert_cells
            vert_cells = T1f.UpdateVertCells(verts,points,vert_cells,T1_verts,old_neigh,new_neigh)
            #Update cell_verts 
            cell_verts = T1f.UpdateCellVerts(verts,points,cell_verts,T1_verts,old_neigh,new_neigh,N)
            #Update vert_neighs
            vert_neighs = T1f.UpdateVertNeighs(vert_neighs,points,cell_verts,T1_verts,old_neigh,new_neigh,old_vert_neighs,N)

    return verts, vert_neighs, vert_cells, cell_verts

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(1

妄想挽回 2025-01-26 14:38:56

您收到的错误与您向 njit 装饰器提供的错误定义的签名有关。每当您的 jitted 函数返回多个值时,您都必须将返回类型定义为同构元组(如果所有返回类型都相同)或异构元组(如果返回类型不同)(请参阅 这个答案)。

关于加速,您不会通过此代码示例获得任何加速:相反,您会减速。我能认识到的主要原因有以下两个:

  1. 您只使用了已经高度优化的 numpy 的标准函数。根据经验,numba 在加速循环方面效果很好。如果您的代码已经矢量化,那么您可能不会通过njit调整您的函数来获得任何速度提升。
  2. 您使用了很多花哨的索引,例如 cell_verts[i,[T1_edge,T1_edge+1]],它生成需要内存分配的数组副本,这是 numba 执行的任务code> 并不是很擅长(请参阅这个答案)。

The error you're getting is related to the wrongly-defined signature that you're providing to the njit decorator. Whenever your jitted function returns several values, you have to define the return type as either an homogeneous tuple (if all the return types are the same) or heterogeneous tuple (if the return types are different) (see this answer).

Regarding the speedup, you won't get any with this code sample: rather you'll get a slowdown. The main reasons I can recognize are the following two:

  1. You are only resorting on numpy's standard functions, that are already highly optimized. As a rule of thumb, numba works well in speeding up loops. If you're code is already vectorized, you probably won't get any speed improvement by njitting your function.
  2. You're using a lot of fancy indexing, e.g. cell_verts[i,[T1_edge,T1_edge+1]], which produce array copies that require memory allocation, a task to which numba is not really good at (see this answer).
~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文