请 [注册] 或 [登录]  | 返回主站

量化交易吧 /  量化平台 帖子:3366815 新帖:0

【机器学习】入门第一步

我太难了发表于:5 月 10 日 06:59回复(1)

机器学习是当下的热门,而当下的量化投资领域里,智能投顾也越来越被众人所关注。最近一直在研究机器学习方面的知识,打算将个人的学习笔记公布出来。一来帮助一些想入门机器学习的朋友,二来为聚宽提供新的资料。

在这里先祝各位宽友日进斗金,祝聚宽越做越大!


机器学习前提介绍:

  • 使用python语言,最好使用python3

  • 使用Jupyter notebook

  • 熟练使用Numpy/SciPy/Pandas/matplotlib

  • 机器学习主要框架scikit-learn

另外,为了方便呈现数据,这里使用了mglarn模块。该模块的使用不必费脑学习,只需要知道它可以帮助美化图表、呈现数据即可。

# 在学习之前,先导入这些常用的模块import numpy as npimport pandas as pdimport matplotlib.pyplot as pltimport mglearn

构建一个简单的机器学习应用

假设这里已经收集了关于鸢尾花的测量数据:花瓣的长度和宽度;花萼的长度和宽度。这些花共有三个品种:setosa/versicolor/virginica。并且事先已经将所有鸢尾花的数据与分类做了对应关系。

如果现在又有一批新的关于鸢尾花的数据,但没有做出分类,是否可以根据其花瓣的长度和宽度、花萼的长度和宽度来预测出其类别呢?

以上问题,是一个分类问题,最终对于数据结果的输出叫做类别

由于事先对已有的鸢尾花数据做了分类处理,再从这些数据的经验中判断新数据的分类,这种学习方式被叫做监督式学习,即从给定好的输入与输出的对应关系中,得出新的数据可能的结果。

第一步,获得数据

鸢尾花数据集已经包含在 scikit-learn 的 datasets 模块中,可以直接调用 load_iris 函数来加载数据:

# 导入load_irisfrom sklearn.datasets import load_iris# 调用数据函数iris_dataset = load_iris()

laod_iris 返回的是一个 Buch 对象,与字典非常相似,里面包含键和值:

# 查看键iris_dataset.keys()
dict_keys(['data', 'target', 'target_names', 'DESCR', 'feature_names'])
  • data 对应是的鸢尾花测量的数据集

  • target 对应的是分类

  • target_names 对应是的类别的名称

  • DESCR 对应的是数据集的说明

  • feature_names 对应的是数据特征列表

# 查看数据集iris_dataset.data
array([[5.1, 3.5, 1.4, 0.2],       [4.9, 3. , 1.4, 0.2],       [4.7, 3.2, 1.3, 0.2],       [4.6, 3.1, 1.5, 0.2],       [5. , 3.6, 1.4, 0.2],       [5.4, 3.9, 1.7, 0.4],       [4.6, 3.4, 1.4, 0.3],       [5. , 3.4, 1.5, 0.2],       [4.4, 2.9, 1.4, 0.2],       [4.9, 3.1, 1.5, 0.1],       [5.4, 3.7, 1.5, 0.2],       [4.8, 3.4, 1.6, 0.2],
       ……,       [6.3, 2.5, 5. , 1.9],       [6.5, 3. , 5.2, 2. ],       [6.2, 3.4, 5.4, 2.3],       [5.9, 3. , 5.1, 1.8]])

datak中的数据有四列,每列表示花萼的长度、花萼的宽度、花瓣的长度、花瓣的宽度,格式为Numpy数组。

# 查看数据的数量iris_dataset.data.shape
(150, 4)

可以看出数据一共有150行,4列。

# 查看数据对应的分类iris_dataset.target
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,       0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])

每一行代表一个花朵,也对应一个类别,这里的0,1,2分别代表三个品种。

要点知道:

数据集中的个体叫做样本,其属性叫作特征或标签

# 查看数据的特征iris_dataset.feature_names
['sepal length (cm)',
 'sepal width (cm)',
 'petal length (cm)',
 'petal width (cm)']

第二步:训练数据

这里不会将所有的数据都用来训练,还要留出一部分用来预测。这里的预测指的是,通过训练后,形成一个模型,使用未经训练的数据去测试模型是否能准确预测出其分类。

sklearn-learn 中的 train_split 函数可以打乱数据集并进行拆分,默认会将75%的数据用作训练集,25%的数据集用作测试集。

在书写上,数据通常用大写的X表示,标签则用小写的y表示。一般大写用来表示二维矩阵,小写表示一维的向量。

# 导入 train_split函数from sklearn.model_selection import train_test_split# 拆分数据X_train, X_test, y_train, y_test = train_test_split(iris_dataset['data'], iris_dataset['target'], random_state=0)

train_test_split 中需要三个参数,第一个是数据集,第二个是标签集,第三个是随机种子数。

由于iris_dataset是一个Buch对象,因为既可以使用属性的方式也可以使用中括号的方式获得对应的值。

random_state是指利用伪随机数生成器将数据集打乱。

# 查看训练数据集X_train
array([[5.9, 3. , 4.2, 1.5],       [5.8, 2.6, 4. , 1.2],       [6.8, 3. , 5.5, 2.1],       [4.7, 3.2, 1.3, 0.2],       [6.9, 3.1, 5.1, 2.3],       [5. , 3.5, 1.6, 0.6],       [5.4, 3.7, 1.5, 0.2],       [5. , 2. , 3.5, 1. ],       [6.5, 3. , 5.5, 1.8],       [6.7, 3.3, 5.7, 2.5],
       ……,       [5.6, 3. , 4.1, 1.3],       [5.9, 3.2, 4.8, 1.8],       [6.3, 2.3, 4.4, 1.3],       [5.5, 3.5, 1.3, 0.2],       [5.1, 3.7, 1.5, 0.4],       [4.9, 3.1, 1.5, 0.1],       [6.3, 2.9, 5.6, 1.8],       [5.8, 2.7, 4.1, 1. ],       [7.7, 3.8, 6.7, 2.2],       [4.6, 3.2, 1.4, 0.2]])
