7.4 kNN 算法
kNN算法是k-Nearest Neighbor Classification的简称,即k-近邻分类算法。它的思想很简单:一个样本在特征空间中,总会有k个最相似(即特征空间中最邻近)的样本。其中,大多数样本属于某一个类别,则该样本也属于这个类别。用一句不准确的俗语来描述会更直白:“近朱者赤,近墨者黑”。
如图7-10所示,我们有两类数据:方块和三角形。它们分布在二维特征空间中。假设有一个新数据(用圆表示)需要预测其所属的类别,根据“物以类聚”的直觉,我们找到离圆圈最近的几个数据点,以它们中的大多数的特点(所属类别)来决定新数据所属的类别,这便是一次预测。
图7-10 k–近邻算法例子
如果k=3,由于三角形所占比例为2/3,k–近邻算法更倾向于认为:圆属于三角形对应的类别。如果k=5,由于方块所占比例为3/5,k–近邻算法更倾向于认为:圆属于方块对应的类别。
读者需注意区分“分类”与“聚类”的区别。分类属于有监督学习问题的范畴,而聚类属于无监督学习。举例阐释:原始人尚未发明文字,用符号来表示事物。他们能够通过观察到的大量事实(相当于收集数据集),发现人与人之间有天然的生理差别。原始人能将族人聚成两个大类,用无意义的符号“♂”和“♀”表示,这个例子能说明“无监督”的内涵,也暗示这被看作是一个聚类问题。
在现代,我们会把“性别”这一概念牢牢地灌输给每一个儿童。我们不断地训练他们,并及时纠正错误,以逐渐形成准确的判断。这个“教育”的过程,便是“有监督”的含义。但即使是成人,我们也可能会在“男扮女装”或“女扮男装”的场景中被欺骗。这个过程与我们训练某一种算法的流程极为相似。或者说,机器学习算法的灵感正是源于生活中的点滴。
另外,k–近邻算法是一种非参数模型 。简单来说,参数模型(如线性回归、逻辑回归等)都包含待确定的参数。训练过程的主要目的是寻找代价最小的最优参数。参数一旦确定,模型就完全固定了,进行预测时完全不依赖于训练数据。非参数模型则相反,在每次预测中都需要重新考虑部分或全部训练(已知的)数据。在下面的算法流程中,请读者仔细体会二者的区别。
1.算法流程
1)计算已知类别数据集中的点与当前点之间的距离。
2)按照距离递增次序排序。
3)选取与当前点距离最小的k个点。
4)确定前k个点所在类别对应的出现频率。
5)返回前k个点出现频率最高的类别作为当前点的预测分类。
2.算法实现
代码清单7-5是KNN算法的一个具体实例,其输出效果图如图7-11所示。
代码清单7-5 kNN算法示例
# -*- coding:utf-8 -*- import numpy as np import matplotlib.pyplot as plt from matplotlib.colors import ListedColormap from sklearn.neighbors import KNeighborsClassifier from sklearn.datasets import load_iris iris = load_iris() # 加载数据 X = iris.data[:, :2] # 为方便画图,仅采用数据的其中两个特征 y = iris.target print iris.DESCR print iris.feature_names cmap_light = ListedColormap(['#FFAAAA', '#AAFFAA', '#AAAAFF']) cmap_bold = ListedColormap(['#FF0000', '#00FF00', '#0000FF']) clf = KNeighborsClassifier(n_neighbors=15, weights='uniform') # 初始化分类器对象 clf.fit(X, y) # 画出决策边界,用不同颜色表示 x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1 y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1 xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.02), np.arange(y_min, y_max, 0.02)) Z = clf.predict(np.c_[xx.ravel(), yy.ravel()]).reshape(xx.shape) plt.figure() plt.pcolormesh(xx, yy, Z, cmap=cmap_light) # 绘制预测结果图 plt.scatter(X[:, 0], X[:, 1], c=y, cmap=cmap_bold) # 补充训练数据点 plt.xlim(xx.min(), xx.max()) plt.ylim(yy.min(), yy.max()) plt.title("3-Class classification (k = 15, weights = 'uniform')") plt.show()
*代码详见:示例程序/code/7-4.py
图7-11 代码输出结果图
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。
绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论