XGBoost模型保存與讀取(多分類問題)

 在用XGBClassifier做多分類問題模型存取時,採用save_model與load_model函數發現並不是很好用,因此通過pickle進行模型的存取工作,在此記錄,以備後用。

import pickle
from xgboost import XGBClassifier

#train

model_xg = XGBClassifier(
        n_estimators=20,
        learning_rate=0.1,
        max_depth=8,
        subsample=0.8,
        early_stopping_rounds = 50,
        objective='multi:softmax',
        eval_metric = 'mlogloss')
model_xg.fit(x_train, y_train,verbose=True)

# save
pickle.dump(model_xg, open("xgb.pkl", "wb"))

# load
xgb_model_loaded = pickle.load(open("xgb.pkl", "rb"))

# test
xgb_model_loaded.predict(test)

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章