# 查看训练标签集y_train
array([1, 1, 2, 0, 2, 0, 0, 1, 2, 2, 2, 2, 1, 2, 1, 1, 2, 2, 2, 2, 1, 2,       1, 0, 2, 1, 1, 1, 1, 2, 0, 0, 2, 1, 0, 0, 1, 0, 2, 1, 0, 1, 2, 1,       0, 2, 2, 2, 2, 0, 0, 2, 2, 0, 2, 0, 2, 2, 0, 0, 2, 0, 0, 0, 1, 2,       2, 0, 0, 0, 1, 1, 0, 0, 1, 0, 2, 1, 2, 1, 0, 2, 0, 2, 0, 0, 2, 0,       2, 1, 1, 1, 2, 2, 1, 1, 0, 1, 2, 2, 0, 1, 1, 1, 1, 0, 0, 0, 2, 1,       2, 0])
# 查看测试数据集X_test
array([[5.8, 2.8, 5.1, 2.4],       [6. , 2.2, 4. , 1. ],       [5.5, 4.2, 1.4, 0.2],       [7.3, 2.9, 6.3, 1.8],       [5. , 3.4, 1.5, 0.2],       [6.3, 3.3, 6. , 2.5],       [5. , 3.5, 1.3, 0.3],       [6.7, 3.1, 4.7, 1.5],       [6.8, 2.8, 4.8, 1.4],       [6.1, 2.8, 4. , 1.3],       [6.1, 2.6, 5.6, 1.4],       [6.4, 3.2, 4.5, 1.5],       [6.1, 2.8, 4.7, 1.2],       [6.5, 2.8, 4.6, 1.5],       [6.1, 2.9, 4.7, 1.4],       [4.9, 3.1, 1.5, 0.1],       [6. , 2.9, 4.5, 1.5],       [5.5, 2.6, 4.4, 1.2],       [4.8, 3. , 1.4, 0.3],       [5.4, 3.9, 1.3, 0.4],       [5.6, 2.8, 4.9, 2. ],       [5.6, 3. , 4.5, 1.5],       [4.8, 3.4, 1.9, 0.2],       [4.4, 2.9, 1.4, 0.2],       [6.2, 2.8, 4.8, 1.8],       [4.6, 3.6, 1. , 0.2],       [5.1, 3.8, 1.9, 0.4],       [6.2, 2.9, 4.3, 1.3],       [5. , 2.3, 3.3, 1. ],       [5. , 3.4, 1.6, 0.4],       [6.4, 3.1, 5.5, 1.8],       [5.4, 3. , 4.5, 1.5],       [5.2, 3.5, 1.5, 0.2],       [6.1, 3. , 4.9, 1.8],       [6.4, 2.8, 5.6, 2.2],       [5.2, 2.7, 3.9, 1.4],       [5.7, 3.8, 1.7, 0.3],       [6. , 2.7, 5.1, 1.6]])
# 查看测试标签集y_test
array([2, 1, 0, 2, 0, 2, 0, 1, 1, 1, 2, 1, 1, 1, 1, 0, 1, 1, 0, 0, 2, 1,       0, 0, 2, 0, 0, 1, 1, 0, 2, 1, 0, 2, 2, 1, 0, 1])

观察数据

下图可以通过四个标签值两两对应的关系,查看其表现。(这里不做深究其原理)

# 将训练数据转换成DataFrameiris_dataframe = pd.DataFrame(X_train, columns=iris_dataset.feature_names)# 通过scatter_matrix绘制出矩阵图grr = pd.scatter_matrix(iris_dataframe, c=y_train, figsize=(15, 15), marker='o', hist_kwds={'bins': 20}, s=60, alpha=.8, cmap=mglearn.cm3)
C:\Users\Administrator\Anaconda3\lib\site-packages\ipykernel_launcher.py:4: FutureWarning: pandas.scatter_matrix is deprecated, use pandas.plotting.scatter_matrix instead
  after removing the cwd from sys.path.

构建第一个模型:k近邻算法

想要训练数据,则需要一个算法模型。这里选择使用k近邻分类算法。

k近邻分类器中k的含义,新数据与训练集中最近的任意k个邻居,也就是说,新数据与k个某标签离得最近,则归类为该标签

scikit_lean 中所有的机器学习模型都在各自的类中实现,k近邻算法实在 neighors 模块的 KNei*orsClassifier 类中实现的,我们需要将这个列实例化为一个对象,然后才能使用这个模型

# 导入KNei*orsClassifier模块from sklearn.nei*ors import KNei*orsClassifier# 实例化对象knn = KNei*orsClassifier(n_nei*ors=1)

n_nei*ors 参数表示k的个数,1一表示按与它相邻最近的那1个进行分类。

想要基于训练集来构建模型,需要调用knn对象的fit方法,输入参数X_train和y_train。

# 训练数据,并返回模型knn.fit(X_train,y_train)
KNei*orsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',
           metric_params=None, n_jobs=1, n_nei*ors=1, p=2,
           weights='uniform')

fit方法返回的是knn对象,所以这里得到了一个表示该对象的字符串

第三步:做出预测

# 假设这里有一个新的花瓣数据X_new = np.array([[5,2.9,1,0.2]])

需要注意的是,这里的数据一定要是二维的数据才可以

调用 knn 的 predict 方法来进行预测

# 调用 predict 函数进行预测prediction = knn.predict(X_new)# 查看返回的类型prediction
array([0])
iris_dataset['target_names'][prediction]
array(['setosa'], dtype='U10')

predict 方法会返回一个标签值,通过标签值,则可获得其对应的品种名称

第四步:评估模型

调用测试集,对测试数据中的每朵鸢尾花进行预测,并将预测结果与标签(一直的品种)进行对比。我们可以通过计算精度来衡量模型的优劣,精度就是品种预测正确的花所占的比例

y_pred = knn.predict(X_test)
y_pred
array([2, 1, 0, 2, 0, 2, 0, 1, 1, 1, 2, 1, 1, 1, 1, 0, 1, 1, 0, 0, 2, 1,       0, 0, 2, 0, 0, 1, 1, 0, 2, 1, 0, 2, 2, 1, 0, 2])

