1. spark-xgboost Java包
主要需要xgboost4j-spark-0.90.jar, xgboost4j-0.90.jar, 以及 調用代碼 sparkxgb.zip.
GitHub上面有xgboost java 實現的包,鏈接:xgboost;
但我省事,用了zhihu xgboost的分佈式版本(pyspark)使用測試 的下載鏈接。
注意,xgboost 的版本號 和sparkxgb內的內容對應。
2. xgboost多分類
我是使用pyspark 運行,通過 pyspark --jars **
把用到的這兩個jar包引入。
#!/usr/bin/env python
# -*- coding:utf8 -*-
import os
import sys
import time
import pandas as pd
import numpy as np
from pyspark import SparkConf, SparkContext
import pyspark.sql.types as typ
import pyspark.ml.feature as ft
from pyspark.sql.functions import isnan, isnull,col
import pyspark
from pyspark.sql.session import SparkSession
from pyspark.sql import SQLContext
from pyspark.sql.types import *
from pyspark.ml.feature import StringIndexer,VectorAssembler
from pyspark.ml.linalg import Vectors
from pyspark.ml import Pipeline
from sparkxgb import XGBoostClassifier
import sklearn.datasets as datasets
import numpy as np
import time
def normalize(x):
return (x - np.min(x)) / (np.max(x) - np.min(x))
def get_data():
# input datasets
X, y = datasets.make_blobs(n_samples=100000, centers=10,
n_features=10, random_state=0)
# 歸一化
X_norm = normalize(X)
X_train = X_norm[:int(len(X_norm) * 0.8)]
X_test = X_norm[int(len(X_norm) * 0.8):]
y_train = y[:int(len(X_norm) * 0.8)]
y_test = y[int(len(X_norm) * 0.8):]
y_train = y_train.reshape(-1, 1)
# spark df
df = np.concatenate([y_train, X_train], axis=1)
train_df = map(lambda x: (int(x[0]), Vectors.dense(x[1:])), df)
spark_train = spark.createDataFrame(train_df, schema=["label", "features"])
test_df = map(lambda x: (Vectors.dense(x),), X_test)
spark_test = spark.createDataFrame(test_df, schema=["features"])
return spark_train,spark_test,y_train,y_test
def train_model(trainDF):
xgboost = XGBoostClassifier(
featuresCol="features",
labelCol="label",
predictionCol="prediction",
objective='multi:softprob',
numClass=10,
missing=0.0
)
pipeline = Pipeline(stages=[xgboost])
model = pipeline.fit(trainDF)
# # Write model/classifier
# model.write().overwrite().save(hdfstrainpth + "/xgboost_class_test")
# model.load(hdfstrainpth + "/xgboost_class_test")
return model
def test():
data = [1, 2, 3, 4, 5]
distData = sc.parallelize(data)
print("done", distData.collect())
def cal_acc(pred, true):
count = 0
for i,row in enumerate(pred):
pred = row
if pred == true[i]:
count += 1
acc = round(count/len(true), 4)
return acc
if __name__ == "__main__":
from pyspark import SparkContext
conf = SparkConf().set("spark.jars", "/home/xgboost4j-0.90.jar,/home/xgboost4j-spark-0.90.jar")
sc = SparkContext(conf=conf).getOrCreate()
spark = SQLContext(sc)
trainDf, testDf,y_train,y_test = get_data()
print('get df')
model = train_model(trainDf)
prediction = model.transform(testDf).select("prediction").collect()
acc = cal_acc(prediction, y_test)
print("acc:{}".format(acc))
運行結果:acc:0.9992
預測結果:
model.transform(testDf).show()
+--------------------+--------------------+--------------------+----------+
| features| rawPrediction| probability|prediction|
+--------------------+--------------------+--------------------+----------+
|[0.36383649267021...|[0.33353492617607...|[0.06999947130680...| 9.0|
|[0.85080275306445...|[0.33345550298690...|[0.06996602565050...| 2.0|
|[0.54471116142668...|[1.99881935119628...|[0.37008801102638...| 0.0|
|[0.61089833342796...|[0.33345550298690...|[0.06995990127325...| 5.0|
|[0.25437385667790...|[0.33415806293487...|[0.07003305852413...| 6.0|
|[0.47371795998355...|[1.99881935119628...|[0.37008947134017...| 0.0|
|[0.75258857302126...|[0.33345550298690...|[0.07017561793327...| 2.0|
|[0.38430822786126...|[0.33345550298690...|[0.06999430805444...| 9.0|
|[0.84192691973241...|[0.33345550298690...|[0.06999272853136...| 7.0|
|[0.89822104638187...|[0.33345550298690...|[0.06999462842941...| 2.0|
|[0.87335367752325...|[0.33345550298690...|[0.06999401748180...| 2.0|
|[0.34598394310439...|[0.33365276455879...|[0.07000749558210...| 9.0|
|[0.37907532566580...|[0.33345550298690...|[0.06999314576387...| 8.0|
|[0.85996665363900...|[0.33345550298690...|[0.06998810172080...| 7.0|
|[0.52503470825319...|[1.99881935119628...|[0.37008947134017...| 0.0|
|[0.51847376135870...|[0.33345550298690...|[0.06998340785503...| 5.0|
|[0.51366954373353...|[1.98586511611938...|[0.36707320809364...| 0.0|
|[0.38344970186248...|[0.33345550298690...|[0.06998835504055...| 4.0|
|[0.31206934826790...|[0.33353492617607...|[0.06996974349021...| 6.0|
|[0.68235540326326...|[0.33345550298690...|[0.06998881697654...| 1.0|
+--------------------+--------------------+--------------------+----------+
參考: