Python利用 SVM 算法實現(xiàn)識別手寫數(shù)字
前言
支持向量機(jī) (Support Vector Machine, SVM) 是一種監(jiān)督學(xué)習(xí)技術(shù),它通過根據(jù)指定的類對訓(xùn)練數(shù)據(jù)進(jìn)行最佳分離,從而在高維空間中構(gòu)建一個或一組超平面。在博文《OpenCV-Python實戰(zhàn)(13)——OpenCV與機(jī)器學(xué)習(xí)的碰撞》中,我們已經(jīng)學(xué)習(xí)了如何在 OpenCV 中實現(xiàn)和訓(xùn)練 SVM 算法,同時通過簡單的示例了解了如何使用 SVM 算法。在本文中,我們將學(xué)習(xí)如何使用 SVM 分類器執(zhí)行手寫數(shù)字識別,同時也將探索不同的參數(shù)對于模型性能的影響,以獲取具有最佳性能的 SVM 分類器。
使用 SVM 進(jìn)行手寫數(shù)字識別
我們已經(jīng)在《利用 KNN 算法識別手寫數(shù)字》中介紹了 MNIST 手寫數(shù)字?jǐn)?shù)據(jù)集,以及如何利用 KNN 算法識別手寫數(shù)字。并通過對數(shù)字圖像進(jìn)行預(yù)處理( desew() 函數(shù))并使用高級描述符( HOG 描述符)作為用于描述每個數(shù)字的特征向量來獲得最佳分類準(zhǔn)確率。因此,對于相同的內(nèi)容不再贅述,接下來將直接使用在《利用 KNN 算法識別手寫數(shù)字》中介紹預(yù)處理和 HOG 特征,利用 SVM 算法對數(shù)字圖像進(jìn)行分類。
首先加載數(shù)據(jù),并將其劃分為訓(xùn)練集和測試集:
# 加載數(shù)據(jù)
(train_dataset, train_labels), (test_dataset, test_labels) = keras.datasets.mnist.load_data()
SIZE_IMAGE = train_dataset.shape[1]
train_labels = np.array(train_labels, dtype=np.int32)
# 預(yù)處理函數(shù)
def deskew(img):
m = cv2.moments(img)
if abs(m['mu02']) < 1e-2:
return img.copy()
skew = m['mu11'] / m['mu02']
M = np.float32([[1, skew, -0.5 * SIZE_IMAGE * skew], [0, 1, 0]])
img = cv2.warpAffine(img, M, (SIZE_IMAGE, SIZE_IMAGE), flags=cv2.WARP_INVERSE_MAP | cv2.INTER_LINEAR)
return img
# HOG 高級描述符
def get_hog():
hog = cv2.HOGDescriptor((SIZE_IMAGE, SIZE_IMAGE), (8, 8), (4, 4), (8, 8), 9, 1, -1, 0, 0.2, 1, 64, True)
print("hog descriptor size: {}".format(hog.getDescriptorSize()))
return hog
# 數(shù)據(jù)打散
shuffle = np.random.permutation(len(train_dataset))
train_dataset, train_labels = train_dataset[shuffle], train_labels[shuffle]
hog = get_hog()
hog_descriptors = []
for img in train_dataset:
hog_descriptors.append(hog.compute(deskew(img)))
hog_descriptors = np.squeeze(hog_descriptors)
results = defaultdict(list)
# 數(shù)據(jù)劃分
split_values = np.arange(0.1, 1, 0.1)
接下來,初始化 SVM,并進(jìn)行訓(xùn)練:
# 模型初始化函數(shù)
def svm_init(C=12.5, gamma=0.50625):
model = cv2.ml.SVM_create()
model.setGamma(gamma)
model.setC(C)
model.setKernel(cv2.ml.SVM_RBF)
model.setType(cv2.ml.SVM_C_SVC)
model.setTermCriteria((cv2.TERM_CRITERIA_MAX_ITER, 100, 1e-6))
return model
# 模型訓(xùn)練函數(shù)
def svm_train(model, samples, responses):
model.train(samples, cv2.ml.ROW_SAMPLE, responses)
return model
# 模型預(yù)測函數(shù)
def svm_predict(model, samples):
return model.predict(samples)[1].ravel()
# 模型評估函數(shù)
def svm_evaluate(model, samples, labels):
predictions = svm_predict(model, samples)
acc = (labels == predictions).mean()
print('Percentage Accuracy: %.2f %%' % (acc * 100))
return acc *100
# 使用不同訓(xùn)練集、測試集劃分方法進(jìn)行訓(xùn)練和測試
for split_value in split_values:
partition = int(split_value * len(hog_descriptors))
hog_descriptors_train, hog_descriptors_test = np.split(hog_descriptors, [partition])
labels_train, labels_test = np.split(train_labels, [partition])
print('Training SVM model ...')
model = svm_init(C=12.5, gamma=0.50625)
svm_train(model, hog_descriptors_train, labels_train)
print('Evaluating model ... ')
acc = svm_evaluate(model, hog_descriptors_test, labels_test)
results['svm'].append(acc)