那么,测试返回的分类集合,与原始的分类是否一致呢?这里需要将 y_pred 与 y_test 进行对比

np.mean(y_pred==y_test)
0.9736842105263158

或者直接调用knn的score方法来计算精度

knn.score(X_test,y_test)
0.9736842105263158

可以看出,测试返回的结果中,与原始分类集合具有97%的相似度。

以上便是机器学习的基本流程。O(∩_∩)

机器学习是当下的热门,而当下的量化投资领域里,智能投顾也越来越被众人所关注。最近一直在研究机器学习方面的知识,打算将个人的学习笔记公布出来。一来帮助一些想入门机器学习的朋友,二来为聚宽提供新的资料。

在这里先祝各位宽友日进斗金,祝聚宽越做越大!


机器学习前提介绍:

  • 使用python语言,最好使用python3

  • 使用Jupyter notebook

  • 熟练使用Numpy/SciPy/Pandas/matplotlib

  • 机器学习主要框架scikit-learn

另外,为了方便呈现数据,这里使用了mglarn模块。该模块的使用不必费脑学习,只需要知道它可以帮助美化图表、呈现数据即可。

# 在学习之前,先导入这些常用的模块import numpy as npimport pandas as pdimport matplotlib.pyplot as pltimport mglearn

构建一个简单的机器学习应用¶

假设这里已经收集了关于鸢尾花的测量数据:花瓣的长度和宽度;花萼的长度和宽度。这些花共有三个品种:setosa/versicolor/virginica。并且事先已经将所有鸢尾花的数据与分类做了对应关系。

如果现在又有一批新的关于鸢尾花的数据,但没有做出分类,是否可以根据其花瓣的长度和宽度、花萼的长度和宽度来预测出其类别呢?

以上问题,是一个分类问题,最终对于数据结果的输出叫做类别

由于事先对已有的鸢尾花数据做了分类处理,再从这些数据的经验中判断新数据的分类,这种学习方式被叫做监督式学习,即从给定好的输入与输出的对应关系中,得出新的数据可能的结果。

第一步,获得数据¶

鸢尾花数据集已经包含在 scikit-learn 的 datasets 模块中,可以直接调用 load_iris 函数来加载数据:

