決策樹算法 MATLAB 簡單實現

決策樹算法

前言

最近在數據挖掘與機器學習的課程上剛剛學到了決策樹算法,於是,想自己用 MATLAB 簡單實現一下。雖然拿其中最簡單算法的進行實現,但是,從構思–編寫–初步完成,也花費了不少時間,畢竟只有動手編寫,才能真正體會到算法的內涵。

1 算法流程

通過閱讀機器學習的書籍首先了解決策樹算法的基本思想:通過遞歸的方式構建一棵樹,子樹是通過選取某一屬性,按照其屬性值進行劃分產生的。其算法僞代碼如下:

在這裏插入圖片描述

2 程序設計

程序設計必須對算法的每個細節都要搞清楚,有時可能要實現一個健全完善的算法很困難,我們可以對算法進行簡化,忽略複雜的情況,比如,在上面的構建決策樹算法的步驟中,子樹的劃分可能有多個輸出,連續屬性和無序離散屬性的劃分的方法也有所不同,如果都要將這些考慮進去程序的設計難度會很大。作爲初學者,可以對問題進行簡化:

  • 假設無序離散屬性都只是二元屬性,屬性值用0或1表示
  • 類別只有兩類,用0或1表示
  • 每個節點只有兩個輸出

在明確了細節之後,還需考慮另外一個問題:數據結構。在程序中用什麼數據結構來描述所構建的“樹”?這一步很關鍵,因爲在對訓練集之外的記錄進行測試的時候要用到該數據結構。

由於自己實現決策樹算法的目的只是加深對算法的理解,並不是實際開發,因此,只是將“樹”的結構和參數打印出來。

function build_tree(x, y, L, level, parent_y, sig, p_value)
% 自編的用於構建決策樹的簡單程序,適用於屬性爲二元屬性,二分類情況。(也可對程序進行修改以適用連續屬性)。
% 輸入:
% x:數值矩陣,樣本屬性記錄(每一行爲一個樣本)
% y:數值向量,樣本對應的labels
% 其它參數調用時可以忽略,在遞歸時起作用。
% 輸出:打印決策樹。
    if nargin == 2
       level = 0; 
       parent_y = -1;
       L = 1:size(x, 2);
       sig = -1;
       p_value = [];
%        bin_f = zeros(size(x, 2), 1);
%        for k=1:size(x, 2)
%            if length(unique(x(:,k))) == 2
%               bin_f(k) = 1; 
%            end
%        end
    end
    class = [0, 1];
    [r, label] = is_leaf(x, y, parent_y); % 判斷是否是葉子節點
    if r   
        if sig ==-1
            disp([repmat('     ', 1, level), 'leaf (', num2str(label), ')']);
        elseif sig ==0
            disp([repmat('     ', 1, level), '<', num2str(p_value),' leaf (', num2str(label), ')']);
        else
            disp([repmat('     ', 1, level), '>', num2str(p_value),' leaf (', num2str(label), ')']);
        end
    else
        [ind, value, i_] = find_best_test(x, y, L); % 找出最佳的測試值
%         
%         if ind ==1
%            keyboard; 
%         end
        
        [x1, y1, x2, y2] = split_(x, y, i_, value); % 實施劃分
        if sig ==-1
            disp([repmat('     ', 1, level), 'node (', num2str(ind), ', ', num2str(value), ')']);
        elseif sig ==0
            disp([repmat('     ', 1, level), '<', num2str(p_value),' node (', num2str(ind), ', ', num2str(value), ')']);
        else
            disp([repmat('     ', 1, level), '>', num2str(p_value),' node (', num2str(ind), ', ', num2str(value), ')']);
        end
%         if bin_f(i_) == 1
            x1(:,i_) = []; 
            x2(:,i_) = [];
            L(:,i_) = [];
%             bin_f(i_) = [];
%         end
        build_tree(x1, y1, L, level+1, y, 0, value); % 地櫃調用
        build_tree(x2, y2, L, level+1, y, 1, value);
    end

    function [ind, value, i_] = find_best_test(xx, yy, LL) % 子函數:找出最佳測試值(可以對連續屬性適用)
        imp_min = inf;
        i_ = 1;
        ind = LL(i_);
        for i=1:size(xx,2);
            if length(unique(xx(:,i))) ==1
                continue;
            end
