總地址:git hub :machine-learning-python
源地址:分類法/範例一: Recognizing hand-written digits
1.代碼
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Date : 2017-10-20 15:19:37
# @Author : VeeL ([email protected])
# @Link : http://blog.csdn.net/ml_1019
# @Version : $Id$
import matplotlib.pyplot as plt
from sklearn.model_selection import cross_val_predict
from sklearn import datasets, svm, metrics
digits = datasets.load_digits()
images_and_labels = list(zip(digits.images, digits.target))
#數據探索
# for index, (image, label) in enumerate(images_and_labels[-4:]):
# plt.subplot(2, 4, index + 1)
# plt.axis('off')
# plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
# plt.title('Training: %i' % label)
# plt.show()
n_samples = len(digits.images)
print(n_samples+1)
#(1797, 8, 8) to (1797, 64)
data = digits.images.reshape((n_samples, -1))
# 產生SVC分類器
classifier = svm.SVC(gamma=0.001)
# 用前半部份的資料來訓練
mid = int((n_samples+1)/2)
classifier.fit(data[:mid], digits.target[:mid])
# expected = digits.target[mid:]
expected = digits.target
#利用後半部份的資料來測試分類器,共 899筆資料
# predicted = classifier.predict(data[mid:])
predicted = cross_val_predict(classifier, data, digits.target, cv=10) #交叉分類,10折
print("Confusion matrix:\n%s"
% metrics.confusion_matrix(expected, predicted))
def plot_confusion_matrix(cm, title='Confusion matrix', cmap=plt.cm.Blues):
import numpy as np
#自定義imshow http://matplotlib.org/api/pyplot_api.html#matplotlib.pyplot.imshow
# plt.imshow(cm, interpolation='nearest', cmap='cmap')
plt.imshow(cm, interpolation='nearest', cmap='jet')
# plt.imshow(cm, interpolation='nearest', cmap='gray')
plt.title(title)
plt.colorbar()
tick_marks = np.arange(len(digits.target_names)) #自定義刻度顯示數量
plt.xticks(tick_marks, digits.target_names, rotation=45)
plt.yticks(tick_marks, digits.target_names)
plt.tight_layout() # 緊湊顯示圖片
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.figure()
plot_confusion_matrix(metrics.confusion_matrix(expected, predicted))
plt.show()