使用numpy生成kneighbors_graph,例如scikit-learn?
我正在尝试了解我对某种基本ML方法的理解。有人可以用kneighbors_graph
解释引擎盖下的情况吗?我想仅用numpy复制此输出。
from sklearn.neighbors import kneighbors_graph
X = [[0, 1], [3, 4], [7, 8]]
A = kneighbors_graph(X, 2, mode='distance', include_self=True)
A.toarray()
输出:
array([[0. , 4.24264069, 0. ],
[4.24264069, 0. , 0. ],
[0. , 5.65685425, 0. ]])
I'm trying to brush up on my understanding of some basic ML method. Can someone explain what is going on under the hood with kneighbors_graph
? I would like to replicate this output with only NumPy.
from sklearn.neighbors import kneighbors_graph
X = [[0, 1], [3, 4], [7, 8]]
A = kneighbors_graph(X, 2, mode='distance', include_self=True)
A.toarray()
Output:
array([[0. , 4.24264069, 0. ],
[4.24264069, 0. , 0. ],
[0. , 5.65685425, 0. ]])
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论
评论(1)
所得矩阵表示 n = 2个邻居的距离加权图,
x
中的每个点,在其中,您将点作为其自己的邻居(距离为零, )。请注意,非纽布尔的距离也为零,因此您可能需要检查连接图,以了解您是在查看零距离邻居还是非邻居。让我们从第一行开始,代表第一个点,
[0,1]
。这是该行中的数字是什么意思:0
是距最近的点的距离(因为您指定inception_self = true
)。如果您指定mode ='连接'
这将是1,因为它是邻居。4.24
是 euclidean> euclidean标准)到x
的下一个点,它是[3,4]
。您之所以得到此距离,是因为metric ='minkowski',p = 2
是默认值;如果您想要不同的距离度量,则可以拥有。同样,如果您指定mode ='Connectivity'
这也是1,因为它是邻居。0
,并不是真正的距离。它告诉您,第三点[7,8]
是不是当n_neighbors
是2。模式='连接'这将是0,因为它不是邻居。您可以使用
scipy.spatial.distance.cdist(x,x)
。还有代码> 适用于邻居查找。如果您真的想获得纯numpy,请查看linalg
模块。The resulting matrix represents the distance-weighted graph of n = 2 neighbours for each point in
X
, where you are including a point as its own neighbour (with a distance of zero). Note that distances to non-neighbours are also zero, so you might want to check the connectivity graph to know if you're looking at a zero-distance neighbour or a non-neighbour.Let's start with the first row, representing the first point,
[0, 1]
. Here's what the numbers in that row mean:0
is the distance to the nearest point, which is itself (because you specifiedinclude_self=True
). If you specifiedmode='connectivity'
this would be a 1, because it's a neighbour.4.24
, is the Euclidean distance (aka L2 norm) to the next point inX
, which is[3, 4]
. You get this distance becausemetric='minkowski', p=2
are defaults; if you want a different distance metric, you can have it. Again, if you specifiedmode='connectivity'
this would also be a 1, because it's a neighbour.0
, is not really a distance; it's telling you that the third point,[7, 8]
, is not a neighbour whenn_neighbors
is 2. If you specifiedmode='connectivity'
this would be a 0, because it's not a neighbour.You can compute the distances between all pairs of points in an array with
scipy.spatial.distance.cdist(X, X)
. There's alsoscipy.spatial.KDTree
for neighbour lookup. If you really want to go pure NumPy, check out thelinalg
module.