# 导入load_irisfrom sklearn.datasets import load_iris# 调用数据函数iris_dataset = load_iris()# 展示结果iris_dataset
{'data': array([[5.1, 3.5, 1.4, 0.2],
        [4.9, 3. , 1.4, 0.2],
        [4.7, 3.2, 1.3, 0.2],
        [4.6, 3.1, 1.5, 0.2],
        [5. , 3.6, 1.4, 0.2],
        [5.4, 3.9, 1.7, 0.4],
        [4.6, 3.4, 1.4, 0.3],
        [5. , 3.4, 1.5, 0.2],
        [4.4, 2.9, 1.4, 0.2],
        [4.9, 3.1, 1.5, 0.1],
        [5.4, 3.7, 1.5, 0.2],
        [4.8, 3.4, 1.6, 0.2],
        [4.8, 3. , 1.4, 0.1],
        [4.3, 3. , 1.1, 0.1],
        [5.8, 4. , 1.2, 0.2],
        [5.7, 4.4, 1.5, 0.4],
        [5.4, 3.9, 1.3, 0.4],
        [5.1, 3.5, 1.4, 0.3],
        [5.7, 3.8, 1.7, 0.3],
        [5.1, 3.8, 1.5, 0.3],
        [5.4, 3.4, 1.7, 0.2],
        [5.1, 3.7, 1.5, 0.4],
        [4.6, 3.6, 1. , 0.2],
        [5.1, 3.3, 1.7, 0.5],
        [4.8, 3.4, 1.9, 0.2],
        [5. , 3. , 1.6, 0.2],
        [5. , 3.4, 1.6, 0.4],
        [5.2, 3.5, 1.5, 0.2],
        [5.2, 3.4, 1.4, 0.2],
        [4.7, 3.2, 1.6, 0.2],
        [4.8, 3.1, 1.6, 0.2],
        [5.4, 3.4, 1.5, 0.4],
        [5.2, 4.1, 1.5, 0.1],
        [5.5, 4.2, 1.4, 0.2],
        [4.9, 3.1, 1.5, 0.1],
        [5. , 3.2, 1.2, 0.2],
        [5.5, 3.5, 1.3, 0.2],
        [4.9, 3.1, 1.5, 0.1],
        [4.4, 3. , 1.3, 0.2],
        [5.1, 3.4, 1.5, 0.2],
        [5. , 3.5, 1.3, 0.3],
        [4.5, 2.3, 1.3, 0.3],
        [4.4, 3.2, 1.3, 0.2],
        [5. , 3.5, 1.6, 0.6],
        [5.1, 3.8, 1.9, 0.4],
        [4.8, 3. , 1.4, 0.3],
        [5.1, 3.8, 1.6, 0.2],
        [4.6, 3.2, 1.4, 0.2],
        [5.3, 3.7, 1.5, 0.2],
        [5. , 3.3, 1.4, 0.2],
        [7. , 3.2, 4.7, 1.4],
        [6.4, 3.2, 4.5, 1.5],
        [6.9, 3.1, 4.9, 1.5],
        [5.5, 2.3, 4. , 1.3],
        [6.5, 2.8, 4.6, 1.5],
        [5.7, 2.8, 4.5, 1.3],
        [6.3, 3.3, 4.7, 1.6],
        [4.9, 2.4, 3.3, 1. ],
        [6.6, 2.9, 4.6, 1.3],
        [5.2, 2.7, 3.9, 1.4],
        [5. , 2. , 3.5, 1. ],
        [5.9, 3. , 4.2, 1.5],
        [6. , 2.2, 4. , 1. ],
        [6.1, 2.9, 4.7, 1.4],
        [5.6, 2.9, 3.6, 1.3],
        [6.7, 3.1, 4.4, 1.4],
        [5.6, 3. , 4.5, 1.5],
        [5.8, 2.7, 4.1, 1. ],
        [6.2, 2.2, 4.5, 1.5],
        [5.6, 2.5, 3.9, 1.1],
        [5.9, 3.2, 4.8, 1.8],
        [6.1, 2.8, 4. , 1.3],
        [6.3, 2.5, 4.9, 1.5],
        [6.1, 2.8, 4.7, 1.2],
        [6.4, 2.9, 4.3, 1.3],
        [6.6, 3. , 4.4, 1.4],
        [6.8, 2.8, 4.8, 1.4],
        [6.7, 3. , 5. , 1.7],
        [6. , 2.9, 4.5, 1.5],
        [5.7, 2.6, 3.5, 1. ],
        [5.5, 2.4, 3.8, 1.1],
        [5.5, 2.4, 3.7, 1. ],
        [5.8, 2.7, 3.9, 1.2],
        [6. , 2.7, 5.1, 1.6],
        [5.4, 3. , 4.5, 1.5],
        [6. , 3.4, 4.5, 1.6],
        [6.7, 3.1, 4.7, 1.5],
        [6.3, 2.3, 4.4, 1.3],
        [5.6, 3. , 4.1, 1.3],
        [5.5, 2.5, 4. , 1.3],
        [5.5, 2.6, 4.4, 1.2],
        [6.1, 3. , 4.6, 1.4],
        [5.8, 2.6, 4. , 1.2],
        [5. , 2.3, 3.3, 1. ],
        [5.6, 2.7, 4.2, 1.3],
        [5.7, 3. , 4.2, 1.2],
        [5.7, 2.9, 4.2, 1.3],
        [6.2, 2.9, 4.3, 1.3],
        [5.1, 2.5, 3. , 1.1],
        [5.7, 2.8, 4.1, 1.3],
        [6.3, 3.3, 6. , 2.5],
        [5.8, 2.7, 5.1, 1.9],
        [7.1, 3. , 5.9, 2.1],
        [6.3, 2.9, 5.6, 1.8],
        [6.5, 3. , 5.8, 2.2],
        [7.6, 3. , 6.6, 2.1],
        [4.9, 2.5, 4.5, 1.7],
        [7.3, 2.9, 6.3, 1.8],
        [6.7, 2.5, 5.8, 1.8],
        [7.2, 3.6, 6.1, 2.5],
        [6.5, 3.2, 5.1, 2. ],
        [6.4, 2.7, 5.3, 1.9],
        [6.8, 3. , 5.5, 2.1],
        [5.7, 2.5, 5. , 2. ],
        [5.8, 2.8, 5.1, 2.4],
        [6.4, 3.2, 5.3, 2.3],
        [6.5, 3. , 5.5, 1.8],
        [7.7, 3.8, 6.7, 2.2],
        [7.7, 2.6, 6.9, 2.3],
        [6. , 2.2, 5. , 1.5],
        [6.9, 3.2, 5.7, 2.3],
        [5.6, 2.8, 4.9, 2. ],
        [7.7, 2.8, 6.7, 2. ],
        [6.3, 2.7, 4.9, 1.8],
        [6.7, 3.3, 5.7, 2.1],
        [7.2, 3.2, 6. , 1.8],
        [6.2, 2.8, 4.8, 1.8],
        [6.1, 3. , 4.9, 1.8],
        [6.4, 2.8, 5.6, 2.1],
        [7.2, 3. , 5.8, 1.6],
        [7.4, 2.8, 6.1, 1.9],
        [7.9, 3.8, 6.4, 2. ],
        [6.4, 2.8, 5.6, 2.2],
        [6.3, 2.8, 5.1, 1.5],
        [6.1, 2.6, 5.6, 1.4],
        [7.7, 3. , 6.1, 2.3],
        [6.3, 3.4, 5.6, 2.4],
        [6.4, 3.1, 5.5, 1.8],
        [6. , 3. , 4.8, 1.8],
        [6.9, 3.1, 5.4, 2.1],
        [6.7, 3.1, 5.6, 2.4],
        [6.9, 3.1, 5.1, 2.3],
        [5.8, 2.7, 5.1, 1.9],
        [6.8, 3.2, 5.9, 2.3],
        [6.7, 3.3, 5.7, 2.5],
        [6.7, 3. , 5.2, 2.3],
        [6.3, 2.5, 5. , 1.9],
        [6.5, 3. , 5.2, 2. ],
        [6.2, 3.4, 5.4, 2.3],
        [5.9, 3. , 5.1, 1.8]]),
 'target': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]),
 'target_names': array(['setosa', 'versicolor', 'virginica'], dtype='<U10'),
 'DESCR': 'Iris Plants Database\n====================\n\nNotes\n-\nData Set Characteristics:\n    :Number of Instances: 150 (50 in each of three classes)\n    :Number of Attributes: 4 numeric, predictive attributes and the class\n    :Attribute Information:\n        - sepal length in cm\n        - sepal width in cm\n        - petal length in cm\n        - petal width in cm\n        - class:\n                - Iris-Setosa\n                - Iris-Versicolour\n                - Iris-Virginica\n    :Summary Statistics:\n\n    ============== ==== ==== ======= ===== ====================\n                    Min  Max   Mean    SD   Class Correlation\n    ============== ==== ==== ======= ===== ====================\n    sepal length:   4.3  7.9   5.84   0.83    0.7826\n    sepal width:    2.0  4.4   3.05   0.43   -0.4194\n    petal length:   1.0  6.9   3.76   1.76    0.9490  (high!)\n    petal width:    0.1  2.5   1.20  0.76     0.9565  (high!)\n    ============== ==== ==== ======= ===== ====================\n\n    :Missing Attribute Values: None\n    :Class Distribution: 33.3% for each of 3 classes.\n    :Creator: R.A. Fisher\n    :Donor: Michael Marshall (MARSHALL%PLU@io.arc.nasa.gov)\n    :Date: July, 1988\n\nThis is a copy of UCI ML iris datasets.\nhttp://archive.ics.uci.edu/ml/datasets/Iris\n\nThe famous Iris database, first used by Sir R.A Fisher\n\nThis is perhaps the best known database to be found in the\npattern recognition literature.  Fisher\'s paper is a classic in the field and\nis referenced frequently to this day.  (See Duda & Hart, for example.)  The\ndata set contains 3 classes of 50 instances each, where each class refers to a\ntype of iris plant.  One class is linearly separable from the other 2; the\nlatter are NOT linearly separable from each other.\n\nReferences\n\n   - Fisher,R.A. "The use of multiple measurements in taxonomic problems"\n     Annual Eugenics, 7, Part II, 179-188 (1936); also in "Contributions to\n     Mathematical Statistics" (John Wiley, NY, 1950).\n   - Duda,R.O., & Hart,P.E. (1973) Pattern Classification and Scene Analysis.\n     (Q327.D83) John Wiley & Sons.  ISBN 0-471-22361-1.  See page 218.\n   - Dasarathy, B.V. (1980) "Nosing Around the Nei*orhood: A New System\n     Structure and Classification Rule for Recognition in Partially Exposed\n     Environments".  IEEE Transactions on Pattern Analysis and Machine\n     Intelligence, Vol. PAMI-2, No. 1, 67-71.\n   - Gates, G.W. (1972) "The Reduced Nearest Nei*or Rule".  IEEE Transactions\n     on Information Theory, May 1972, 431-433.\n   - See also: 1988 MLC Proceedings, 54-64.  Cheeseman et al"s AUTOCLASS II\n     conceptual clustering system finds 3 classes in the data.\n   - Many, many more ...\n',
 'feature_names': ['sepal length (cm)',
  'sepal width (cm)',
  'petal length (cm)',
  'petal width (cm)']}

