/* * Licensed to Elasticsearch under one or more contributor * license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright * ownership. Elasticsearch licenses this file to you under * the Apache License, Version 2.0 (the "License"); you may * not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.elasticsearch.ml.factories; import org.dmg.pmml.BayesInput; import org.dmg.pmml.ContinuousDistribution; import org.dmg.pmml.DataDictionary; import org.dmg.pmml.GaussianDistribution; import org.dmg.pmml.NaiveBayesModel; import org.dmg.pmml.OpType; import org.dmg.pmml.PairCounts; import org.dmg.pmml.TargetValueCount; import org.dmg.pmml.TargetValueCounts; import org.dmg.pmml.TargetValueStat; import org.dmg.pmml.TargetValueStats; import org.dmg.pmml.TransformationDictionary; import org.elasticsearch.ml.modelinput.ModelAndModelInputEvaluator; import org.elasticsearch.ml.modelinput.PMMLVectorRange; import org.elasticsearch.ml.modelinput.VectorModelInput; import org.elasticsearch.ml.modelinput.VectorModelInputEvaluator; import org.elasticsearch.ml.modelinput.VectorRange; import org.elasticsearch.ml.models.EsModelEvaluator; import org.elasticsearch.ml.models.EsNaiveBayesModelWithMixedInput; import org.elasticsearch.ml.models.EsNaiveBayesModelWithMixedInput.GaussFunction; import org.elasticsearch.ml.models.EsNaiveBayesModelWithMixedInput.ProbFunction; import org.elasticsearch.script.pmml.ProcessPMMLHelper; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.TreeMap; import java.util.TreeSet; import java.util.function.DoubleUnaryOperator; public class NaiveBayesModelFactory extends ModelFactory<VectorModelInput, String, NaiveBayesModel> { public NaiveBayesModelFactory() { super(NaiveBayesModel.class); } @Override public ModelAndModelInputEvaluator<VectorModelInput, String> buildFromPMML(NaiveBayesModel naiveBayesModel, DataDictionary dataDictionary, TransformationDictionary transformationDictionary) { if (naiveBayesModel.getFunctionName().value().equals("classification")) { // for each Bayes input // find the whole tranform pipeline (cp glm) // create vector range // append Ijk/Tk List<VectorRange> vectorRanges = new ArrayList<>(); List<TargetValueStats> targetValueStats = new ArrayList<>(); int indexCounter = 0; Map<String, OpType> types = new HashMap<>(); for (BayesInput bayesInput : naiveBayesModel.getBayesInputs()) { PMMLVectorRange vectorRange = ProcessPMMLHelper.extractVectorRange(naiveBayesModel, dataDictionary, transformationDictionary, bayesInput.getFieldName().getValue(), () -> { // sort values first TreeSet<String> sortedValues = new TreeSet<>(); for (PairCounts pairCount : bayesInput.getPairCounts()) { sortedValues.add(pairCount.getValue()); } return sortedValues; }, indexCounter, types); vectorRanges.add(vectorRange); indexCounter += vectorRange.size(); targetValueStats.add(bayesInput.getTargetValueStats()); } VectorModelInputEvaluator vectorPMML = new VectorModelInputEvaluator(vectorRanges); EsModelEvaluator<VectorModelInput, String> model = buildEsNaiveBayesModel(naiveBayesModel, types); return new ModelAndModelInputEvaluator<>(vectorPMML, model); } else { throw new UnsupportedOperationException("Naive does not support the following parameters yet: " + " functionName:" + naiveBayesModel.getFunctionName().value()); } } private EsModelEvaluator<VectorModelInput, String> buildEsNaiveBayesModel(NaiveBayesModel naiveBayesModel, Map<String, OpType> types) { Map<String, Integer> classIndexMap = new HashMap<>(); // get class priors int numClasses = naiveBayesModel.getBayesOutput().getTargetValueCounts().getTargetValueCounts().size(); // sort first TreeMap<String, Double> sortedClassLabelsAndCounts = new TreeMap<>(); double[] classPriors = new double[numClasses]; double[] classCounts = new double[numClasses]; String[] classLabels = new String[numClasses]; double sumCounts = 0; for (TargetValueCount targetValueCount : naiveBayesModel.getBayesOutput().getTargetValueCounts().getTargetValueCounts()) { sortedClassLabelsAndCounts.put(targetValueCount.getValue(), targetValueCount.getCount()); sumCounts += targetValueCount.getCount(); } int classCounter = 0; for (Map.Entry<String, Double> classCount : sortedClassLabelsAndCounts.entrySet()) { classPriors[classCounter] = Math.log(classCount.getValue() / sumCounts); classLabels[classCounter] = classCount.getKey(); classCounts[classCounter] = classCount.getValue(); classIndexMap.put(classCount.getKey(), classCounter); classCounter++; } List<List<DoubleUnaryOperator>> functionLists = initFunctions(naiveBayesModel, types, classCounts, classIndexMap, classLabels); DoubleUnaryOperator[][] functions = new DoubleUnaryOperator[functionLists.size()][functionLists.get(0).size()]; classCounter = 0; for (List<DoubleUnaryOperator> classFunctions : functionLists) { int functionCounter = 0; for (DoubleUnaryOperator classFunction : classFunctions) { functions[classCounter][functionCounter] = classFunction; functionCounter++; } classCounter++; } return new EsNaiveBayesModelWithMixedInput(classLabels, functions, classPriors); } private List<List<DoubleUnaryOperator>> initFunctions(NaiveBayesModel naiveBayesModel, Map<String, OpType> types, double[] classCounts, Map<String, Integer> classIndexMap, String[] classLabels) { List<List<DoubleUnaryOperator>> functionLists = new ArrayList<>(); for (int i = 0; i < classLabels.length; i++) { functionLists.add(new ArrayList<>()); } double threshold = naiveBayesModel.getThreshold(); for (BayesInput bayesInput : naiveBayesModel.getBayesInputs()) { String fieldName = bayesInput.getFieldName().getValue(); if (types.containsKey(fieldName) == false) { throw new UnsupportedOperationException("Cannot determine type of field " + bayesInput.getFieldName().getValue() + "probably messed up parsing"); } if (types.get(fieldName).equals(OpType.CONTINUOUS)) { for (TargetValueStat targetValueStat : bayesInput.getTargetValueStats()) { ContinuousDistribution continuousDistribution = targetValueStat.getContinuousDistribution(); if (continuousDistribution instanceof GaussianDistribution == false) { throw new UnsupportedOperationException("Only Gaussian distribution implemented so fay for naive bayes model"); } GaussianDistribution gaussianDistribution = (GaussianDistribution) continuousDistribution; String classAssignment = targetValueStat.getValue(); functionLists.get(classIndexMap.get(classAssignment)).add( new GaussFunction(gaussianDistribution.getVariance(), gaussianDistribution.getMean())); } } else if (types.get(fieldName).equals(OpType.CATEGORICAL)) { TreeMap<String, TargetValueCounts> sortedValues = new TreeMap<>(); for (PairCounts pairCount : bayesInput.getPairCounts()) { sortedValues.put(pairCount.getValue(), pairCount.getTargetValueCounts()); } for (Map.Entry<String, TargetValueCounts> counts : sortedValues.entrySet()) { for (TargetValueCount targetValueCount : counts.getValue()) { Integer classIndex = classIndexMap.get(targetValueCount.getValue()); double prob = targetValueCount.getCount() / classCounts[classIndex]; functionLists.get(classIndex).add(new ProbFunction(prob, threshold)); } } } else { throw new UnsupportedOperationException("cannot deal with bayes input that is not categorical and also not continuous"); } } return functionLists; } }