5.5 scikit-learn
本节介绍的是Python在机器学习方面一个非常强力的模块——scikit-learn。scikit-learn是在NumPy、SciPy和Matplotlib三个模块上编写的,是数据挖掘和数据分析的一个简单而有效的工具。在其官方网站上我们可以看到scikit-learn有6大功能:分类(Classification),回归(Regression),聚类(Clustering),降维(Dimensionality Reduction),模型选择(Model Selection)和预处理(Preprocessing)。下面先将简单介绍机器学习和scikit-learn的应用。
1.机器学习的问题
一般来说,我们可以这样理解机器学习的问题:我们有n个样本(sample)的数据集,想要预测未知数据的属性。如果样本的数据是多维的,那么我们就说样本具有多个属性或特征。
我们可以将学习问题分为以下两类:
1)有监督学习 (Supervised Learning)是指数据中包括了我们想要预测的属性,即目标变量,而有监督学习问题有以下两类:
·分类 (Classification):样本属于两个或多个类别,我们希望通过从已标记类别的数据学习,来预测未标记数据的分类。例如,识别手写数字就是一个分类问题,其目标是将每个输入向量对应到有穷的数字类别。从另一种角度来思考,分类是一种有监督学习的离散(相对于连续)形式,对于n个样本,一方有对应的有限个类别数量,另一方则试图标记样本并分配到正确的类别。
·回归 (Regression):如果希望的输出是一个或多个连续的变量,那么这个问题称为回归,比如用三文鱼的年龄和体重去预测其长度。
2)无监督学习 (Unsupervised Learning):无监督学习的训练数据包括了输入向量X的集合,但没有相应的目标变量。这类问题的目标可以是发掘数据中相似样本的分组,被称作聚类(Clustering);也可以是确定输入样本空间中的数据分布,被称作密度估计(Density Estimation);还可以是将数据从高维空间投射到两维或三维空间,以便进行数据可视化。
2.scikit-learn的数据集
scikit-learn有一些标准数据集,比如分类的iris和digits数据集和用于回归的波士顿房价(Boston House Prices)数据集。针对digits数据集的任务是给定一个8*8像素数组,程序能够预测这64个像素代表哪个数字,图5-2所示。
图5-2 手写数字识别示意图
下面,我们尝试用Python加载digits数据集,如代码清单5-16所示。
代码清单5-16 加载digits数据集
from sklearn import datasets # 数据集类似字典对象,包括了所有的数据和关于数据的元数据( metadata)。 # 数据被存储在 .data成员内,是一个 n_samples*n_features的数组。 # 在有监督问题的情形下,一个或多个因变量( response variables)被储存在 .target成员中 digits = datasets.load_digits() # 例如在 digits数据集中, digits.data是可以用来分类数字样本的特征 print digits.data # result: # [[ 0. 0. 5. ..., 0. 0. 0.] # [ 0. 0. 0. ..., 10. 0. 0.] # [ 0. 0. 0. ..., 16. 9. 0.] # ..., # [ 0. 0. 1. ..., 6. 0. 0.] # [ 0. 0. 2. ..., 12. 0. 0.] # [ 0. 0. 10. ..., 12. 1. 0.]] #digits.target给出了 digits数据集的目标变量,即每个数字图案对应的我们想预测的真实数字 print digits.target # result: # [0 1 2, ..., 8 9 8]
*代码详见:示例程序/code/5-5.py
3.scikit-learn的训练和预测
接着上面的例子,我们的任务是给定一幅像素图案,预测其表示的数字。这是一个有监督学习的分类问题,总共有10个可能的分类(数字0~9)。我们将训练一个预测器(Estimator)来预测(Predict)未知样本所属分类。
在scikit-learn中,分类的预测器是一个Python对象,具有方法fit(X,y)和predict(test)方法。下面这个预测器的例子是sklearn.svm.SVC,实现了支持向量机分类。创建分类器需要模型参数,但现在我们暂时先将分类器看作是一个黑盒。代码清单5-17展现了整个训练和预测的过程:
代码清单5-17 训练和预测
from sklearn import svm # 选择模型参数 clf = svm.SVC(gamma=0.0001,C=100) # 我们的预测器的名字叫做 clf。现在 clf必须通过 fit方法来从模型中学习。 # 这个过程是通过将训练集传递给 fit方法来实现的。我们将除了最后一个样本的数据全部作为训练集。 # 进行训练 clf.fit(digits.data[:-1], digits.target[:-1]) # 进行预测 print clf.predict(digits.data[-1]) # result: 8
*代码详见:示例程序/code/5-5.py
如图5-3所示,最后一个像素图案显示出数字8,和我们的预测结果一致。关于scikit-learn我们暂时介绍到这里,在后面的第8和第9章我们也会用到此模块。
图5-3 手写数字识别的测试图片
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。
绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论