/* * Licensed to Elasticsearch under one or more contributor * license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright * ownership. Elasticsearch licenses this file to you under * the Apache License, Version 2.0 (the "License"); you may * not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.elasticsearch.script; import org.dmg.pmml.DataDictionary; import org.dmg.pmml.DataField; import org.dmg.pmml.DataType; import org.dmg.pmml.FieldName; import org.dmg.pmml.FieldUsageType; import org.dmg.pmml.MiningField; import org.dmg.pmml.MiningFunctionType; import org.dmg.pmml.MiningSchema; import org.dmg.pmml.NumericPredictor; import org.dmg.pmml.OpType; import org.dmg.pmml.PMML; import org.dmg.pmml.RegressionModel; import org.dmg.pmml.RegressionNormalizationMethodType; import org.dmg.pmml.RegressionTable; import org.jpmml.model.JAXBUtil; import javax.xml.bind.JAXBException; import javax.xml.transform.stream.StreamResult; import java.io.ByteArrayOutputStream; import java.io.UnsupportedEncodingException; import java.nio.charset.Charset; public class PMMLGenerator { public static String generateSVMPMMLModel(double intercept, double[] weights, double[] labels) throws JAXBException, UnsupportedEncodingException { PMML pmml = new PMML(); // create DataDictionary DataDictionary dataDictionary = createDataDictionary(weights); pmml.setDataDictionary(dataDictionary); // create model RegressionModel regressionModel = new RegressionModel(); regressionModel.setModelName("linear SVM"); regressionModel.setFunctionName(MiningFunctionType.CLASSIFICATION); regressionModel.setNormalizationMethod(RegressionNormalizationMethodType.NONE); MiningSchema miningSchema = createMiningSchema(weights); regressionModel.setMiningSchema(miningSchema); RegressionTable regressionTable0 = new RegressionTable(); regressionTable0.setIntercept(intercept); regressionTable0.setTargetCategory(Double.toString(labels[0])); NumericPredictor[] numericPredictors = createNumericPredictors(weights); regressionTable0.addNumericPredictors(numericPredictors); RegressionTable regressionTable1 = new RegressionTable(); regressionTable1.setIntercept(0.0); regressionTable1.setTargetCategory(Double.toString(labels[1])); regressionModel.addRegressionTables(regressionTable0, regressionTable1); pmml.addModels(regressionModel); // marshal return convertPMMLToString(pmml); // write to string } public static String generateLRPMMLModel(double intercept, double[] weights, double[] labels) throws JAXBException, UnsupportedEncodingException { PMML pmml = new PMML(); // create DataDictionary DataDictionary dataDictionary = createDataDictionary(weights); pmml.setDataDictionary(dataDictionary); // create model RegressionModel regressionModel = new RegressionModel(); regressionModel.setModelName("logistic regression"); regressionModel.setFunctionName(MiningFunctionType.CLASSIFICATION); regressionModel.setNormalizationMethod(RegressionNormalizationMethodType.LOGIT); MiningSchema miningSchema = createMiningSchema(weights); regressionModel.setMiningSchema(miningSchema); RegressionTable regressionTable0 = new RegressionTable(); regressionTable0.setIntercept(intercept); regressionTable0.setTargetCategory(Double.toString(labels[0])); NumericPredictor[] numericPredictors = createNumericPredictors(weights); regressionTable0.addNumericPredictors(numericPredictors); RegressionTable regressionTable1 = new RegressionTable(); regressionTable1.setIntercept(0.0); regressionTable1.setTargetCategory(Double.toString(labels[1])); regressionModel.addRegressionTables(regressionTable0, regressionTable1); pmml.addModels(regressionModel); // marshal return convertPMMLToString(pmml); // write to string } public static String convertPMMLToString(PMML pmml) throws JAXBException, UnsupportedEncodingException { ByteArrayOutputStream baor = new ByteArrayOutputStream(); StreamResult streamResult = new StreamResult(); streamResult.setOutputStream(baor); JAXBUtil.marshal(pmml, streamResult); return baor.toString(Charset.defaultCharset().toString()); } public static DataDictionary createDataDictionary(double[] weights) { DataDictionary dataDictionary = new DataDictionary(); DataField[] dataFields = new DataField[weights.length + 1]; for (int i = 0; i < weights.length; i++) { dataFields[i] = createDataField("field_" + Integer.toString(i), DataType.DOUBLE, OpType.CONTINUOUS); } dataFields[weights.length] = createDataField("target", DataType.STRING, OpType.CATEGORICAL); dataDictionary.addDataFields(dataFields); dataDictionary.setNumberOfFields(weights.length + 1); return dataDictionary; } public static NumericPredictor[] createNumericPredictors(double[] weights) { NumericPredictor[] numericPredictors = new NumericPredictor[weights.length]; for (int i = 0; i < weights.length; i++) { numericPredictors[i] = creatNumericPredictor("field_" + Integer.toString(i), weights[i]); } return numericPredictors; } public static MiningSchema createMiningSchema(double[] weights) { MiningSchema miningSchema = new MiningSchema(); MiningField[] miningFields = new MiningField[weights.length + 1]; for (int i = 0; i < weights.length; i++) { miningFields[i] = creatMiningField("field_" + Integer.toString(i), FieldUsageType.ACTIVE); } miningFields[weights.length] = creatMiningField("target", FieldUsageType.TARGET); miningSchema.addMiningFields(miningFields); return miningSchema; } private static NumericPredictor creatNumericPredictor(String name, double coefficient) { NumericPredictor numericPredictor = new NumericPredictor(); numericPredictor.setCoefficient(coefficient); numericPredictor.setName(FieldName.create(name)); return numericPredictor; } private static MiningField creatMiningField(String name, FieldUsageType fieldUsageType) { MiningField field = new MiningField(); field.setName(new FieldName(name)); field.setUsageType(fieldUsageType); return field; } public static DataField createDataField(String name, DataType datatype, OpType opType) { DataField field = new DataField(); field.setName(new FieldName(name)); field.setDataType(datatype); field.setOpType(opType); return field; } }