使用numpy生成kneighbors_graph,例如scikit-learn?

发布于 2025-02-06 02:52:51 字数 439 浏览 1 评论 0原文

我正在尝试了解我对某种基本ML方法的理解。有人可以用kneighbors_graph解释引擎盖下的情况吗?我想仅用numpy复制此输出。

from sklearn.neighbors import kneighbors_graph

X = [[0, 1], [3, 4], [7, 8]]
A = kneighbors_graph(X, 2, mode='distance', include_self=True)
A.toarray()

输出:

array([[0.        , 4.24264069, 0.        ],
       [4.24264069, 0.        , 0.        ],
       [0.        , 5.65685425, 0.        ]])

I'm trying to brush up on my understanding of some basic ML method. Can someone explain what is going on under the hood with kneighbors_graph? I would like to replicate this output with only NumPy.

from sklearn.neighbors import kneighbors_graph

X = [[0, 1], [3, 4], [7, 8]]
A = kneighbors_graph(X, 2, mode='distance', include_self=True)
A.toarray()

Output:

array([[0.        , 4.24264069, 0.        ],
       [4.24264069, 0.        , 0.        ],
       [0.        , 5.65685425, 0.        ]])

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(1

無處可尋 2025-02-13 02:52:51

所得矩阵表示 n = 2个邻居的距离加权图,x中的每个点,在其中,您将点作为其自己的邻居(距离为零, )。请注意,非纽布尔的距离也为零,因此您可能需要检查连接图,以了解您是在查看零距离邻居还是非邻居。

让我们从第一行开始,代表第一个点,[0,1]。这是该行中的数字是什么意思:

您可以使用 scipy.spatial.distance.cdist(x,x) 。还有代码> 适用于邻居查找。如果您真的想获得纯numpy,请查看linalg模块。

The resulting matrix represents the distance-weighted graph of n = 2 neighbours for each point in X, where you are including a point as its own neighbour (with a distance of zero). Note that distances to non-neighbours are also zero, so you might want to check the connectivity graph to know if you're looking at a zero-distance neighbour or a non-neighbour.

Let's start with the first row, representing the first point, [0, 1]. Here's what the numbers in that row mean:

  • The first 0 is the distance to the nearest point, which is itself (because you specified include_self=True). If you specified mode='connectivity' this would be a 1, because it's a neighbour.
  • The second element, 4.24, is the Euclidean distance (aka L2 norm) to the next point in X, which is [3, 4]. You get this distance because metric='minkowski', p=2 are defaults; if you want a different distance metric, you can have it. Again, if you specified mode='connectivity' this would also be a 1, because it's a neighbour.
  • The third element, another 0, is not really a distance; it's telling you that the third point, [7, 8], is not a neighbour when n_neighbors is 2. If you specified mode='connectivity' this would be a 0, because it's not a neighbour.

You can compute the distances between all pairs of points in an array with scipy.spatial.distance.cdist(X, X). There's also scipy.spatial.KDTree for neighbour lookup. If you really want to go pure NumPy, check out the linalg module.

~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文