laod_iris 返回的是一个 Buch 对象,与字典非常相似,里面包含键和值:

# 查看键iris_dataset.keys()
dict_keys(['data', 'target', 'target_names', 'DESCR', 'feature_names'])
  • data 对应是的鸢尾花测量的数据集

  • target 对应的是分类

  • target_names 对应是的类别的名称

  • DESCR 对应的是数据集的说明

  • feature_names 对应的是数据特征列表

# 查看数据集iris_dataset.data
array([[5.1, 3.5, 1.4, 0.2],
       [4.9, 3. , 1.4, 0.2],
       [4.7, 3.2, 1.3, 0.2],
       [4.6, 3.1, 1.5, 0.2],
       [5. , 3.6, 1.4, 0.2],
       [5.4, 3.9, 1.7, 0.4],
       [4.6, 3.4, 1.4, 0.3],
       [5. , 3.4, 1.5, 0.2],
       [4.4, 2.9, 1.4, 0.2],
       [4.9, 3.1, 1.5, 0.1],
       [5.4, 3.7, 1.5, 0.2],
       [4.8, 3.4, 1.6, 0.2],
       [4.8, 3. , 1.4, 0.1],
       [4.3, 3. , 1.1, 0.1],
       [5.8, 4. , 1.2, 0.2],
       [5.7, 4.4, 1.5, 0.4],
       [5.4, 3.9, 1.3, 0.4],
       [5.1, 3.5, 1.4, 0.3],
       [5.7, 3.8, 1.7, 0.3],
       [5.1, 3.8, 1.5, 0.3],
       [5.4, 3.4, 1.7, 0.2],
       [5.1, 3.7, 1.5, 0.4],
       [4.6, 3.6, 1. , 0.2],
       [5.1, 3.3, 1.7, 0.5],
       [4.8, 3.4, 1.9, 0.2],
       [5. , 3. , 1.6, 0.2],
       [5. , 3.4, 1.6, 0.4],
       [5.2, 3.5, 1.5, 0.2],
       [5.2, 3.4, 1.4, 0.2],
       [4.7, 3.2, 1.6, 0.2],
       [4.8, 3.1, 1.6, 0.2],
       [5.4, 3.4, 1.5, 0.4],
       [5.2, 4.1, 1.5, 0.1],
       [5.5, 4.2, 1.4, 0.2],
       [4.9, 3.1, 1.5, 0.1],
       [5. , 3.2, 1.2, 0.2],
       [5.5, 3.5, 1.3, 0.2],
       [4.9, 3.1, 1.5, 0.1],
       [4.4, 3. , 1.3, 0.2],
       [5.1, 3.4, 1.5, 0.2],
       [5. , 3.5, 1.3, 0.3],
       [4.5, 2.3, 1.3, 0.3],
       [4.4, 3.2, 1.3, 0.2],
       [5. , 3.5, 1.6, 0.6],
       [5.1, 3.8, 1.9, 0.4],
       [4.8, 3. , 1.4, 0.3],
       [5.1, 3.8, 1.6, 0.2],
       [4.6, 3.2, 1.4, 0.2],
       [5.3, 3.7, 1.5, 0.2],
       [5. , 3.3, 1.4, 0.2],
       [7. , 3.2, 4.7, 1.4],
       [6.4, 3.2, 4.5, 1.5],
       [6.9, 3.1, 4.9, 1.5],
       [5.5, 2.3, 4. , 1.3],
       [6.5, 2.8, 4.6, 1.5],
       [5.7, 2.8, 4.5, 1.3],
       [6.3, 3.3, 4.7, 1.6],
       [4.9, 2.4, 3.3, 1. ],
       [6.6, 2.9, 4.6, 1.3],
       [5.2, 2.7, 3.9, 1.4],
       [5. , 2. , 3.5, 1. ],
       [5.9, 3. , 4.2, 1.5],
       [6. , 2.2, 4. , 1. ],
       [6.1, 2.9, 4.7, 1.4],
       [5.6, 2.9, 3.6, 1.3],
       [6.7, 3.1, 4.4, 1.4],
       [5.6, 3. , 4.5, 1.5],
       [5.8, 2.7, 4.1, 1. ],
       [6.2, 2.2, 4.5, 1.5],
       [5.6, 2.5, 3.9, 1.1],
       [5.9, 3.2, 4.8, 1.8],
       [6.1, 2.8, 4. , 1.3],
       [6.3, 2.5, 4.9, 1.5],
       [6.1, 2.8, 4.7, 1.2],
       [6.4, 2.9, 4.3, 1.3],
       [6.6, 3. , 4.4, 1.4],
       [6.8, 2.8, 4.8, 1.4],
       [6.7, 3. , 5. , 1.7],
       [6. , 2.9, 4.5, 1.5],
       [5.7, 2.6, 3.5, 1. ],
       [5.5, 2.4, 3.8, 1.1],
       [5.5, 2.4, 3.7, 1. ],
       [5.8, 2.7, 3.9, 1.2],
       [6. , 2.7, 5.1, 1.6],
       [5.4, 3. , 4.5, 1.5],
       [6. , 3.4, 4.5, 1.6],
       [6.7, 3.1, 4.7, 1.5],
       [6.3, 2.3, 4.4, 1.3],
       [5.6, 3. , 4.1, 1.3],
       [5.5, 2.5, 4. , 1.3],
       [5.5, 2.6, 4.4, 1.2],
       [6.1, 3. , 4.6, 1.4],
       [5.8, 2.6, 4. , 1.2],
       [5. , 2.3, 3.3, 1. ],
       [5.6, 2.7, 4.2, 1.3],
       [5.7, 3. , 4.2, 1.2],
       [5.7, 2.9, 4.2, 1.3],
       [6.2, 2.9, 4.3, 1.3],
       [5.1, 2.5, 3. , 1.1],
       [5.7, 2.8, 4.1, 1.3],
       [6.3, 3.3, 6. , 2.5],
       [5.8, 2.7, 5.1, 1.9],
       [7.1, 3. , 5.9, 2.1],
       [6.3, 2.9, 5.6, 1.8],
       [6.5, 3. , 5.8, 2.2],
       [7.6, 3. , 6.6, 2.1],
       [4.9, 2.5, 4.5, 1.7],
       [7.3, 2.9, 6.3, 1.8],
       [6.7, 2.5, 5.8, 1.8],
       [7.2, 3.6, 6.1, 2.5],
       [6.5, 3.2, 5.1, 2. ],
       [6.4, 2.7, 5.3, 1.9],
       [6.8, 3. , 5.5, 2.1],
       [5.7, 2.5, 5. , 2. ],
       [5.8, 2.8, 5.1, 2.4],
       [6.4, 3.2, 5.3, 2.3],
       [6.5, 3. , 5.5, 1.8],
       [7.7, 3.8, 6.7, 2.2],
       [7.7, 2.6, 6.9, 2.3],
       [6. , 2.2, 5. , 1.5],
       [6.9, 3.2, 5.7, 2.3],
       [5.6, 2.8, 4.9, 2. ],
       [7.7, 2.8, 6.7, 2. ],
       [6.3, 2.7, 4.9, 1.8],
       [6.7, 3.3, 5.7, 2.1],
       [7.2, 3.2, 6. , 1.8],
       [6.2, 2.8, 4.8, 1.8],
       [6.1, 3. , 4.9, 1.8],
       [6.4, 2.8, 5.6, 2.1],
       [7.2, 3. , 5.8, 1.6],
       [7.4, 2.8, 6.1, 1.9],
       [7.9, 3.8, 6.4, 2. ],
       [6.4, 2.8, 5.6, 2.2],
       [6.3, 2.8, 5.1, 1.5],
       [6.1, 2.6, 5.6, 1.4],
       [7.7, 3. , 6.1, 2.3],
       [6.3, 3.4, 5.6, 2.4],
       [6.4, 3.1, 5.5, 1.8],
       [6. , 3. , 4.8, 1.8],
       [6.9, 3.1, 5.4, 2.1],
       [6.7, 3.1, 5.6, 2.4],
       [6.9, 3.1, 5.1, 2.3],
       [5.8, 2.7, 5.1, 1.9],
       [6.8, 3.2, 5.9, 2.3],
       [6.7, 3.3, 5.7, 2.5],
       [6.7, 3. , 5.2, 2.3],
       [6.3, 2.5, 5. , 1.9],
       [6.5, 3. , 5.2, 2. ],
       [6.2, 3.4, 5.4, 2.3],
       [5.9, 3. , 5.1, 1.8]])

