R語言筆記之樹模型(迴歸樹和決策樹)

迴歸樹過程:

例:有10名學生,他們的身高分佈如下:
R1:
女生(7):156,167,165,163,160,170,160
R2:
男生(3):172,180,176
那麼,落入R1的樣本均值爲163,落入R2的樣本均值爲176,那麼對於新樣本,如果是女生,樹模型預測的身高是163,是男生,則爲176.
那麼如何劃分出區域R1,R2(建造樹模型)?
需要使用自上到下的貪婪算法—–遞歸二元分割,即從根節點逐步向下分隔,每次產生兩個樹枝(二元分割)
R中可以建造迴歸樹的包:ctree,rpart,tree

> library(rpart)
> library(tree)
Error in library(tree) : 不存在叫‘tree’這個名字的程輯包
> install.packages("tree")
trying URL 'https://cran.rstudio.com/bin/windows/contrib/3.4/tree_1.0-37.zip'
Content type 'application/zip' length 122090 bytes (119 KB)
downloaded 119 KB

package ‘tree’ successfully unpacked and MD5 sums checked

The downloaded binary packages are in
    C:\Users\LLJiang\AppData\Local\Temp\RtmpmMgvpx\downloaded_packages
> library(tree)
Warning message:
程輯包‘tree’是用R版本3.4.3 來建造的 
> dat=read.csv("https://raw.githubusercontent.com/happyrabbit/DataScientistR/master/Data/SegData.csv")
> dat=subset(dat,store_exp>0&online_exp>0)
> trainx=dat[,grep("Q",names(dat))]
> trainy=dat$store_exp+dat$online_exp
> set.seed(100)
> rpartTrue=train(trainx,trainy,method="rpart2",tuneLength=10,trControl=trainControl(method="cv"))
> plot(rpartTrue)
> 

這裏寫圖片描述
如上圖,樹的最大深度大於2,RMSE就不再變化了,這裏我們就用深度2來建立樹

> rpartTrue=rpart(trainy~.,data=trainx,maxdepth=2)
> print(rpartTrue)
n= 999 

node), split, n, deviance, yval
      * denotes terminal node

1) root 999 15812720000  3479.113  
  2) Q3< 3.5 799  2373688000  1818.720  
    4) Q5< 1.5 250     3534392   705.193 *
    5) Q5>=1.5 549  1919009000  2325.791 *
  3) Q3>=3.5 200  2436211000 10112.380 *
> 

Q3,Q5被最終預測總花銷
以下爲對rpart生成的樹繪製圖形

> library(partykit)
載入需要的程輯包:grid
載入需要的程輯包:libcoin
載入需要的程輯包:mvtnorm
Warning messages:
1: 程輯包‘partykit’是用R版本3.4.3 來建造的 
2: 程輯包‘libcoin’是用R版本3.4.3 來建造的 
> rpartTrue2=as.party(rpartTrue)
> plot(rpartTrue2)
> 

這裏寫圖片描述

決策樹

其目標是把數據劃分爲更小,同質性更強的組
與迴歸樹不同在於,因變量是分類變量而不是數值。
故預測並不是基於平均而是基於每個類別樣本的頻數。葉節點的預測值就是落入相應區域訓練集樣本中頻數最高的類別。
其分裂準則不是RSS,而是熵或者Gini係數
當自變量是連續型時,確定最佳分裂點的劃分過程很之間,然而當自變量是分類型時,有兩種處理方式
1.不對分類變量進行變換,每個分類型自變量作爲單獨的個體輸入到模型中。
2.分類型自變量先被重新編碼爲二元虛擬變量,這樣講類別信息分解成獨立信息塊。
如果某些類對結果又強預測性,第一種方法更合適。
下面我們用不同方法對服裝消費者性別進行判定(分類模型)

> library(pROC)
Type 'citation("pROC")' for a citation.

載入程輯包:‘pROC’

The following objects are masked from ‘package:stats’:

    cov, smooth, var

Warning message:
程輯包‘pROC’是用R版本3.4.3 來建造的 
> dat=read.csv("https://raw.githubusercontent.com/happyrabbit/DataScientistR/master/Data/SegData.csv")
#將10個問卷調查變量當做自變量

> trainx1=dat[,grep("Q",names(dat))]
#將類別也作爲自變量
#用兩種方法編碼分類變量
#trainx1不對消費者類別進行變換

> trainx1$segment=dat$segment
 #trainx2中消費者類別被轉換成虛擬變量

> dumMod=dummyVars(~.,data=trainx1,levelsOnly=F)
#用原變量名加上因子層級的名稱作爲新的名義變量名
> trainx2=predict(dumMod,trainx1)
#性別作爲因變量
> trainy=dat$gender

