/* * Licensed to Elasticsearch under one or more contributor * license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright * ownership. Elasticsearch licenses this file to you under * the Apache License, Version 2.0 (the "License"); you may * not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.elasticsearch.ml.models; import org.elasticsearch.ml.modelinput.VectorModelInput; import java.util.HashMap; import java.util.Map; import java.util.function.DoubleUnaryOperator; public class EsNaiveBayesModelWithMixedInput extends EsModelEvaluator<VectorModelInput, String> { private final DoubleUnaryOperator[][] functions; private final double[] classPriors; private final String[] classLabels; public EsNaiveBayesModelWithMixedInput(String[] classLabels, DoubleUnaryOperator[][] functions, double[] classPriors) { this.functions = functions; this.classPriors = classPriors; this.classLabels = classLabels; } @Override public Map<String, Object> evaluateDebug(VectorModelInput modelInput) { double[] classProbs = getClassProbs(modelInput); return prepareResult(classProbs); } private double[] getClassProbs(VectorModelInput modelInput) { double[] classProbs = new double[classLabels.length]; System.arraycopy(classPriors, 0, classProbs, 0, classProbs.length); for (int i = 0; i < modelInput.getSize(); i++) { for (int j = 0; j < classProbs.length; j++) { classProbs[j] += functions[j][modelInput.getIndex(i)].applyAsDouble(modelInput.getValue(i)); } } return classProbs; } @Override public String evaluate(VectorModelInput modelInput) { double[] classProbs = getClassProbs(modelInput); int bestClass = 0; // sum the values to get the actual probs double bestProb = Double.NEGATIVE_INFINITY; for (int i = 0; i < classProbs.length; i++) { if (bestProb < classProbs[i]) { bestClass = i; bestProb = classProbs[i]; } } return classLabels[bestClass]; } private Map<String, Object> prepareResult(double... val) { int bestClass = 0; // sum the values to get the actual probs double sumProb = 0; double bestProb = Double.NEGATIVE_INFINITY; for (int i = 0; i < val.length; i++) { if (bestProb < val[i]) { bestClass = i; bestProb = val[i]; } sumProb += Math.exp(val[i]); } Map<String, Object> results = new HashMap<>(); String classValue = classLabels[bestClass]; results.put("class", classValue); Map<String, Double> probMap = new HashMap<>(); for (int i = 0; i < val.length; i++) { probMap.put(classLabels[i], Math.exp(val[i]) / sumProb); } results.put("probs", probMap); return results; } public static class GaussFunction implements DoubleUnaryOperator { double variance; double mean; double varianceFactor; public GaussFunction(double variance, double mean) { this.variance = variance; this.mean = mean; varianceFactor = Math.log(Math.sqrt(2 * Math.PI * variance)); } @Override public double applyAsDouble(double value) { return -Math.pow((value - mean), 2) / (2 * variance) - varianceFactor; } } public static class ProbFunction implements DoubleUnaryOperator { double prob; public ProbFunction(double prob, double threshold) { if (prob == 0.0) { this.prob = Math.log(threshold); } else { this.prob = Math.log(prob); } } @Override public double applyAsDouble(double value) { return prob; } } }