datak中的数据有四列,每列表示花萼的长度、花萼的宽度、花瓣的长度、花瓣的宽度,格式为Numpy数组。

# 查看数据的数量iris_dataset.data.shape
(150, 4)

可以看出数据一共有150行,4列。

# 查看数据对应的分类iris_dataset.target
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])

每一行代表一个花朵,也对应一个类别,这里的0,1,2分别代表三个品种。

要点知道:

数据集中的个体叫做样本,其属性叫作特征或标签

# 查看数据的特征iris_dataset.feature_names
['sepal length (cm)',
 'sepal width (cm)',
 'petal length (cm)',
 'petal width (cm)']

第二步:训练数据¶

这里不会将所有的数据都用来训练,还要留出一部分用来预测。这里的预测指的是,通过训练后,形成一个模型,使用未经训练的数据去测试模型是否能准确预测出其分类。

sklearn-learn 中的 train_split 函数可以打乱数据集并进行拆分,默认会将75%的数据用作训练集,25%的数据集用作测试集。

在书写上,数据通常用大写的X表示,标签则用小写的y表示。一般大写用来表示二维矩阵,小写表示一维的向量。

# 导入 train_split函数from sklearn.model_selection import train_test_split# 拆分数据X_train, X_test, y_train, y_test = train_test_split(iris_dataset['data'], iris_dataset['target'], random_state=0)

train_test_split 中需要三个参数,第一个是数据集,第二个是标签集,第三个是随机种子数。

由于iris_dataset是一个Buch对象,因为既可以使用属性的方式也可以使用中括号的方式获得对应的值。

random_state是指利用伪随机数生成器将数据集打乱。

