模型持久化(模型保存與加載)是機器學習完成的最後一步。
因爲,在實際情況中,訓練一個模型可能會非常耗時,如果每次需要使用模型時都要重新訓練,這無疑會浪費大量的計算資源和時間。
通過將訓練好的模型持久化到磁盤,我們可以在需要使用模型時直接從磁盤加載到內存,而無需重新訓練。這樣不僅可以節省時間,還可以提高模型的使用效率。
本篇介紹scikit-learn
中幾種常用的模型持久化方法。
1. 訓練模型
首先,訓練一個模型,這裏用scikit-learn
自帶的手寫數字數據集作爲樣本。
import matplotlib.pyplot as plt
from sklearn import datasets
# 加載手寫數據集
data = datasets.load_digits()
# 調整數據格式
n_samples = len(data.images)
X = data.images.reshape((n_samples, -1))
y = data.target
# 用支持向量機訓練模型
from sklearn.svm import SVC
# 定義
reg = SVC()
# 訓練模型
reg.fit(X, y)
最後的得到的 reg
就是我們訓練之後的模型,使用這個模型,就可以預測一些手寫數字圖片。
但是這個 reg
是代碼中的一個變量,如果不能保存下來,那麼,每次需要使用的時候,
還要重新執行一次上面的模型訓練代碼,樣本數據量大的話,每次重複訓練會浪費大量時間和計算資源。
所以,要將上面的 reg
模型保存下來,下次使用的時候,直接加載,不用重新訓練。
2. 模型持久化
2.1. pickle 序列化
pickle
格式是python
中常用的序列化方式,它通過將python對象及其所擁有的層次結構轉化爲一個字節流來實現序列化。
將上面的模型保存到磁盤文件model.pkl
中。
import pickle
with open("./model.pkl", "wb") as f:
pickle.dump(reg, f)
需要使用模型時,從磁盤加載的方式:
with open("./model.pkl", "rb") as f:
reg_pkl = pickle.load(f)
驗證加載之後的模型reg_pkl
是否可以正常使用。
y_pred = reg_pkl.predict(X)
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
cm = confusion_matrix(y, y_pred)
g = ConfusionMatrixDisplay(confusion_matrix=cm)
g.plot()
plt.show()
從混淆矩陣來看,模型可以正常加載和使用。
關於混淆矩陣具體內容,可以參考:【scikit-learn基礎】--『分類模型評估』之評估報告
2.2. joblib 序列化
相比於pickle
,保存機器學習模型時,更推薦使用joblib
。
因爲joblib
針對大數據進行了優化,使其在處理大型數據集時性能更佳。
序列化的方式也很簡單:
import joblib
joblib.dump(reg, "model.jlib")
從磁盤加載模型並驗證:
reg_jlib = joblib.load("model.jlib")
y_pred = reg_jlib.predict(X)
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
cm = confusion_matrix(y, y_pred)
g = ConfusionMatrixDisplay(confusion_matrix=cm)
g.plot()
plt.show()
2.3. skops 格式
skops是比較新的一種格式,它是專門爲了共享基於 scikit-learn
的模型而開發的。
目前還在積極的開發中,github上的地址是:github-skops。
相比於pickle
和joblib
,它提供了更加安全的序列化格式,
但使用上和它們差別不大。
import skops.io as sio
# 保存到文件 model.sio
sio.dump(reg, "model.sio")
從文件中讀取模型並驗證:
reg_sio = sio.load("model.sio")
y_pred = reg_jlib.predict(X)
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
cm = confusion_matrix(y, y_pred)
g = ConfusionMatrixDisplay(confusion_matrix=cm)
g.plot()
plt.show()
3. 總結
在scikit-learn
中,模型持久化是一個重要且實用的技術,它允許我們將訓練好的模型保存到磁盤上,以便在不同的時間點或不同的環境中重新加載和使用。
通過模型持久化,我們能夠避免每次需要使用時重新訓練模型,從而節省大量的時間和計算資源。
本篇介紹的三種方法可以方便的序列化和反序列化模型對象,使其可以輕鬆地保存到磁盤上,並能夠在需要時恢復出原始模型對象。
總而言之,模型持久化不僅使得我們能夠在不同的運行會話之間重用模型,還方便了模型的共享和部署。