Matlab TreeBagger隨機森林迴歸實例

簡介

這裏是一個在Matlab使用隨機森林(TreeBagger)的例子。隨機森林迴歸是一種機器學習和數據分析領域常用且有效的算法。本文介紹在Matlab平臺如何使用自帶函數和測試數據實現迴歸森林,對於隨機森林和決策樹的相關理論原理將不做太深入的描述。

算法流程

(1)加載Matlab測試數據集;
(2)獲取計算機性能,以便最好地利用其性能;
(3)訓練TreeBagger(隨機森林);
(4)創建散點圖;
(5)估計輸入變量的相對重要性;
(6)檢查需要多少棵樹。

TreeBagger介紹

TreeBagger集成了一組決策樹,用於分類或迴歸。集成中的每棵樹都生長在獨立繪製的輸入數據的引導程序副本上。該副本中未包含的觀察結果對於該樹而言是“無用之物”。

TreeBagger將決策樹用於分類或迴歸。TreeBagger依靠ClassificationTree和 RegressionTree功能來生長單個樹。ClassificationTree和RegressionTree接受爲每個決策拆分隨機選擇的特徵數作爲可選輸入參數。也就是說, TreeBagger實現了隨機森林算法。
對於迴歸問題,TreeBagger支持均值和分位數迴歸(即分位數迴歸森林)。

默認情況下,TreeBagger爲分類樹。要使用迴歸樹,請指定 ‘Method’,‘regression’。

語法

Mdl = TreeBagger(NumTrees,Tbl,ResponseVarName)
Mdl = TreeBagger(NumTrees,Tbl,formula)
Mdl = TreeBagger(NumTrees,Tbl,Y)
B = TreeBagger(NumTrees,X,Y)
B = TreeBagger(NumTrees,X,Y,Name,Value)

描述

Y是響應數據的數組,對於分類問題, Y是一組類標籤。標籤可以是數字或邏輯向量等。對於迴歸問題,Y是一個數值向量。要增長迴歸樹,必須指定名稱-值對 ‘Method’,‘regression’。

若要預測均值響應或估計給定數據的均方誤差,請分別傳遞TreeBagger模型和數據分析。要對袋外觀測數據執行類似的操作,請使用oobPredict或oobError。

要估計給定數據的響應分佈的分位數或分位數誤差,請將TreeBagger模型和數據分別傳遞給quantilePredict或quantileError。要對袋外觀察執行類似的操作,請使用oobQuantilePredict或oobError。

測試數據集下載

https://download.csdn.net/download/wokaowokaowokao12345/12243422

例子1

例子2

%--------------------------------------------------------------------------
clear;clc;close all

%--------------------------------------------------------------------------
% 加載Matlab提供的測試數據,備註:house_dataset數據集新版Matlab沒有
% 如果需要這個數據集可以在我csdn資源中下載:https://download.csdn.net/download/wokaowokaowokao12345/12243422

% load house_dataset
% In = houseInputs';
% Out = houseTargets';
% house_dataset.csv數據可以在我csdn資源中心下載
load house_dataset.csv
In = house_dataset(:,2:end);
Out = house_dataset(:,1);

%下面測試數據可以直接在2018版本Matlab中加載
% load imports-85;
% Out = X(:,1);
% In = X(:,2:end);

%--------------------------------------------------------------------------
% Find capabilities of computer so we can best utilize them.
% 獲取計算機性能,這部分內容可以註釋掉
% Find if gpu is present
ngpus=gpuDeviceCount;
disp([num2str(ngpus) ' GPUs found'])
if ngpus>0
    lgpu=1;
    disp('GPU found')
    useGPU='yes';
else
    lgpu=0;
    disp('No GPU found')
    useGPU='no';
end

% Find number of cores
ncores=feature('numCores');
disp([num2str(ncores) ' cores found'])

% Find number of cpus
import java.lang.*;
r=Runtime.getRuntime;
ncpus=r.availableProcessors;
disp([num2str(ncpus) ' cpus found'])

if ncpus>1
    useParallel='yes';
else
    useParallel='no';
end

[archstr,maxsize,endian]=computer;
disp([...
    'This is a ' archstr ...
    ' computer that can have up to ' num2str(maxsize) ...
    ' elements in a matlab array and uses ' endian ...
    ' byte ordering.'...
    ])

% Set up the size of the parallel pool if necessary
npool=ncores;