# 查看训练数据集X_train
array([[5.9, 3. , 4.2, 1.5],
       [5.8, 2.6, 4. , 1.2],
       [6.8, 3. , 5.5, 2.1],
       [4.7, 3.2, 1.3, 0.2],
       [6.9, 3.1, 5.1, 2.3],
       [5. , 3.5, 1.6, 0.6],
       [5.4, 3.7, 1.5, 0.2],
       [5. , 2. , 3.5, 1. ],
       [6.5, 3. , 5.5, 1.8],
       [6.7, 3.3, 5.7, 2.5],
       [6. , 2.2, 5. , 1.5],
       [6.7, 2.5, 5.8, 1.8],
       [5.6, 2.5, 3.9, 1.1],
       [7.7, 3. , 6.1, 2.3],
       [6.3, 3.3, 4.7, 1.6],
       [5.5, 2.4, 3.8, 1.1],
       [6.3, 2.7, 4.9, 1.8],
       [6.3, 2.8, 5.1, 1.5],
       [4.9, 2.5, 4.5, 1.7],
       [6.3, 2.5, 5. , 1.9],
       [7. , 3.2, 4.7, 1.4],
       [6.5, 3. , 5.2, 2. ],
       [6. , 3.4, 4.5, 1.6],
       [4.8, 3.1, 1.6, 0.2],
       [5.8, 2.7, 5.1, 1.9],
       [5.6, 2.7, 4.2, 1.3],
       [5.6, 2.9, 3.6, 1.3],
       [5.5, 2.5, 4. , 1.3],
       [6.1, 3. , 4.6, 1.4],
       [7.2, 3.2, 6. , 1.8],
       [5.3, 3.7, 1.5, 0.2],
       [4.3, 3. , 1.1, 0.1],
       [6.4, 2.7, 5.3, 1.9],
       [5.7, 3. , 4.2, 1.2],
       [5.4, 3.4, 1.7, 0.2],
       [5.7, 4.4, 1.5, 0.4],
       [6.9, 3.1, 4.9, 1.5],
       [4.6, 3.1, 1.5, 0.2],
       [5.9, 3. , 5.1, 1.8],
       [5.1, 2.5, 3. , 1.1],
       [4.6, 3.4, 1.4, 0.3],
       [6.2, 2.2, 4.5, 1.5],
       [7.2, 3.6, 6.1, 2.5],
       [5.7, 2.9, 4.2, 1.3],
       [4.8, 3. , 1.4, 0.1],
       [7.1, 3. , 5.9, 2.1],
       [6.9, 3.2, 5.7, 2.3],
       [6.5, 3. , 5.8, 2.2],
       [6.4, 2.8, 5.6, 2.1],
       [5.1, 3.8, 1.6, 0.2],
       [4.8, 3.4, 1.6, 0.2],
       [6.5, 3.2, 5.1, 2. ],
       [6.7, 3.3, 5.7, 2.1],
       [4.5, 2.3, 1.3, 0.3],
       [6.2, 3.4, 5.4, 2.3],
       [4.9, 3. , 1.4, 0.2],
       [5.7, 2.5, 5. , 2. ],
       [6.9, 3.1, 5.4, 2.1],
       [4.4, 3.2, 1.3, 0.2],
       [5. , 3.6, 1.4, 0.2],
       [7.2, 3. , 5.8, 1.6],
       [5.1, 3.5, 1.4, 0.3],
       [4.4, 3. , 1.3, 0.2],
       [5.4, 3.9, 1.7, 0.4],
       [5.5, 2.3, 4. , 1.3],
       [6.8, 3.2, 5.9, 2.3],
       [7.6, 3. , 6.6, 2.1],
       [5.1, 3.5, 1.4, 0.2],
       [4.9, 3.1, 1.5, 0.1],
       [5.2, 3.4, 1.4, 0.2],
       [5.7, 2.8, 4.5, 1.3],
       [6.6, 3. , 4.4, 1.4],
       [5. , 3.2, 1.2, 0.2],
       [5.1, 3.3, 1.7, 0.5],
       [6.4, 2.9, 4.3, 1.3],
       [5.4, 3.4, 1.5, 0.4],
       [7.7, 2.6, 6.9, 2.3],
       [4.9, 2.4, 3.3, 1. ],
       [7.9, 3.8, 6.4, 2. ],
       [6.7, 3.1, 4.4, 1.4],
       [5.2, 4.1, 1.5, 0.1],
       [6. , 3. , 4.8, 1.8],
       [5.8, 4. , 1.2, 0.2],
       [7.7, 2.8, 6.7, 2. ],
       [5.1, 3.8, 1.5, 0.3],
       [4.7, 3.2, 1.6, 0.2],
       [7.4, 2.8, 6.1, 1.9],
       [5. , 3.3, 1.4, 0.2],
       [6.3, 3.4, 5.6, 2.4],
       [5.7, 2.8, 4.1, 1.3],
       [5.8, 2.7, 3.9, 1.2],
       [5.7, 2.6, 3.5, 1. ],
       [6.4, 3.2, 5.3, 2.3],
       [6.7, 3. , 5.2, 2.3],
       [6.3, 2.5, 4.9, 1.5],
       [6.7, 3. , 5. , 1.7],
       [5. , 3. , 1.6, 0.2],
       [5.5, 2.4, 3.7, 1. ],
       [6.7, 3.1, 5.6, 2.4],
       [5.8, 2.7, 5.1, 1.9],
       [5.1, 3.4, 1.5, 0.2],
       [6.6, 2.9, 4.6, 1.3],
       [5.6, 3. , 4.1, 1.3],
       [5.9, 3.2, 4.8, 1.8],
       [6.3, 2.3, 4.4, 1.3],
       [5.5, 3.5, 1.3, 0.2],
       [5.1, 3.7, 1.5, 0.4],
       [4.9, 3.1, 1.5, 0.1],
       [6.3, 2.9, 5.6, 1.8],
       [5.8, 2.7, 4.1, 1. ],
       [7.7, 3.8, 6.7, 2.2],
       [4.6, 3.2, 1.4, 0.2]])
# 查看训练标签集y_train
array([1, 1, 2, 0, 2, 0, 0, 1, 2, 2, 2, 2, 1, 2, 1, 1, 2, 2, 2, 2, 1, 2,
       1, 0, 2, 1, 1, 1, 1, 2, 0, 0, 2, 1, 0, 0, 1, 0, 2, 1, 0, 1, 2, 1,
       0, 2, 2, 2, 2, 0, 0, 2, 2, 0, 2, 0, 2, 2, 0, 0, 2, 0, 0, 0, 1, 2,
       2, 0, 0, 0, 1, 1, 0, 0, 1, 0, 2, 1, 2, 1, 0, 2, 0, 2, 0, 0, 2, 0,
       2, 1, 1, 1, 2, 2, 1, 1, 0, 1, 2, 2, 0, 1, 1, 1, 1, 0, 0, 0, 2, 1,
       2, 0])
