/* * Licensed to Elasticsearch under one or more contributor * license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright * ownership. Elasticsearch licenses this file to you under * the Apache License, Version 2.0 (the "License"); you may * not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.elasticsearch.ml.factories; import org.dmg.pmml.DataDictionary; import org.dmg.pmml.NumericPredictor; import org.dmg.pmml.OpType; import org.dmg.pmml.RegressionModel; import org.dmg.pmml.RegressionTable; import org.dmg.pmml.TransformationDictionary; import org.elasticsearch.ml.modelinput.ModelAndModelInputEvaluator; import org.elasticsearch.ml.modelinput.PMMLVectorRange; import org.elasticsearch.ml.modelinput.VectorModelInput; import org.elasticsearch.ml.modelinput.VectorModelInputEvaluator; import org.elasticsearch.ml.modelinput.VectorRange; import org.elasticsearch.ml.models.EsLinearSVMModel; import org.elasticsearch.ml.models.EsLogisticRegressionModel; import org.elasticsearch.ml.models.EsModelEvaluator; import org.elasticsearch.script.pmml.ProcessPMMLHelper; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; /** * Factory for regression models */ public class RegressionModelFactory extends ModelFactory<VectorModelInput, String, RegressionModel> { public RegressionModelFactory() { super(RegressionModel.class); } @Override public ModelAndModelInputEvaluator<VectorModelInput, String> buildFromPMML(RegressionModel model, DataDictionary dataDictionary, TransformationDictionary transformationDictionary) { if (model.getModelName().equals("logistic regression")) { return initModel(model, dataDictionary, transformationDictionary, EsLogisticRegressionModel::new); } else if (model.getModelName().equals("linear SVM")) { return initModel(model, dataDictionary, transformationDictionary, EsLinearSVMModel::new); } else { throw new UnsupportedOperationException("We only implemented logistic regression so far but your model is of type " + model.getModelName()); } } private interface RegressionModelConstructor { EsModelEvaluator<VectorModelInput, String> create(double[] coefficients, double intercept, String[] classes); } private ModelAndModelInputEvaluator<VectorModelInput, String> initModel(RegressionModel model, DataDictionary dataDictionary, TransformationDictionary transformationDictionary, RegressionModelConstructor constructor) { List<VectorRange> vectorRanges = new ArrayList<>(); int indexCounter = 0; Map<String, OpType> types = new HashMap<>(); // TODO: add RegressionTable regressionTable = model.getRegressionTables().get(0); for (NumericPredictor predictor : regressionTable.getNumericPredictors()) { PMMLVectorRange vectorRange = ProcessPMMLHelper.extractVectorRange(model, dataDictionary, transformationDictionary, predictor.getName().getValue(), () -> { throw new IllegalArgumentException("Categorical fields are not supported yet"); }, indexCounter, types); vectorRanges.add(vectorRange); indexCounter += vectorRange.size(); } VectorModelInputEvaluator vectorPMML = new VectorModelInputEvaluator(vectorRanges); EsModelEvaluator<VectorModelInput, String> modelEvaluator = buildLinerModel(model, constructor); return new ModelAndModelInputEvaluator<>(vectorPMML, modelEvaluator); } private static EsModelEvaluator<VectorModelInput, String> buildLinerModel(RegressionModel model, RegressionModelConstructor constructor) { RegressionTable regressionTable = model.getRegressionTables().get(0); List<NumericPredictor> numericPredictors = regressionTable.getNumericPredictors(); double[] coefficients = new double[numericPredictors.size()]; int i = 0; for (NumericPredictor numericPredictor : numericPredictors) { coefficients[i] = numericPredictor.getCoefficient(); i++; } String[] classes = new String[]{ model.getRegressionTables().get(0).getTargetCategory(), model.getRegressionTables().get(1).getTargetCategory() }; return constructor.create(coefficients, regressionTable.getIntercept(), classes); } }