% Opening parallel pool
if ncpus>1
    tic
    disp('Opening parallel pool')
    
    % first check if there is a current pool
    poolobj=gcp('nocreate');
    
    % If there is no pool create one
    if isempty(poolobj)
        command=['parpool(' num2str(npool) ');'];
        disp(command);
        eval(command);
    else
        poolsize=poolobj.NumWorkers;
        disp(['A pool of ' poolsize ' workers already exists.'])
    end
    
    % Set parallel options
    paroptions = statset('UseParallel',true);
    toc
    
end

%--------------------------------------------------------------------------
%訓練隨機森林,TreeBagger使用內容,以及設置隨機森林參數
tic
leaf=5;
ntrees=200;
fboot=1;
surrogate='on';
disp('Training the tree bagger')
b = TreeBagger(...
        ntrees,...
        In,Out,... 
        'Method','regression',...
        'oobvarimp','on',...
        'surrogate',surrogate,...
        'minleaf',leaf,...
        'FBoot',fboot,...
        'Options',paroptions...
    );
toc

%--------------------------------------------------------------------------
% Estimate Output using tree bagger
%使用訓練好的模型進行預測
disp('Estimate Output using tree bagger')
x=Out;
y=predict(b, In);
name='Bagged Decision Trees Model';
toc

%--------------------------------------------------------------------------
% calculate the training data correlation coefficient
%計算相關係數
cct=corrcoef(x,y);
cct=cct(2,1);

%--------------------------------------------------------------------------
% Create a scatter Diagram
disp('Create a scatter Diagram')

% plot the 1:1 line
plot(x,x,'LineWidth',3);

hold on
scatter(x,y,'filled');
hold off
grid on

set(gca,'FontSize',18)
xlabel('Actual','FontSize',25)
ylabel('Estimated','FontSize',25)
title(['Training Dataset, R^2=' num2str(cct^2,2)],'FontSize',30)

drawnow

fn='ScatterDiagram';
fnpng=[fn,'.png'];
print('-dpng',fnpng);

%--------------------------------------------------------------------------
% Calculate the relative importance of the input variables
tic
disp('Sorting importance into descending order')
weights=b.OOBPermutedVarDeltaError;
[B,iranked] = sort(weights,'descend');
toc

%--------------------------------------------------------------------------
disp(['Plotting a horizontal bar graph of sorted labeled weights.']) 

%--------------------------------------------------------------------------
figure
barh(weights(iranked),'g');
xlabel('Variable Importance','FontSize',30,'Interpreter','latex');
ylabel('Variable Rank','FontSize',30,'Interpreter','latex');
title(...
    ['Relative Importance of Inputs in estimating Redshift'],...
    'FontSize',17,'Interpreter','latex'...
    );
hold on
barh(weights(iranked(1:10)),'y');
barh(weights(iranked(1:5)),'r');

%--------------------------------------------------------------------------
grid on 
xt = get(gca,'XTick');    
xt_spacing=unique(diff(xt));
xt_spacing=xt_spacing(1);    
yt = get(gca,'YTick');    
ylim([0.25 length(weights)+0.75]);
xl=xlim;
xlim([0 2.5*max(weights)]);

%--------------------------------------------------------------------------
% Add text labels to each bar
for ii=1:length(weights)
    text(...
        max([0 weights(iranked(ii))+0.02*max(weights)]),ii,...
        ['Column ' num2str(iranked(ii))],'Interpreter','latex','FontSize',11);
end

%--------------------------------------------------------------------------
set(gca,'FontSize',16)
set(gca,'XTick',0:2*xt_spacing:1.1*max(xl));
set(gca,'YTick',yt);
set(gca,'TickDir','out');
set(gca, 'ydir', 'reverse' )
set(gca,'LineWidth',2);   
drawnow

%--------------------------------------------------------------------------
fn='RelativeImportanceInputs';
fnpng=[fn,'.png'];
print('-dpng',fnpng);

%--------------------------------------------------------------------------
% Ploting how weights change with variable rank
disp('Ploting out of bag error versus the number of grown trees')

figure
plot(b.oobError,'LineWidth',2);
xlabel('Number of Trees','FontSize',30)
ylabel('Out of Bag Error','FontSize',30)
title('Out of Bag Error','FontSize',30)
set(gca,'FontSize',16)
set(gca,'LineWidth',2);   
grid on
drawnow
fn='EroorAsFunctionOfForestSize';
fnpng=[fn,'.png'];
print('-dpng',fnpng);


實驗結果

模型的相關係數
輸入變量的重要性
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章