從上圖所示,使用默認(rèn)參數(shù)的 SVM 模型在使用 70% 的數(shù)字圖像訓(xùn)練算法時準(zhǔn)確率可以達(dá)到 98.60%,接下來我們通過修改 SVM 模型的參數(shù) C 和 γ 來測試模型是否還有提升空間。
參數(shù) C 和 γ 對識別手寫數(shù)字精確度的影響
SVM 模型在使用 RBF 核時,有兩個重要參數(shù)——C 和 γ,上例中我們使用 C=12.5 和 γ=0.50625 作為參數(shù)值,C 和 γ 的設(shè)定依賴于特定的數(shù)據(jù)集。因此,必須使用某種方法進(jìn)行參數(shù)搜索,本例中使用網(wǎng)格搜索合適的參數(shù) C 和 γ。
for C in [1, 10, 100, 1000]:
for gamma in [0.1, 0.15, 0.25, 0.3, 0.35, 0.45, 0.5, 0.65]:
model = svm_init(C, gamma)
svm_train(model, hog_descriptors_train, labels_train)
acc = svm_evaluate(model, hog_descriptors_test, labels_test)
print(" {}".format("%.2f" % acc))
results[C].append(acc)
最后,可視化結(jié)果:
fig = plt.figure(figsize=(10, 6))
plt.suptitle("SVM handwritten digits recognition", fontsize=14, fontweight='bold')
ax = plt.subplot(1, 1, 1)
ax.set_xlim(0, 0.65)
dim = [0.1, 0.15, 0.25, 0.3, 0.35, 0.45, 0.5, 0.65]
for key in results:
ax.plot(dim, results[key], linestyle='--', marker='o', label=str(key))
plt.legend(loc='upper left', title="C")
plt.title('Accuracy of the SVM model varying both C and gamma')
plt.xlabel("gamma")
plt.ylabel("accuracy")
plt.show()
程序的運(yùn)行結(jié)果如下所示:

