mmlspark-101: TrainClassifier

mmlspark 101

預測一個人的收入是否超過$50k
數據下載地址https://www.kaggle.com/uciml/adult-census-income/data
注意!!!
mmlspark安裝,版本0.17,部分api已經發生變化,官方git上notebook版本較低
shell

pyspark --master=spark://Lord:7077 --packages Azure:mmlspark:0.17

在這裏插入圖片描述
會自動下載

from pyspark import SparkConf
import pyspark.sql.functions as F
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField
from pyspark.mllib.evaluation import RankingMetrics, RegressionMetrics
from pyspark.sql.types import StringType, FloatType, IntegerType, LongType

read and clean data

spark = SparkSession.builder.appName("MyApp").config("spark.jars.packages", "Azure:mmlspark:0.17").getOrCreate()

data = spark.read.csv('hdfs:///user/hadoop/adult.csv',inferSchema=True, header=True)

data.limit(10).toPandas()
age workclass fnlwgt education education.num marital.status occupation relationship race sex capital.gain capital.loss hours.per.week native.country income
0 90 ? 77053 HS-grad 9 Widowed ? Not-in-family White Female 0 4356 40 United-States <=50K
1 82 Private 132870 HS-grad 9 Widowed Exec-managerial Not-in-family White Female 0 4356 18 United-States <=50K
2 66 ? 186061 Some-college 10 Widowed ? Unmarried Black Female 0 4356 40 United-States <=50K
3 54 Private 140359 7th-8th 4 Divorced Machine-op-inspct Unmarried White Female 0 3900 40 United-States <=50K
4 41 Private 264663 Some-college 10 Separated Prof-specialty Own-child White Female 0 3900 40 United-States <=50K
5 34 Private 216864 HS-grad 9 Divorced Other-service Unmarried White Female 0 3770 45 United-States <=50K
6 38 Private 150601 10th 6 Separated Adm-clerical Unmarried White Male 0 3770 40 United-States <=50K
7 74 State-gov 88638 Doctorate 16 Never-married Prof-specialty Other-relative White Female 0 3683 20 United-States >50K
8 68 Federal-gov 422013 HS-grad 9 Divorced Prof-specialty Not-in-family White Female 0 3683 40 United-States <=50K
9 41 Private 70037 Some-college 10 Never-married Craft-repair Unmarried White Male 0 3004 60 ? >50K

withColumnRenamed

data = data.withColumnRenamed('education.num','education_num')\
           .withColumnRenamed('marital.status','marital_status')\
           .withColumnRenamed('capital.gain','capital_gain')\
           .withColumnRenamed('capital.loss','capital_loss')\
           .withColumnRenamed('hours.per.week','hours_per_week')\
           .withColumnRenamed('native.country','native_country')
data.printSchema()
root
 |-- age: integer (nullable = true)
 |-- workclass: string (nullable = true)
 |-- fnlwgt: integer (nullable = true)
 |-- education: string (nullable = true)
 |-- education_num: integer (nullable = true)
 |-- marital_status: string (nullable = true)
 |-- occupation: string (nullable = true)
 |-- relationship: string (nullable = true)
 |-- race: string (nullable = true)
 |-- sex: string (nullable = true)
 |-- capital_gain: integer (nullable = true)
 |-- capital_loss: integer (nullable = true)
 |-- hours_per_week: integer (nullable = true)
 |-- native_country: string (nullable = true)
 |-- income: string (nullable = true)

EDA

import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline
# add label column
add_label = F.udf(lambda income : 0 if income == '<=50K' else 1 , IntegerType())

data = data.withColumn('label', add_label(data['income']))
# 收入大於$50k
data[data.label == 1].describe().toPandas()[['summary','age','education_num','hours_per_week']]
summary age education_num hours_per_week
0 count 7841 7841 7841
1 mean 44.24984058155847 11.611656676444332 45.473026399693914
2 stddev 10.519027719851813 2.385128632665079 11.01297093020927
3 min 19 2 1
4 max 90 16 99
# 收入小於$50k
data[data.label == 0].describe().toPandas()[['summary','age','education_num','hours_per_week']]
summary age education_num hours_per_week
0 count 24720 24720 24720
1 mean 36.78373786407767 9.595064724919094 38.840210355987054
2 stddev 14.020088490824895 2.4361467923083993 12.31899464185489
3 min 17 1 1
4 max 90 16 99

age,education_num, hours_per_week分佈情況

ages_0 = data[data.label == 0].select('age').collect()

ages_1 = data[data.label == 1].select('age').collect()

plt.figure(figsize=(10, 5))
sns.distplot(ages_0, label='<=$50K')
sns.distplot(ages_1, label='>$50K')
plt.xlabel('age',fontsize=15)
plt.legend()

在這裏插入圖片描述

顯而易見,收入大於50K50K的人羣年齡整體大於小於50K50K

edus_0 = data[data.label == 0].select('education_num').collect()

edus_1 = data[data.label == 1].select('education_num').collect()

plt.figure(figsize=(10, 5))
sns.distplot(edus_0, label='<=$50K')
sns.distplot(edus_1, label='>$50K')
plt.xlabel('education_num',fontsize=15)
plt.legend()

在這裏插入圖片描述

hours_per_week_0 = data[data.label == 0].select('hours_per_week').collect()

hours_per_week_1 = data[data.label == 1].select('hours_per_week').collect()

plt.figure(figsize=(10, 5))
sns.distplot(hours_per_week_0, label='<=$50K')
sns.distplot(hours_per_week_1, label='>$50K')
plt.xlabel('hours_per_week',fontsize=15)
plt.legend()

在這裏插入圖片描述

Spilt Data and Training Data

data = data.select(["age","education", "education_num","marital_status", "hours_per_week", "income"])

train, test = data.randomSplit([0.75, 0.25], seed=20200420)
from mmlspark import TrainClassifier

from pyspark.ml.classification import LogisticRegression, DecisionTreeClassifier, NaiveBayes
# lr : LogisticRegression
# dt : decisionTree
# nb : NaiveBayes
model_lr = TrainClassifier(model=LogisticRegression(), labelCol="income", numFeatures=256).fit(train)
model_dt = TrainClassifier(model=DecisionTreeClassifier(), labelCol="income", numFeatures=256).fit(train)
model_nb = TrainClassifier(model=NaiveBayes(), labelCol="income", numFeatures=256).fit(train)

model_lr.write().overwrite().save("../models/LrModel.mml")

evaluate

from mmlspark import ComputeModelStatistics, TrainedClassifierModel
predictionModel = TrainedClassifierModel.load("../models/LrModel.mml")
prediction = predictionModel.transform(test)
metrics = ComputeModelStatistics().transform(prediction)
metrics.toPandas()
evaluation_type confusion_matrix accuracy precision recall AUC
0 Classification DenseMatrix([[5766., 473.],\n [ 9... 0.824068 0.674242 0.503083 0.869809
dt_prediction = model_dt.transform(test)
nb_prediction = model_nb.transform(test)

dt_metrics = ComputeModelStatistics().transform(dt_prediction)
nb_metrics = ComputeModelStatistics().transform(nb_prediction)
dt_metrics.toPandas()
evaluation_type confusion_matrix accuracy precision recall AUC
0 Classification DenseMatrix([[5943., 296.],\n [11... 0.825901 0.734052 0.419836 0.672863
nb_metrics.toPandas()
evaluation_type confusion_matrix accuracy precision recall AUC
0 Classification DenseMatrix([[5824., 415.],\n [10... 0.818448 0.678295 0.44964 0.247985
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章