Java解析pmml格式機器學習模型

背景

         在實際工程項目中,我們訓練和迭代模型一般使用Python,因此它提供了強大的算法包和非常方便的數據處理工具,所以能夠快速試驗。但是,算法模型部署成服務,Java語言和其相應的框架就顯得優勢明顯了。爲了更好地結合Python 與 Java各自的優勢,PMML能夠作爲中間媒介,將模型以.pmml格式導出,然後利用java語言進行解析和部署

PMML是什麼?

        可以理解爲類似於一個xml的文件格式,能夠將機器學習模型以文件格式導出。並且,文件內容是對算法規則的描述,例如以下用Iris數據集訓練的決策樹模型

<?xml version="1.0" encoding="UTF-8" standalone="yes"?>
<PMML xmlns="http://www.dmg.org/PMML-4_3" xmlns:data="http://jpmml.org/jpmml-model/InlineTable" version="4.3">
	<Header>
		<Application name="JPMML-SkLearn" version="1.5.34"/>
		<Timestamp>2020-03-24T06:07:44Z</Timestamp>
	</Header>
	<MiningBuildTask>
		<Extension>PMMLPipeline(steps=[('classifier', DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,
            max_features=None, max_leaf_nodes=None,
            min_impurity_decrease=0.0, min_impurity_split=None,
            min_samples_leaf=1, min_samples_split=2,
            min_weight_fraction_leaf=0.0, presort=False, random_state=None,
            splitter='best'))])</Extension>
	</MiningBuildTask>
	<DataDictionary>
		<DataField name="y" optype="categorical" dataType="integer">
			<Value value="0"/>
			<Value value="1"/>
			<Value value="2"/>
		</DataField>
		<DataField name="x2" optype="continuous" dataType="float"/>
		<DataField name="x3" optype="continuous" dataType="float"/>
		<DataField name="x4" optype="continuous" dataType="float"/>
	</DataDictionary>
	<TransformationDictionary/>
	<TreeModel functionName="classification" missingValueStrategy="nullPrediction">
		<MiningSchema>
			<MiningField name="y" usageType="target"/>
			<MiningField name="x3"/>
			<MiningField name="x4"/>
			<MiningField name="x2"/>
		</MiningSchema>
		<Output>
			<OutputField name="probability(0)" optype="continuous" dataType="double" feature="probability" value="0"/>
			<OutputField name="probability(1)" optype="continuous" dataType="double" feature="probability" value="1"/>
			<OutputField name="probability(2)" optype="continuous" dataType="double" feature="probability" value="2"/>
		</Output>
		<LocalTransformations>
			<DerivedField name="double(x3)" optype="continuous" dataType="double">
				<FieldRef field="x3"/>
			</DerivedField>
			<DerivedField name="double(x4)" optype="continuous" dataType="double">
				<FieldRef field="x4"/>
			</DerivedField>
			<DerivedField name="double(x2)" optype="continuous" dataType="double">
				<FieldRef field="x2"/>
			</DerivedField>
		</LocalTransformations>
		<Node>
			<True/>
			<Node score="0" recordCount="50.0">
				<SimplePredicate field="double(x3)" operator="lessOrEqual" value="2.449999988079071"/>
				<ScoreDistribution value="0" recordCount="50.0"/>
				<ScoreDistribution value="1" recordCount="0.0"/>
				<ScoreDistribution value="2" recordCount="0.0"/>
			</Node>
			<Node>
				<SimplePredicate field="double(x4)" operator="lessOrEqual" value="1.75"/>
				<Node>
					<SimplePredicate field="double(x3)" operator="lessOrEqual" value="4.950000047683716"/>
					<Node score="1" recordCount="47.0">
						<SimplePredicate field="double(x4)" operator="lessOrEqual" value="1.6500000357627869"/>
						<ScoreDistribution value="0" recordCount="0.0"/>
						<ScoreDistribution value="1" recordCount="47.0"/>
						<ScoreDistribution value="2" recordCount="0.0"/>
					</Node>
					<Node score="2" recordCount="1.0">
						<True/>
						<ScoreDistribution value="0" recordCount="0.0"/>
						<ScoreDistribution value="1" recordCount="0.0"/>
						<ScoreDistribution value="2" recordCount="1.0"/>
					</Node>
				</Node>
				<Node score="2" recordCount="3.0">
					<SimplePredicate field="double(x4)" operator="lessOrEqual" value="1.550000011920929"/>
					<ScoreDistribution value="0" recordCount="0.0"/>
					<ScoreDistribution value="1" recordCount="0.0"/>
					<ScoreDistribution value="2" recordCount="3.0"/>
				</Node>
				<Node score="1" recordCount="2.0">
					<SimplePredicate field="double(x3)" operator="lessOrEqual" value="5.450000047683716"/>
					<ScoreDistribution value="0" recordCount="0.0"/>
					<ScoreDistribution value="1" recordCount="2.0"/>
					<ScoreDistribution value="2" recordCount="0.0"/>
				</Node>
				<Node score="2" recordCount="1.0">
					<True/>
					<ScoreDistribution value="0" recordCount="0.0"/>
					<ScoreDistribution value="1" recordCount="0.0"/>
					<ScoreDistribution value="2" recordCount="1.0"/>
				</Node>
			</Node>
			<Node>
				<SimplePredicate field="double(x3)" operator="lessOrEqual" value="4.8500001430511475"/>
				<Node score="2" recordCount="2.0">
					<SimplePredicate field="double(x2)" operator="lessOrEqual" value="3.100000023841858"/>
					<ScoreDistribution value="0" recordCount="0.0"/>
					<ScoreDistribution value="1" recordCount="0.0"/>
					<ScoreDistribution value="2" recordCount="2.0"/>
				</Node>
				<Node score="1" recordCount="1.0">
					<True/>
					<ScoreDistribution value="0" recordCount="0.0"/>
					<ScoreDistribution value="1" recordCount="1.0"/>
					<ScoreDistribution value="2" recordCount="0.0"/>
				</Node>
			</Node>
			<Node score="2" recordCount="43.0">
				<True/>
				<ScoreDistribution value="0" recordCount="0.0"/>
				<ScoreDistribution value="1" recordCount="0.0"/>
				<ScoreDistribution value="2" recordCount="43.0"/>
			</Node>
		</Node>
	</TreeModel>