如圖所示,通過使用不同參數(shù),準(zhǔn)確率可以達(dá)到 99.25% 左右。通過比較 KNN 分類器和 SVM 分類器在手寫數(shù)字識別任務(wù)中的表現(xiàn),我們可以得出在手寫數(shù)字識別任務(wù)中 SVM 優(yōu)于 KNN 分類器的結(jié)論。
完整代碼
程序的完整代碼如下所示:
import cv2
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict
import keras
(train_dataset, train_labels), (test_dataset, test_labels) = keras.datasets.mnist.load_data()
SIZE_IMAGE = train_dataset.shape[1]
train_labels = np.array(train_labels, dtype=np.int32)
def deskew(img):
m = cv2.moments(img)
if abs(m['mu02']) < 1e-2:
return img.copy()
skew = m['mu11'] / m['mu02']
M = np.float32([[1, skew, -0.5 * SIZE_IMAGE * skew], [0, 1, 0]])
img = cv2.warpAffine(img, M, (SIZE_IMAGE, SIZE_IMAGE), flags=cv2.WARP_INVERSE_MAP | cv2.INTER_LINEAR)
return img
def get_hog():
hog = cv2.HOGDescriptor((SIZE_IMAGE, SIZE_IMAGE), (8, 8), (4, 4), (8, 8), 9, 1, -1, 0, 0.2, 1, 64, True)
print("hog descriptor size: {}".format(hog.getDescriptorSize()))
return hog
def svm_init(C=12.5, gamma=0.50625):
model = cv2.ml.SVM_create()
model.setGamma(gamma)
model.setC(C)
model.setKernel(cv2.ml.SVM_RBF)
model.setType(cv2.ml.SVM_C_SVC)
model.setTermCriteria((cv2.TERM_CRITERIA_MAX_ITER, 100, 1e-6))
return model
def svm_train(model, samples, responses):
model.train(samples, cv2.ml.ROW_SAMPLE, responses)
return model
def svm_predict(model, samples):
return model.predict(samples)[1].ravel()
def svm_evaluate(model, samples, labels):
predictions = svm_predict(model, samples)
acc = (labels == predictions).mean()
return acc * 100
# 數(shù)據(jù)打散
shuffle = np.random.permutation(len(train_dataset))
train_dataset, train_labels = train_dataset[shuffle], train_labels[shuffle]
# 使用 HOG 描述符
hog = get_hog()
hog_descriptors = []
for img in train_dataset:
hog_descriptors.append(hog.compute(deskew(img)))
hog_descriptors = np.squeeze(hog_descriptors)
# 訓(xùn)練數(shù)據(jù)與測試數(shù)據(jù)劃分
partition = int(0.9 * len(hog_descriptors))
hog_descriptors_train, hog_descriptors_test = np.split(hog_descriptors, [partition])
labels_train, labels_test = np.split(train_labels, [partition])
print('Training SVM model ...')
results = defaultdict(list)
for C in [1, 10, 100, 1000]:
for gamma in [0.1, 0.15, 0.25, 0.3, 0.35, 0.45, 0.5, 0.65]:
model = svm_init(C, gamma)
svm_train(model, hog_descriptors_train, labels_train)
acc = svm_evaluate(model, hog_descriptors_test, labels_test)
print(" {}".format("%.2f" % acc))
results[C].append(acc)
fig = plt.figure(figsize=(10, 6))
plt.suptitle("SVM handwritten digits recognition", fontsize=14, fontweight='bold')
ax = plt.subplot(1, 1, 1)
ax.set_xlim(0, 0.65)
dim = [0.1, 0.15, 0.25, 0.3, 0.35, 0.45, 0.5, 0.65]
for key in results:
ax.plot(dim, results[key], linestyle='--', marker='o', label=str(key))
plt.legend(loc='upper left', title="C")
plt.title('Accuracy of the SVM model varying both C and gamma')
plt.xlabel("gamma")
plt.ylabel("accuracy")
plt.show()
以上就是Python利用 SVM 算法實現(xiàn)識別手寫數(shù)字的詳細(xì)內(nèi)容,更多關(guān)于Python SVM算法識別手寫數(shù)字的資料請關(guān)注腳本之家其它相關(guān)文章!
相關(guān)文章
基于Python實現(xiàn)虛假評論檢測可視化系統(tǒng)
這篇文章主要為大家詳細(xì)介紹了如何基于Python實現(xiàn)一個簡單的虛假評論檢測可視化系統(tǒng),文中的示例代碼講解詳細(xì),感興趣的小伙伴可以了解一下2023-04-04
Django后臺管理系統(tǒng)的圖文使用教學(xué)
在本篇文章里小編給大家整理的是一篇關(guān)于Django后臺管理系統(tǒng)的圖文使用教學(xué)內(nèi)容,需要的朋友們參考下。2020-01-01
python?tkinter庫的Text記錄點(diǎn)擊路經(jīng)和刪除記錄詳情
這篇文章主要介紹了python?tkinter庫的Text記錄點(diǎn)擊路經(jīng)和刪除記錄詳情,文章圍繞主題展開詳細(xì)的內(nèi)容介紹,具有一定的參考價值,感興趣的小伙伴可以參考一下2022-06-06
Python+wxPython實現(xiàn)個人鏈接收藏夾
這篇文章主要介紹了如何使用wxPython和XML數(shù)據(jù)源創(chuàng)建一個具有按鈕和Web視圖的應(yīng)用程序窗口,以便輕松管理和訪問各種網(wǎng)頁鏈接,感興趣的可以了解下2023-08-08

