【scikit-learn基礎】--模型持久化

模型持久化(模型保存與加載)是機器學習完成的最後一步。
因爲,在實際情況中,訓練一個模型可能會非常耗時,如果每次需要使用模型時都要重新訓練,這無疑會浪費大量的計算資源和時間。

通過將訓練好的模型持久化到磁盤,我們可以在需要使用模型時直接從磁盤加載到內存,而無需重新訓練。這樣不僅可以節省時間,還可以提高模型的使用效率。

本篇介紹scikit-learn中幾種常用的模型持久化方法。

1. 訓練模型

首先,訓練一個模型,這裏用scikit-learn自帶的手寫數字數據集作爲樣本。

import matplotlib.pyplot as plt
from sklearn import datasets

# 加載手寫數據集
data = datasets.load_digits()

# 調整數據格式
n_samples = len(data.images)
X = data.images.reshape((n_samples, -1))
y = data.target

# 用支持向量機訓練模型
from sklearn.svm import SVC

# 定義
reg = SVC()

# 訓練模型
reg.fit(X, y)

最後的得到的 reg 就是我們訓練之後的模型,使用這個模型,就可以預測一些手寫數字圖片。

但是這個 reg 是代碼中的一個變量,如果不能保存下來,那麼,每次需要使用的時候,
還要重新執行一次上面的模型訓練代碼,樣本數據量大的話,每次重複訓練會浪費大量時間和計算資源。

所以,要將上面的 reg 模型保存下來,下次使用的時候,直接加載,不用重新訓練。

2. 模型持久化

2.1. pickle 序列化

pickle格式是python中常用的序列化方式,它通過將python對象及其所擁有的層次結構轉化爲一個字節流來實現序列化。

將上面的模型保存到磁盤文件model.pkl中。

import pickle

with open("./model.pkl", "wb") as f:
    pickle.dump(reg, f)

需要使用模型時,從磁盤加載的方式:

with open("./model.pkl", "rb") as f:
    reg_pkl = pickle.load(f)

驗證加載之後的模型reg_pkl是否可以正常使用。

y_pred = reg_pkl.predict(X)

from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

cm = confusion_matrix(y, y_pred)
g = ConfusionMatrixDisplay(confusion_matrix=cm)
g.plot()

plt.show()

image.png
從混淆矩陣來看,模型可以正常加載和使用。
關於混淆矩陣具體內容,可以參考:【scikit-learn基礎】--『分類模型評估』之評估報告

2.2. joblib 序列化

相比於pickle,保存機器學習模型時,更推薦使用joblib
因爲joblib針對大數據進行了優化,使其在處理大型數據集時性能更佳。

序列化的方式也很簡單:

import joblib

joblib.dump(reg, "model.jlib")

從磁盤加載模型並驗證:

reg_jlib = joblib.load("model.jlib")

y_pred = reg_jlib.predict(X)

from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

cm = confusion_matrix(y, y_pred)
g = ConfusionMatrixDisplay(confusion_matrix=cm)
g.plot()

plt.show()

image.png

2.3. skops 格式

skops是比較新的一種格式,它是專門爲了共享基於 scikit-learn 的模型而開發的。
目前還在積極的開發中,github上的地址是:github-skops

相比於picklejoblib,它提供了更加安全的序列化格式,
但使用上和它們差別不大。

import skops.io as sio

# 保存到文件 model.sio
sio.dump(reg, "model.sio")

從文件中讀取模型並驗證:

reg_sio = sio.load("model.sio")

y_pred = reg_jlib.predict(X)

from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

cm = confusion_matrix(y, y_pred)
g = ConfusionMatrixDisplay(confusion_matrix=cm)
g.plot()

plt.show()

image.png

3. 總結

scikit-learn中,模型持久化是一個重要且實用的技術,它允許我們將訓練好的模型保存到磁盤上,以便在不同的時間點或不同的環境中重新加載和使用。
通過模型持久化,我們能夠避免每次需要使用時重新訓練模型,從而節省大量的時間和計算資源。

本篇介紹的三種方法可以方便的序列化和反序列化模型對象,使其可以輕鬆地保存到磁盤上,並能夠在需要時恢復出原始模型對象。

總而言之,模型持久化不僅使得我們能夠在不同的運行會話之間重用模型,還方便了模型的共享和部署。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章