# 查看测试数据集X_test
array([[5.8, 2.8, 5.1, 2.4],
       [6. , 2.2, 4. , 1. ],
       [5.5, 4.2, 1.4, 0.2],
       [7.3, 2.9, 6.3, 1.8],
       [5. , 3.4, 1.5, 0.2],
       [6.3, 3.3, 6. , 2.5],
       [5. , 3.5, 1.3, 0.3],
       [6.7, 3.1, 4.7, 1.5],
       [6.8, 2.8, 4.8, 1.4],
       [6.1, 2.8, 4. , 1.3],
       [6.1, 2.6, 5.6, 1.4],
       [6.4, 3.2, 4.5, 1.5],
       [6.1, 2.8, 4.7, 1.2],
       [6.5, 2.8, 4.6, 1.5],
       [6.1, 2.9, 4.7, 1.4],
       [4.9, 3.1, 1.5, 0.1],
       [6. , 2.9, 4.5, 1.5],
       [5.5, 2.6, 4.4, 1.2],
       [4.8, 3. , 1.4, 0.3],
       [5.4, 3.9, 1.3, 0.4],
       [5.6, 2.8, 4.9, 2. ],
       [5.6, 3. , 4.5, 1.5],
       [4.8, 3.4, 1.9, 0.2],
       [4.4, 2.9, 1.4, 0.2],
       [6.2, 2.8, 4.8, 1.8],
       [4.6, 3.6, 1. , 0.2],
       [5.1, 3.8, 1.9, 0.4],
       [6.2, 2.9, 4.3, 1.3],
       [5. , 2.3, 3.3, 1. ],
       [5. , 3.4, 1.6, 0.4],
       [6.4, 3.1, 5.5, 1.8],
       [5.4, 3. , 4.5, 1.5],
       [5.2, 3.5, 1.5, 0.2],
       [6.1, 3. , 4.9, 1.8],
       [6.4, 2.8, 5.6, 2.2],
       [5.2, 2.7, 3.9, 1.4],
       [5.7, 3.8, 1.7, 0.3],
       [6. , 2.7, 5.1, 1.6]])
# 查看测试标签集y_test
array([2, 1, 0, 2, 0, 2, 0, 1, 1, 1, 2, 1, 1, 1, 1, 0, 1, 1, 0, 0, 2, 1,
       0, 0, 2, 0, 0, 1, 1, 0, 2, 1, 0, 2, 2, 1, 0, 1])

观察数据¶

下图可以通过四个标签值两两对应的关系,查看其表现。(这里不做深究其原理)

# 将训练数据转换成DataFrameiris_dataframe = pd.DataFrame(X_train, columns=iris_dataset.feature_names)# 通过scatter_matrix绘制出矩阵图grr = pd.scatter_matrix(iris_dataframe, c=y_train, figsize=(15, 15), marker='o', hist_kwds={'bins': 20}, s=60, alpha=.8, cmap=mglearn.cm3)
C:\Users\Administrator\Anaconda3\lib\site-packages\ipykernel_launcher.py:4: FutureWarning: pandas.scatter_matrix is deprecated, use pandas.plotting.scatter_matrix instead
  after removing the cwd from sys.path.

构建第一个模型:k近邻算法¶

想要训练数据,则需要一个算法模型。这里选择使用k近邻分类算法。

k近邻分类器中k的含义,新数据与训练集中最近的任意k个邻居,也就是说,新数据与k个某标签离得最近,则归类为该标签

scikit_lean 中所有的机器学习模型都在各自的类中实现,k近邻算法实在 neighors 模块的 KNei*orsClassifier 类中实现的,我们需要将这个列实例化为一个对象,然后才能使用这个模型

# 导入KNei*orsClassifier模块from sklearn.nei*ors import KNei*orsClassifier# 实例化对象knn = KNei*orsClassifier(n_nei*ors=1)

n_nei*ors 参数表示k的个数,1一表示按与它相邻最近的那1个进行分类。

想要基于训练集来构建模型,需要调用knn对象的fit方法,输入参数X_train和y_train。

# 训练数据,并返回模型knn.fit(X_train,y_train)
KNei*orsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',
           metric_params=None, n_jobs=1, n_nei*ors=1, p=2,
           weights='uniform')

fit方法返回的是knn对象,所以这里得到了一个表示该对象的字符串

第三步:做出预测¶

# 假设这里有一个新的花瓣数据X_new = np.array([[5,2.9,1,0.2]])

需要注意的是,这里的数据一定要是二维的数据才可以

调用 knn 的 predict 方法来进行预测

# 调用 predict 函数进行预测prediction = knn.predict(X_new)# 查看返回的类型prediction
array([0])
iris_dataset['target_names'][prediction]
array(['setosa'], dtype='<U10')

predict 方法会返回一个标签值,通过标签值,则可获得其对应的品种名称

第四步:评估模型¶

调用测试集,对测试数据中的每朵鸢尾花进行预测,并将预测结果与标签(一直的品种)进行对比。我们可以通过计算精度来衡量模型的优劣,精度就是品种预测正确的花所占的比例

y_pred = knn.predict(X_test)y_pred
array([2, 1, 0, 2, 0, 2, 0, 1, 1, 1, 2, 1, 1, 1, 1, 0, 1, 1, 0, 0, 2, 1,
       0, 0, 2, 0, 0, 1, 1, 0, 2, 1, 0, 2, 2, 1, 0, 2])

那么,测试返回的分类集合,与原始的分类是否一致呢?这里需要将 y_pred 与 y_test 进行对比

np.mean(y_pred==y_test)
0.9736842105263158

或者直接调用knn的score方法来计算精度

knn.score(X_test,y_test)
0.9736842105263158

可以看出,测试返回的结果中,与原始分类集合具有97%的相似度。

以上便是机器学习的基本流程。O(∩_∩)

全部回复

0/140

量化课程

    移动端课程