</PMML>

Sklearn 生成PMML文件

  • 安裝Sklearn2pmml 
pip install --user --upgrade git+https://github.com/jpmml/sklearn2pmml.git
  • 利用sklearn 自帶的決策樹模型、iris數據集訓練和導出一個決策樹pmml文件
from sklearn2pmml import PMMLPipeline
from sklearn.datasets import load_iris
from sklearn import tree
iris = load_iris()
clf = tree.DecisionTreeClassifier()
pipeline = PMMLPipeline([("classifier", clf)])
pipeline.fit(iris.data, iris.target)

# 導出爲PMML
from sklearn2pmml import sklearn2pmml
sklearn2pmml(pipeline, "/Desktop/DecisionTreeIris.pmml", with_repr = True)
  • 新建java maven工程,增加如下依賴
<dependency>
            <groupId>org.jpmml</groupId>
            <artifactId>pmml-evaluator</artifactId>
            <version>1.4.1</version>
        </dependency>
        <dependency>
            <groupId>org.jpmml</groupId>
            <artifactId>pmml-evaluator-extension</artifactId>
            <version>1.4.1</version>
</dependency>

JAVA接口的兩種輸入形式

  • java解析決策樹模型,並完成預測輸出
  • 補充一段:可以將pmml當作字符串輸入,然後利用字節流轉換,實現輸入字符串參數也滿足條件
import org.dmg.pmml.FieldName;
import org.dmg.pmml.PMML;
import org.jpmml.evaluator.*;
import org.xml.sax.SAXException;

import javax.xml.bind.JAXBException;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

/**
 * Created by sanyin
 *  on 2020/03/24.
 */
public class PMMLDemo {
    private Evaluator loadPmml(){
        PMML pmml = new PMML();
        InputStream inputStream = null;
   
     //註釋這段是可以將pmml當成字符串傳參,封裝接口的時候就不用傳pmml文件路徑了
     //   try {
    //        inputStream = new ByteArrayInputStream(pmml.getBytes("utf-8"));
    //    } catch (
     //           IOException e) {
     //       e.printStackTrace();
     //   }
        try {
            inputStream = new FileInputStream("/Users/hzp/Desktop/DecisionTreeIris.pmml");
        } catch (IOException e) {
            e.printStackTrace();
        }
        if(inputStream == null){
            return null;
        }
        InputStream is = inputStream;
        try {
            pmml = org.jpmml.model.PMMLUtil.unmarshal(is);
        } catch (SAXException e1) {
            e1.printStackTrace();
        } catch (JAXBException e1) {
            e1.printStackTrace();
        }finally {
            //關閉輸入流
            try {
                is.close();
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
        ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance();
        Evaluator evaluator = modelEvaluatorFactory.newModelEvaluator(pmml);
        return evaluator;
    }
    private int predict(Evaluator evaluator,float a, float b, float c, float d) {
        //輸入特徵賦值,iris數據類型是4維,數據維度順序不能亂
        Map<String, Float> data = new HashMap<String, Float>();
        data.put("x1", a);
        data.put("x2", b);
        data.put("x3", c);
        data.put("x4", d);
        List<InputField> inputFields = evaluator.getInputFields();
        //構造模型輸入
        Map<FieldName, FieldValue> arguments = new LinkedHashMap<FieldName, FieldValue>();
        for (InputField inputField : inputFields) {
            FieldName inputFieldName = inputField.getName();
            Object rawValue = data.get(inputFieldName.getValue());
            FieldValue inputFieldValue = inputField.prepare(rawValue);
            arguments.put(inputFieldName, inputFieldValue);
        }

        Map<FieldName, ?> results = evaluator.evaluate(arguments);
        List<TargetField> targetFields = evaluator.getTargetFields();

        TargetField targetField = targetFields.get(0);
        FieldName targetFieldName = targetField.getName();

        Object targetFieldValue = results.get(targetFieldName);
        System.out.println("target: " + targetFieldName.getValue() + " value: " + targetFieldValue);
        int primitiveValue = -1;
        if (targetFieldValue instanceof Computable) {
            Computable computable = (Computable) targetFieldValue;
            primitiveValue = (Integer)computable.getResult();
        }
//        System.out.println(a + " " + b + " " + c + " " + d + ":" + primitiveValue);
        return primitiveValue;
    }
    public static void main(String args[]){
        PMMLDemo demo = new PMMLDemo();
        Evaluator model = demo.loadPmml();
        System.out.println(demo.predict(model,5.1f,3.5f,1.4f, 0.2f));
        System.out.println(demo.predict(model,6.9f,	3.1f,	5.1f,	2.3f));

    }
}

注意事項

  1. 注意python的數據類型,java輸入數據類型要與其一致
  2. 注意python訓練模型的特徵維數,java輸入數據特徵維度需要與其一致
  3. 其他的模型也適用,只要sklearn能輸出pmml文件格式即可

參考

       https://cloud.tencent.com/developer/article/1178944

       https://yao544303.github.io/2018/07/11/sklearn-PMML/

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章