不對分類變量進行編碼,cp指複雜度參數,是樹生長的停止準則,cp=0.01意味着相應分裂度量(熵,Gini)每一步分裂都需要比之前提高0.01,在交互校檢結果中不滿0.01提升的部分會被修剪掉

> set.seed(100)
> rpartTune1=caret::train(trainx1,trainy,method="rpart",tuneLength=30,metric="ROC",trControl=trainControl(method="cv",summaryFunction=twoClassSummary,classProbs=TRUE,savePredictions=TRUE))
> rpartTune1
CART 

1000 samples
  11 predictor
   2 classes: 'Female', 'Male' 

No pre-processing
Resampling: Cross-Validated (10 fold) 
Summary of sample sizes: 901, 899, 900, 900, 901, 900, ... 
Resampling results across tuning parameters:

  cp           ROC        Sens       Spec     
  0.000000000  0.6936668  0.6516883  0.6883838
  0.008350085  0.7026106  0.6118506  0.7354545
  0.016700170  0.6851629  0.5324351  0.8204545
  0.025050255  0.6802976  0.5107468  0.8498485
  0.033400340  0.6802976  0.5107468  0.8498485
  0.041750425  0.6802976  0.5107468  0.8498485
  0.050100510  0.6802976  0.5107468  0.8498485
  0.058450595  0.6802976  0.5107468  0.8498485
  0.066800680  0.6802976  0.5107468  0.8498485
  0.075150765  0.6802976  0.5107468  0.8498485
  0.083500850  0.6802976  0.5107468  0.8498485
  0.091850936  0.6802976  0.5107468  0.8498485
  0.100201021  0.6802976  0.5107468  0.8498485
  0.108551106  0.6802976  0.5107468  0.8498485
  0.116901191  0.6802976  0.5107468  0.8498485
  0.125251276  0.6802976  0.5107468  0.8498485
  0.133601361  0.6802976  0.5107468  0.8498485
  0.141951446  0.6802976  0.5107468  0.8498485
  0.150301531  0.6802976  0.5107468  0.8498485
  0.158651616  0.6802976  0.5107468  0.8498485
  0.167001701  0.6802976  0.5107468  0.8498485
  0.175351786  0.6802976  0.5107468  0.8498485
  0.183701871  0.6802976  0.5107468  0.8498485
  0.192051956  0.6802976  0.5107468  0.8498485
  0.200402041  0.6802976  0.5107468  0.8498485
  0.208752126  0.6802976  0.5107468  0.8498485
  0.217102211  0.6802976  0.5107468  0.8498485
  0.225452296  0.6802976  0.5107468  0.8498485
  0.233802381  0.6340747  0.5936039  0.6745455
  0.242152466  0.5556313  0.7872727  0.3239899

ROC was used to select the optimal model using the largest value.
The final value used for the model was cp = 0.008350085.
> 

對將分類變量進行編碼後的數據集進行訓練

> rpartTune2=caret::train(trainx2,trainy,method="rpart",tuneLength=30,metric="ROC",trControl=trainControl(method="cv",summaryFunction=twoClassSummary,classProbs=TRUE,savePredictions=TRUE))
> rpartRoc=roc(response=rpartTune1$pred$obs,predictor=rpartTune1$pred$Female,levels=rev(levels(rpartTune$pred$obs)))
Error in levels(rpartTune$pred$obs) : object 'rpartTune' not found
> rpartRoc=roc(response=rpartTune1$pred$obs,predictor=rpartTune1$pred$Female,levels=rev(levels(rpartTune1$pred$obs)))
> rpartFactorRoc=roc(response=rpartTune2$pred$obs,predictor=rpartTune2$pred$Female,levels=rev(levels(rpartTune1$pred$obs)))
> plot(rpartRoc,type="s",print.thres=c(.5),print.thres.pch=3,print.thres.pattern="",print.thres.cex=1.2,col="red",legacy.axes=TRUE,print.thres.col="red")

這裏寫圖片描述

> plot(rpartFactorRoc,type="s",add=TRUE,print.thres=c(.5),print.thres.pch=16,legacy.axes=TRUE,print.thres.pattern="",print.thres.cex=1.2)

這裏寫圖片描述

> legend(.75,.2,c("Grouped Categories","Independent Categories"),lwd=c(1,1),col=c("black","red"),pch=c(16,3))
> 

這裏寫圖片描述
由上圖可以看出,對於使用CART構建的樹,對消費者類別變量編碼或者不編碼並不影響對受訪者性別做預測
下面我們通過partykit包對最終模型繪製圖形。

> library(partykit)
> plot(as.party(rpartTune2$finalModel))
> 

這裏寫圖片描述

單棵樹很直觀,容易解釋,但它有兩個缺點:
1.和很多回歸模型相比精確度差
2.非常不穩定,數據微小的變化會導致模型結果很大的變化。

發佈了53 篇原創文章 · 獲贊 27 · 訪問量 8萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章