Matlab利用分位數誤差和貝葉斯優化調整隨機森林

簡介

本示例說明如何使用分位數誤差實現貝葉斯優化以調整迴歸樹的隨機森林的超參數。 如果計劃使用模型來預測條件分位數而不是條件均值,則使用分位數誤差而不是均方誤差來調整模型是合適的。查找關於樹複雜性和要使用貝葉斯優化在每個節點上採樣的預測變量數量,實現最小,受罰的袋外分位數誤差的模型。 將期望的改進加功能指定爲獲取功能。

加載和預處理數據

加載carsmall數據集。 假設一個模型,該模型根據加速度、汽缸數、發動機排量、馬力、製造商、年和重量來預測汽車的平均燃油經濟性。 將Cylinders,Mfg和Model_Year視爲類別變量。

例子

clc
clear all
close all

load carsmall
% 將Cylinders、Mfg和Model_Year視爲類別變量
Cylinders = categorical(Cylinders);
Mfg = categorical(cellstr(Mfg));
Model_Year = categorical(Model_Year);
X = table(Acceleration,Cylinders,Displacement,Horsepower,Mfg,...
    Model_Year,Weight,MPG);
rng('default'); % For reproducibility

% 調整參數,考慮調整:
% 森林中樹木的複雜程度(深度)。 深樹傾向於過度擬合,而淺樹傾向於欠擬合。 因此,指定每片葉子的最少觀察數爲20% 生長樹木時,在每個節點上採樣的預測變量的數量。 指定從1到所有預測變量的採樣。
maxMinLS = 20;
minLS = optimizableVariable('minLS',[1,maxMinLS],'Type','integer');
numPTS = optimizableVariable('numPTS',[1,size(X,2)-1],'Type','integer');
hyperparametersRF = [minLS; numPTS];

% bayesopt是實現貝葉斯優化的函數,要求您將這些規範作爲optimizableVariable對象傳遞。
% hyperparametersRF是OptimizableVariable對象的21數組。
% 還應該考慮調整集合中的樹數。 bayesopt傾向於選擇包含許多樹木的隨機森林,因爲會更準確。 
% 如果考慮到可用的計算資源,並且您希望使用較少的樹,則可以考慮與其他參數分開調整樹的數量,或者對包含許多學習者的模型進行懲罰。
% 結果是一個BayesianOptimization對象,該對象除其他外包含目標函數的最小值和優化的超參數值。
% 顯示觀察到的目標函數最小值和優化的超參數值。
results = bayesopt(@(params)oobErrRF(params,X),hyperparametersRF,...
    'AcquisitionFunctionName','expected-improvement-plus','Verbose',0);

bestOOBErr = results.MinObjective
bestHyperparameters = results.XAtMinObjective
% 使用優化的超參數訓練模型
% 使用整個數據集和優化的超參數值訓練隨機森林。
% Mdl是針對中位數預測優化的TreeBagger對象。 
% 您可以通過將Mdl和新數據傳遞給QuantilePredict,在給定預測器數據的情況下預測平均燃油經濟性。
Mdl = TreeBagger(300,X,'MPG','Method','regression',...
    'MinLeafSize',bestHyperparameters.minLS,...
    'NumPredictorstoSample',bestHyperparameters.numPTS);

% 定義目標函數
% 爲貝葉斯優化算法定義目標函數以進行優化。 該功能應:
% 接受要調諧的參數作爲輸入。
% 使用TreeBagger訓練一個隨機森林。 在TreeBagger調用中,指定要調整的參數,並指定返回袋外索引。
% 根據中位數估算袋外分位數誤差。
% 返回袋外分位數誤差。
% oobErrRF訓練隨機森林並估計袋外分位數誤差oobErr使用X中的預測變量數據和參數中的參數指定來訓練300個迴歸樹的隨機森林,
% 然後根據中位數返回袋外分位數誤差 。 X是一個表,params是一個OptimizableVariable對象的數組,對應於最小葉子大小和要在每個節點上採樣的預測變量數量。
function oobErr = oobErrRF(params,X)
randomForest = TreeBagger(300,X,'MPG','Method','regression',...
    'OOBPrediction','on','MinLeafSize',params.minLS,...
    'NumPredictorstoSample',params.numPTS);
oobErr = oobQuantileError(randomForest);
end

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章