%            [xx_sorted, ii] = sortrows(xx, i); 
%            yy_sorted = yy(ii, :);
           vv = unique(xx(:,i));
           imp_min_i = inf;
           best_point = mean([vv(1), vv(2)]);
           value = best_point;
           for j = 1:length(vv)-1
               point = mean([vv(j), vv(j+1)]);               
               [xx1, yy1, xx2, yy2] = split_(xx, yy, i, point);
               imp = calc_imp(yy1, yy2);
               if imp<imp_min_i
                   best_point = point;
                   imp_min_i = imp;
               end
           end
           if imp_min_i < imp_min
              value = best_point;
              imp_min = imp_min_i;
              i_ = i;
              ind = LL(i_);
           end
        end
    end
    
    function imp = calc_imp(y1, y2) % 子函數:計算熵
        p11 = sum(y1==class(1))/length(y1);
        p12 = sum(y1==class(2))/length(y1);
        p21 = sum(y2==class(1))/length(y2);
        p22 = sum(y2==class(2))/length(y2);
        if p11==0
            t11 = 0;
        else
           t11 = p11*log2(p11); 
        end
        if p12==0
            t12 = 0;
        else
           t12 = p12*log2(p12); 
        end
        if p21==0
            t21 = 0;
        else
           t21 = p21*log2(p21); 
        end
        if p22==0
            t22 = 0;
        else
           t22 = p22*log2(p22); 
        end
        
        imp = -t11-t12-t21-t22;
    end

    function [x1, y1, x2, y2] = split_(x, y, i, point) % 子函數:實施劃分
       index = (x(:,i)<point);
       x1 = x(index,:);
       y1 = y(index,:);
       x2 = x(~index,:);
       y2 = y(~index,:);
    end
    
    function [r, label] = is_leaf(xx, yy, parent_yy) % 子函數:判斷是否是葉子節點
        if isempty(xx)
            r = true;
            label = mode(parent_yy);
        elseif length(unique(yy)) == 1
            r = true;
            label = unique(yy);
        else
            t = xx - repmat(xx(1,:),size(xx, 1), 1);
            if all(all(t ==0))
                r = true;
                label = mode(yy);
            else
                r = false;
                label = [];
            end
        end
    end
end

利用MATLAB提供的數據集進行測試,並與 MATLAB 自身提供的決策樹分類的函數進行對比。

clc
clear all
load ionosphere % contains X and Y variables
x = X(:,1:3);
ind = x(:,3)>0;
x(ind,3) = 1;
x(~ind,3) = 0;

y = zeros(size(Y));
y(ismember(Y, 'b')) = 1;

ctree = fitctree(x, y);
view(ctree,'mode','graph') % graphic description
% [label, score] = predict(ctree, X(5,:))

build_tree(x, y);

自編程序運行結果

含義說明:

node(屬性序號, 劃分點)

leaf(類別)

在這裏插入圖片描述

MATLAB提供的函數的運行結果

在這裏插入圖片描述

結果與MATLAB中自己實現的函數運行結果相同。

3 MATLAB 中的調用

自己對算法的實現的目的主要還是用於加深對算法的理解,但是在實際應用時,還得藉助成熟的機器學習工具包,比如MATLAB或Python提供的機器學習工具包。下面介紹一下MATLAB中決策樹算法的相關函數的調用方法。

tree = fitctree(x,y) 
tree = fitctree(x,y,Name,Value)

根據給定的記錄的屬性x,對應類別y,構造決策樹(二叉樹)。要求x爲數值矩陣,y爲數值向量或cell數組。name-value pair 爲可選參數,用於指定算法的參數(劃分準則,葉子節點最少記錄值等)。x, y 每一行爲一個樣本。

返回tree爲決策樹的數據結構。

利用tree進行分類:

label = predict(tree, x)

4 Python 中的調用

scikit-learn 庫提供了決策樹分類和迴歸的方法.

訓練

>>> from sklearn import tree
>>> X = [[0, 0], [1, 1]]
>>> Y = [0, 1]
>>> clf = tree.DecisionTreeClassifier()
>>> clf = clf.fit(X, Y)

分類

>>> clf.predict([[2., 2.]])
array([1])

DecisionTreeClassifier能夠進行二元分類(標籤爲[- 1,1])和多類分類(標籤爲[0,…,K-1])。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章