
Building a Random Forests with PySpark

  • Decision Tree
  • Random Forests


  • RF的基本組件DT(決策樹)
  • 決策樹常用於分類和迴歸任務
  • Entropy熵
    • Entorpy of target
    • Entorpy of target with features
  • Information Gain 信息增益
# example dataset
import pandas as pd
toy_data = pd.DataFrame({'Age_group':
                         ['old', 'teenager', 'young', 'old', 'young', 'teenager', 'teenager', 'old', 'teenager', 'young', 'young','teenager','young','old'],
                         'Smoker':['yes', 'yes', 'yes', 'no', 'yes', 'no', 'no', 'no', 'no', 'no', 'yes', 'yes' ,'no', 'yes'],
Age_group Smoker Medical_condition Salary_level insurance_premium
0 old yes yes high high
1 teenager yes yes medium high
2 young yes yes medium low
3 old no yes high high
4 young yes yes high low
5 teenager no yes low high
6 teenager no no low low
7 old no no low high
8 teenager no yes medium high
9 young no yes low high
10 young yes no high low
11 teenager yes no medium low
12 young no no medium high
13 old yes no medium high


Plog2PQlog2Q -P*log_{2}P-Q*log_{2}Q

計算 Entropy of target

  • target column : insurance_premium
    • high 9
    • low 5
  • probability high : 9/14 = 0.64
  • probability low : 4/15 = 0.36

EntropyTarget=P(high)log2(P(high)P(low)log2(P(low)) EntropyTarget = -P(high)log_2(P(high) - P(low)log_2(P(low))

=(0.64log2(0.64))(0.36log2(0.36)) =-(0.64*log_2(0.64)) - (0.36*log_2(0.36))

=0.94 =0.94

計算Entropy of target with features

Entropy(featuretarget)=Probability(feature)Entropy(target) Entropy(feature|target) = Probability(feature) * Entropy(target)

  • feature : smoker
    • yes : high 3, low 4
    • no : high 6, low 1

Entropy(smoker)=PyesEntropyTarget(smoker=yes)+PnoEntropyTarget(smoker=no) Entropy_{(smoker)} = P_{yes} * EntropyTarget_{(smoker=yes)} + P_{no} * EntropyTarget_{(smoker=no)}

Pyes=714=0.5 P_{yes} = \frac{7}{14} = 0.5

Pno=714=0.5 P_{no} = \frac{7}{14} = 0.5

EntropyTarget(somker=yes)=37log2(37)47log2(47) EntropyTarget_{(somker=yes)} =- \frac{3}{7} * log_2 \big( \frac{3}{7}\big) - \frac{4}{7} * log_2\big(\frac{4}{7}\big)

EntropyTarget(somker=yes)=0.99 EntropyTarget_{(somker=yes)} = 0.99

EntropyTarget(somker=no)=67log2(67)17log2(17) EntropyTarget_{(somker=no)} = -\frac{6}{7} * log_2\big(\frac{6}{7}\big) - \frac{1}{7} * log_2\big(\frac{1}{7}\big)

EntropyTarget(somker=no)=0.59 EntropyTarget_{(somker=no)} = 0.59

Entropy(smoker)=0.50.99+0.50.59=0.79 Entropy_{(smoker)} = 0.5 * 0.99 +0.5 * 0.59 = 0.79


  • Entropy(smoker) = 0.79
  • Entropy(age_group) = 0.69
  • Entropy(medical_condition) = 89
  • Entropy(salary_level) = 0.91

Information Gain ( IG )

IG=EntropytargetEntropyfeature IG = Entropy_{target} - Entropy_{feature}

IGsmoker=0.940.79=0.15 IG_{smoker} = 0.94 - 0.79 = 0.15


  • IG(smoker) = 0.15
  • IG(age_group) = 0.25
  • IG(medical_condition) = 0.05
  • IG(salary_level) = 0.03


  • toy_data(age_group == teenager)
  • toy_data(age_group == young)
  • toy_data(age_group == old)


Random Forests



  • 迴歸:平均,加權平均
  • 分類:投票


  • 特徵的重要性:特徵選擇
  • 性能提升:>> 決策樹
  • 減少過擬合
  • 計算開銷增加:訓練多棵決策樹


Let’s build a random forest model using spark’s MLlib

  • create a sparksession & load dataset
  • eda
  • feature engineering
  • splitting train/test set
  • building & training model
  • evaluation

sparksession & loaddata

from pyspark.sql import SparkSession
spark = SparkSession.builder.appName('random_forest').getOrCreate()

# load data
df ='./Data/affairs.csv', inferSchema=True, header=True)


print((df.count(), len(df.columns)))
(6366, 6)
 |-- rate_marriage: integer (nullable = true)
 |-- age: double (nullable = true)
 |-- yrs_married: double (nullable = true)
 |-- children: double (nullable = true)
 |-- religious: integer (nullable = true)
 |-- affairs: integer (nullable = true), False)
|rate_marriage|age |yrs_married|children|religious|affairs|
|5            |32.0|6.0        |1.0     |3        |0      |
|4            |22.0|2.5        |0.0     |2        |0      |
|3            |32.0|9.0        |3.0     |3        |1      |
|3            |27.0|13.0       |3.0     |1        |1      |
|4            |22.0|2.5        |0.0     |1        |1      |
only showing top 5 rows
|summary|     rate_marriage|               age|      yrs_married|          children|         religious|           affairs|
|  count|              6366|              6366|             6366|              6366|              6366|              6366|
|   mean| 4.109644989004084|29.082862079798932| 9.00942507068803|1.3968740182218033|2.4261702796104303|0.3224945020420987|
| stddev|0.9614295945655025| 6.847881883668817|7.280119972766412| 1.433470828560344|0.8783688402641785| 0.467467779921086|
|    min|                 1|              17.5|              0.5|               0.0|                 1|                 0|
|    max|                 5|              42.0|             23.0|               5.5|                 4|                 1|
df.groupBy('affairs').count().show()         # 有外遇約30%
|      1| 2053|
|      0| 4313|
df.groupBy('rate_marriage').count().show()    # 大多數人給她們的婚姻打4、5分
|            1|   99|
|            3|  993|
|            5| 2684|
|            4| 2242|
|            2|  348|
# 不同分數,婚外情情況
temp_df = df.groupBy('rate_marriage', 'affairs').count().orderBy('rate_marriage','affairs','count', ascending=True)
|            1|      0|   25|
|            1|      1|   74|
|            2|      0|  127|
|            2|      1|  221|
|            3|      0|  446|
|            3|      1|  547|
|            4|      0| 1518|
|            4|      1|  724|
|            5|      0| 2197|
|            5|      1|  487|
# 不同分數,有外遇的人數
temp_df = temp_df.filter(temp_df.affairs==1)
|            1|      1|   74|
|            2|      1|  221|
|            3|      1|  547|
|            4|      1|  724|
|            5|      1|  487|
# 不同分數,總人數
temp_2 = df.groupBy('rate_marriage').count()
|            1|   99|
|            3|  993|
|            5| 2684|
|            4| 2242|
|            2|  348|
# religious
df.groupBy('religious', 'affairs').count().orderBy('religious', 'affairs', 'count', ascending=True).show()
|        1|      0|  613|
|        1|      1|  408|
|        2|      0| 1448|
|        2|      1|  819|
|        3|      0| 1715|
|        3|      1|  707|
|        4|      0|  537|
|        4|      1|  119|
# children
df.groupBy('children', 'affairs').count().orderBy('children', 'affairs', 'count', ascending=True).show()
|     0.0|      0| 1912|
|     0.0|      1|  502|
|     1.0|      0|  747|
|     1.0|      1|  412|
|     2.0|      0|  873|
|     2.0|      1|  608|
|     3.0|      0|  460|
|     3.0|      1|  321|
|     4.0|      0|  197|
|     4.0|      1|  131|
|     5.5|      0|  124|
|     5.5|      1|   79|
|affairs|avg(rate_marriage)|          avg(age)|  avg(yrs_married)|     avg(children)|    avg(religious)|avg(affairs)|
|      1|3.6473453482708234|30.537018996590355|11.152459814905017|1.7289332683877252| 2.261568436434486|         1.0|
|      0| 4.329700904242986| 28.39067934152562| 7.989334569904939|1.2388128912589844|2.5045212149316023|         0.0|

create feature data

from import VectorAssembler

df_assembler = VectorAssembler(inputCols=['rate_marriage', 'age', 'yrs_married', 'children', 'religious'], outputCol='features')

df = df_assembler.transform(df)
|rate_marriage| age|yrs_married|children|religious|affairs|            features|
|            5|32.0|        6.0|     1.0|        3|      0|[5.0,32.0,6.0,1.0...|
|            4|22.0|        2.5|     0.0|        2|      0|[4.0,22.0,2.5,0.0...|
|            3|32.0|        9.0|     3.0|        3|      1|[3.0,32.0,9.0,3.0...|
|            3|27.0|       13.0|     3.0|        1|      1|[3.0,27.0,13.0,3....|
|            4|22.0|        2.5|     0.0|        1|      1|[4.0,22.0,2.5,0.0...|
only showing top 5 rows['features', 'affairs']).show(5)
|            features|affairs|
|[5.0,32.0,6.0,1.0...|      0|
|[4.0,22.0,2.5,0.0...|      0|
|[3.0,32.0,9.0,3.0...|      1|
|[3.0,27.0,13.0,3....|      1|
|[4.0,22.0,2.5,0.0...|      1|
only showing top 5 rows
data =['features', 'affairs'])

splitting train\test set

train_df , test_df = data.randomSplit([0.75, 0.25])
print('train set (%d, %d)'%(train_df.count(), len(train_df.columns)))
print('test set (%d, %d)'%(test_df.count(), len(test_df.columns)))
train set (4784, 2)
test set (1582, 2)

build model

  • Logistic Regression VS Random Forests
from import RandomForestClassifier,LogisticRegression, DecisionTreeClassifier

rf = RandomForestClassifier(labelCol='affairs', numTrees=50).fit(train_df)

lr = LogisticRegression(labelCol='affairs').fit(train_df)

dt = DecisionTreeClassifier(labelCol='affairs').fit(train_df)

rf_pred = rf.transform(test_df)

lr_pred = lr.transform(test_df)

dt_pred = dt.transform(test_df)


  • Accuracy
  • Precision
  • AUC
from import MulticlassClassificationEvaluator
from import BinaryClassificationEvaluator      #auc

rf_accuracy = MulticlassClassificationEvaluator(labelCol='affairs', metricName='accuracy').evaluate(rf_pred)
print("RF's accuracy is %f"%rf_accuracy)
lr_accuracy = MulticlassClassificationEvaluator(labelCol='affairs', metricName='accuracy').evaluate(lr_pred)
print("LR's accuracy is %f"%lr_accuracy)
dt_accuracy= MulticlassClassificationEvaluator(labelCol='affairs', metricName='accuracy').evaluate(dt_pred)
print("DT's accuracy is %f"%dt_accuracy)
RF's accuracy is 0.727560
LR's accuracy is 0.724399
DT's accuracy is 0.719343
rf_precision = MulticlassClassificationEvaluator(labelCol='affairs', metricName='weightedPrecision').evaluate(rf_pred)
print("RF's precision is %f"%rf_precision)
lr_precision = MulticlassClassificationEvaluator(labelCol='affairs', metricName='weightedPrecision').evaluate(lr_pred)
print("LR's precision is %f"%lr_precision)
dt_precision= MulticlassClassificationEvaluator(labelCol='affairs', metricName='weightedPrecision').evaluate(dt_pred)
print("DT's precision is %f"%dt_precision)
RF's precision is 0.709906
LR's precision is 0.706239
DT's precision is 0.707323
rf_auc = BinaryClassificationEvaluator(labelCol='affairs').evaluate(rf_pred)
print("RF's precision is %f"%rf_auc)
lr_auc = BinaryClassificationEvaluator(labelCol='affairs').evaluate(lr_pred)
print("LR's precision is %f"%lr_auc)
dt_auc= BinaryClassificationEvaluator(labelCol='affairs').evaluate(dt_pred)
print("DT's precision is %f"%dt_auc)
RF's precision is 0.752915
LR's precision is 0.745961
DT's precision is 0.609049

feature importances

SparseVector(5, {0: 0.5652, 1: 0.0286, 2: 0.2444, 3: 0.0781, 4: 0.0836})
{'numeric': [{'idx': 0, 'name': 'rate_marriage'},
  {'idx': 1, 'name': 'age'},
  {'idx': 2, 'name': 'yrs_married'},
  {'idx': 3, 'name': 'children'},
  {'idx': 4, 'name': 'religious'}]}
