/* * Licensed to Elasticsearch under one or more contributor * license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright * ownership. Elasticsearch licenses this file to you under * the Apache License, Version 2.0 (the "License"); you may * not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.elasticsearch.ml.factories; import org.dmg.pmml.DataDictionary; import org.dmg.pmml.DataField; import org.dmg.pmml.FieldUsageType; import org.dmg.pmml.GeneralRegressionModel; import org.dmg.pmml.MiningField; import org.dmg.pmml.PCell; import org.dmg.pmml.PPCell; import org.dmg.pmml.Parameter; import org.dmg.pmml.Predictor; import org.dmg.pmml.TransformationDictionary; import org.dmg.pmml.Value; import org.elasticsearch.ElasticsearchParseException; import org.elasticsearch.ml.modelinput.ModelAndModelInputEvaluator; import org.elasticsearch.ml.modelinput.PMMLVectorRange; import org.elasticsearch.ml.modelinput.VectorModelInput; import org.elasticsearch.ml.modelinput.VectorModelInputEvaluator; import org.elasticsearch.ml.modelinput.VectorRange; import org.elasticsearch.ml.models.EsLogisticRegressionModel; import org.elasticsearch.script.pmml.ProcessPMMLHelper; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; import java.util.TreeMap; import java.util.stream.Collectors; public class GeneralizedLinearRegressionModelFactory extends ModelFactory<VectorModelInput, String, GeneralRegressionModel> { public GeneralizedLinearRegressionModelFactory() { super(GeneralRegressionModel.class); } private List<VectorRange> convertToFeatureEntries(GeneralRegressionModel grModel, TransformationDictionary transformationDictionary, DataDictionary dataDictionary, TreeMap<String, List<PPCell>> fieldToPPCellMap, List<String> orderedParameterList) { // for each predictor: get vector entries? List<VectorRange> vectorRangeList = new ArrayList<>(); int indexCounter = 0; // for each of the fields create the feature entries for (String fieldname : fieldToPPCellMap.keySet()) { PMMLVectorRange featureEntries = ProcessPMMLHelper.extractVectorRange(grModel, dataDictionary, transformationDictionary, fieldname, () -> { // sort values first return fieldToPPCellMap.get(fieldname).stream().map(PPCell::getValue).collect(Collectors.toList()); }, indexCounter, null); for (PPCell cell : fieldToPPCellMap.get(fieldname)) { orderedParameterList.add(cell.getParameterName()); } indexCounter += featureEntries.size(); vectorRangeList.add(featureEntries); } return vectorRangeList; } private TreeMap<String, List<PPCell>> mapCellsToFields(GeneralRegressionModel grModel) { // check that correlation matrix only has one entry per parameter // we nned to implement correlations later, see http://dmg.org/pmml/v4-2-1/GeneralRegression.html (PPMatrix) but not now... Set<String> parametersInPPMatrix = new HashSet<>(); for (PPCell ppcell : grModel.getPPMatrix().getPPCells()) { if (parametersInPPMatrix.contains(ppcell.getParameterName())) { throw new UnsupportedOperationException("Don't support correlated predictors for GeneralRegressionModel yet"); } else { parametersInPPMatrix.add(ppcell.getParameterName()); } } //get all the field names for multinomialLogistic model TreeMap<String, List<PPCell>> fieldToPPCellMap = new TreeMap<>(); for (Predictor predictor : grModel.getFactorList().getPredictors()) { fieldToPPCellMap.put(predictor.getName().toString(), new ArrayList<PPCell>()); } for (Predictor predictor : grModel.getCovariateList().getPredictors()) { fieldToPPCellMap.put(predictor.getName().toString(), new ArrayList<PPCell>()); } // get all the entries and sort them by field. // then create one feature entry per field and add them to vectorRangesToVector. // also we must keep a list of parameter names here to make sure the model uses the same order! for (PPCell ppcell : grModel.getPPMatrix().getPPCells()) { fieldToPPCellMap.get(ppcell.getField().toString()).add(ppcell); } return fieldToPPCellMap; } private void addIntercept(GeneralRegressionModel grModel, List<VectorRange> vectorRangeMap, Map<String, List<PPCell>> fieldToPPCellMap, List<String> orderedParameterList) { // now, find the order of vector entries to model parameters. This is extremely annoying but we have to do it at some // point... int numFeatures = 0; // current index? Set<String> allFieldParameters = new HashSet<>(); for (Map.Entry<String, List<PPCell>> fieldAndCells : fieldToPPCellMap.entrySet()) { for (PPCell cell : fieldAndCells.getValue()) { allFieldParameters.add(cell.getParameterName()); numFeatures++; } } // now find the parameters which do not come form a field for (Parameter parameter : grModel.getParameterList().getParameters()) { if (allFieldParameters.contains(parameter.getName()) == false) { PMMLVectorRange.Intercept intercept = new PMMLVectorRange.Intercept(parameter.getName(), "double"); intercept.addVectorEntry(numFeatures, null); numFeatures++; vectorRangeMap.add(intercept); orderedParameterList.add(parameter.getName()); } } } @SuppressWarnings("unchecked") @Override public ModelAndModelInputEvaluator<VectorModelInput, String> buildFromPMML(GeneralRegressionModel grModel, DataDictionary dataDictionary, TransformationDictionary transformationDictionary) { if (grModel.getFunctionName().value().equals("classification") && (grModel.getModelType().value().equals ("multinomialLogistic") || (grModel.getModelType().value().equals ("generalizedLinear") && grModel.getDistribution().value().equals("binomial") && grModel.getLinkFunction().value().equals ("logit")))) { TreeMap<String, List<PPCell>> fieldToPPCellMap = mapCellsToFields(grModel); // this list stores the order of the parameters as the vectors will return them. So, if p5 is a parameter that has index 11 // in the vetor then this parameter will be at position 11 in the ordered parameter list. List<String> orderedParameterList = new ArrayList<>(); List<VectorRange> vectorRangeList = convertToFeatureEntries(grModel, transformationDictionary, dataDictionary, fieldToPPCellMap, orderedParameterList); //add intercept if any addIntercept(grModel, vectorRangeList, fieldToPPCellMap, orderedParameterList); assert orderedParameterList.size() == grModel.getParameterList().getParameters().size(); VectorModelInputEvaluator vectorEntries = new VectorModelInputEvaluator(vectorRangeList); // now finally create the model! // find all the coefficients for each class // first: sort all by target class Map<String, List<PCell>> targetClassPCellMap = mapParametersToTargetCategory(grModel); if (targetClassPCellMap.size() != 1) { throw new UnsupportedOperationException("We do not support more than two classes for GeneralizedRegression for " + "classification"); } double[] coefficients = getGLMCoefficients(orderedParameterList, targetClassPCellMap); //get the target class values. one we can get from the Pmatrix but the other one we have to find in the data dictionary String targetVariable = findTargetVariableName(grModel); // this need to be more if we implement more than two class String[] targetCategories = findTargetCategories(dataDictionary, targetClassPCellMap, targetVariable); EsLogisticRegressionModel logisticRegressionModel = new EsLogisticRegressionModel(coefficients, 0.0, targetCategories); return new ModelAndModelInputEvaluator<>(vectorEntries, logisticRegressionModel); } else { throw new UnsupportedOperationException("Only implemented logistic regression with multinomialLogistic so far."); } } private Map<String, List<PCell>> mapParametersToTargetCategory(GeneralRegressionModel grModel) { Map<String, List<PCell>> targetClassPCellMap = new HashMap<>(); for (PCell pCell : grModel.getParamMatrix().getPCells()) { String targetClassName = pCell.getTargetCategory(); if (targetClassPCellMap.containsKey(targetClassName) == false) { targetClassPCellMap.put(targetClassName, new ArrayList<PCell>()); } targetClassPCellMap.get(targetClassName).add(pCell); } return targetClassPCellMap; } private String[] findTargetCategories(DataDictionary dataDictionary, Map<String, List<PCell>> targetClassPCellMap, String targetVariable) { String[] targetCategories = new String[2]; String class1 = targetClassPCellMap.keySet().iterator().next(); targetCategories[0] = class1; // find it in the datafields for (DataField dataField : dataDictionary.getDataFields()) { if (dataField.getName().toString().equals(targetVariable)) { for (Value value : dataField.getValues()) { String valueString = value.getValue(); if (valueString.equals(class1) == false) { targetCategories[1] = valueString; } } if (targetCategories[1] == null) { throw new ElasticsearchParseException("could not find target class"); } break; } } return targetCategories; } private String findTargetVariableName(GeneralRegressionModel grModel) { String targetVariable = null; for (MiningField miningField : grModel.getMiningSchema().getMiningFields()) { FieldUsageType fieldUsageType = miningField.getUsageType(); if (fieldUsageType != null && fieldUsageType.value().equals("target")) { targetVariable = miningField.getName().getValue(); break; } } if (targetVariable == null) { // fall back to find the "predicted" field for (MiningField miningField : grModel.getMiningSchema().getMiningFields()) { FieldUsageType fieldUsageType = miningField.getUsageType(); if (fieldUsageType != null && fieldUsageType.value().equals("predicted")) { targetVariable = miningField.getName().getValue(); break; } } } if (targetVariable == null) { throw new ElasticsearchParseException("could not find target variable"); } return targetVariable; } // get the model coefficients ad return them in the order defined by orderedParameterList private double[] getGLMCoefficients(List<String> orderedParameterList, Map<String, List<PCell>> targetClassPCellMap) { List<PCell> coefficientCells = targetClassPCellMap.values().iterator().next(); if (coefficientCells.size() > orderedParameterList.size()) { throw new ElasticsearchParseException("Parameter list contains more entries than parameters"); } double[] coefficients = new double[orderedParameterList.size()]; Arrays.fill(coefficients, 0.0); for (int i = 0; i < coefficients.length; i++) { String parameter = orderedParameterList.get(i); for (PCell pCell : coefficientCells) { if (pCell.getParameterName().equals(parameter)) { coefficients[i] = pCell.getBeta(); // TODO: what to do with df? I don't get the documentation: http://dmg.org/pmml/v4-2-1/GeneralRegression.html } } } return coefficients; } }