/* * Licensed to Elasticsearch under one or more contributor * license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright * ownership. Elasticsearch licenses this file to you under * the Apache License, Version 2.0 (the "License"); you may * not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.elasticsearch.ml.training; import org.dmg.pmml.BayesInput; import org.dmg.pmml.BayesInputs; import org.dmg.pmml.BayesOutput; import org.dmg.pmml.DataDictionary; import org.dmg.pmml.DataField; import org.dmg.pmml.DataType; import org.dmg.pmml.FieldName; import org.dmg.pmml.FieldUsageType; import org.dmg.pmml.GaussianDistribution; import org.dmg.pmml.MiningField; import org.dmg.pmml.MiningFunctionType; import org.dmg.pmml.MiningSchema; import org.dmg.pmml.NaiveBayesModel; import org.dmg.pmml.OpType; import org.dmg.pmml.PMML; import org.dmg.pmml.PairCounts; import org.dmg.pmml.TargetValueCount; import org.dmg.pmml.TargetValueCounts; import org.dmg.pmml.TargetValueStat; import org.dmg.pmml.TargetValueStats; import org.dmg.pmml.Value; import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.cluster.metadata.MappingMetaData; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.search.aggregations.Aggregation; import org.elasticsearch.search.aggregations.AggregationBuilder; import org.elasticsearch.search.aggregations.Aggregations; import org.elasticsearch.search.aggregations.bucket.terms.Terms; import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder; import org.elasticsearch.search.aggregations.metrics.stats.extended.ExtendedStats; import org.jpmml.model.JAXBUtil; import javax.xml.bind.JAXBException; import javax.xml.transform.stream.StreamResult; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.nio.charset.Charset; import java.security.AccessController; import java.security.PrivilegedAction; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Set; import java.util.TreeMap; import java.util.TreeSet; import static org.elasticsearch.search.aggregations.AggregationBuilders.extendedStats; import static org.elasticsearch.search.aggregations.AggregationBuilders.terms; /** * */ public class NaiveBayesModelTrainer implements ModelTrainer { @Override public String modelType() { return "naive_bayes"; } @Override public TrainingSession createTrainingSession(MappingMetaData mappingMetaData, List<ModelInputField> inputs, ModelTargetField target, Settings settings) { return new NaiveBayesTrainingSession(mappingMetaData, inputs, target, settings); } private static class NaiveBayesTrainingSession implements TrainingSession { final TermsAggregationBuilder termsAggregationBuilder; private NaiveBayesTrainingSession(MappingMetaData mappingMetaData, List<ModelInputField> fields, ModelTargetField target, Settings settings) { TermsAggregationBuilder topLevelClassAgg = terms(target.getName()); topLevelClassAgg.field(target.getName()); topLevelClassAgg.size(Integer.MAX_VALUE); topLevelClassAgg.shardMinDocCount(1); topLevelClassAgg.minDocCount(1); topLevelClassAgg.order(Terms.Order.term(true)); Map<String, Object> fieldMappings = getFiledMappings(mappingMetaData); for (ModelInputField field : fields) { String fieldType = getFieldType(fieldMappings, field.getName()); if (fieldType == null) { throw new IllegalArgumentException("input field [" + field.getName() + "] not found"); } if (fieldType.equals("text") || fieldType.equals("keyword")) { topLevelClassAgg.subAggregation(terms(field.getName()).field(field.getName()) .size(Integer.MAX_VALUE).shardMinDocCount(1).minDocCount(1) .order(Terms.Order.term(true))); } else if (fieldType.equals("double") || fieldType.equals("float") || fieldType.equals("integer") || fieldType.equals("long")) { topLevelClassAgg.subAggregation(extendedStats(field.getName()).field(field.getName())); } else { throw new UnsupportedOperationException("have not implemented naive bayes training for anything but " + "number and string field yet"); } } termsAggregationBuilder = topLevelClassAgg; } @SuppressWarnings("unchecked") private Map<String, Object> getFiledMappings(MappingMetaData mappingMetaData) { try { return (Map<String, Object>) mappingMetaData.sourceAsMap().get("properties"); } catch (IOException ex) { throw new IllegalStateException(ex); } } @SuppressWarnings("unchecked") private String getFieldType(Map<String, Object> fieldMappings, String field) { Map<String, Object> attributes = (Map<String, Object>) fieldMappings.get(field); return (String) attributes.get("type"); } @Override public AggregationBuilder trainingRequest() { return termsAggregationBuilder; } @Override public String model(SearchResponse searchResponse) { NaiveBayesModel naiveBayesModel = new NaiveBayesModel(); Aggregations aggs = searchResponse.getAggregations(); Terms classAgg = (Terms) aggs.asList().get(0); int numClasses = classAgg.getBuckets().size(); long[] classCounts = new long[numClasses]; String[] classLabels = new String[numClasses]; int classCounter = 0; for (Terms.Bucket bucket : classAgg.getBuckets()) { classCounts[classCounter] = bucket.getDocCount(); classLabels[classCounter] = bucket.getKeyAsString(); classCounter++; } if (classCounter < 2) { throw new RuntimeException("Need at least two classes for naive bayes!"); } setTargetValueCounts(naiveBayesModel, classAgg, classCounts, classLabels); // field, value, class -> count TreeMap<String, TreeMap<String, TreeMap<String, Long>>> stringFieldValueCounts = new TreeMap<>(); TreeMap<String, TreeSet<String>> allTermsPerField = new TreeMap<>(); TreeMap<String, TreeMap<String, Map<String, Double>>> numericFieldStats = new TreeMap<>(); for (Terms.Bucket bucket : classAgg.getBuckets()) { String className = bucket.getKeyAsString(); for (Aggregation aggregation : bucket.getAggregations()) { String fieldName = aggregation.getName(); if (aggregation instanceof Terms) { Terms termAgg = (Terms) aggregation; // init the data structure if not present if (stringFieldValueCounts.containsKey(fieldName) == false) { stringFieldValueCounts.put(fieldName, new TreeMap<>()); allTermsPerField.put(fieldName, new TreeSet<>()); } TreeMap<String, TreeMap<String, Long>> valueCounts = stringFieldValueCounts.get(fieldName); for (Terms.Bucket termBucket : termAgg.getBuckets()) { String value = termBucket.getKeyAsString(); if (valueCounts.containsKey(value) == false) { valueCounts.put(value, new TreeMap<>()); } TreeMap<String, Long> termCountsPerClass = valueCounts.get(value); allTermsPerField.get(fieldName).add(termBucket.getKeyAsString()); termCountsPerClass.put(className, termBucket.getDocCount()); } } else if (aggregation instanceof ExtendedStats) { ExtendedStats extendedStats = (ExtendedStats) aggregation; if (numericFieldStats.containsKey(fieldName) == false) { numericFieldStats.put(fieldName, new TreeMap<>()); } Map<String, Double> stats = new HashMap<>(); stats.put("mean", extendedStats.getAvg()); stats.put("variance", extendedStats.getVariance()); numericFieldStats.get(fieldName).put(className, stats); } else { throw new RuntimeException("unsupported agg " + aggregation.getClass().getName()); } } } setBayesInputs(naiveBayesModel, stringFieldValueCounts, numericFieldStats, classLabels); naiveBayesModel.setFunctionName(MiningFunctionType.CLASSIFICATION); final PMML pmml = new PMML(); setDataDictionary(pmml, allTermsPerField, numericFieldStats.keySet()); setMiningFields(naiveBayesModel, allTermsPerField.keySet(), numericFieldStats.keySet(), classAgg.getName()); naiveBayesModel.setThreshold(1.0 / searchResponse.getHits().totalHits()); pmml.addModels(naiveBayesModel); final StreamResult streamResult = new StreamResult(); ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); streamResult.setOutputStream(outputStream); AccessController.doPrivileged(new PrivilegedAction<Object>() { public Object run() { try { JAXBUtil.marshal(pmml, streamResult); } catch (JAXBException e) { throw new RuntimeException("No idea what went wrong here", e); } return null; } }); return new String(outputStream.toByteArray(), Charset.defaultCharset()); } } private static void setMiningFields(NaiveBayesModel naiveBayesModel, Set<String> categoricalFields, Set<String> numericFields, String classField) { MiningSchema miningSchema = new MiningSchema(); for (String fieldName : categoricalFields) { MiningField miningField = new MiningField(); miningField.setName(new FieldName(fieldName)); miningField.setUsageType(FieldUsageType.ACTIVE); miningSchema.addMiningFields(miningField); } for (String fieldName : numericFields) { MiningField miningField = new MiningField(); miningField.setName(new FieldName(fieldName)); miningField.setUsageType(FieldUsageType.ACTIVE); miningSchema.addMiningFields(miningField); } MiningField miningField = new MiningField(); miningField.setName(new FieldName(classField)); miningField.setUsageType(FieldUsageType.PREDICTED); miningSchema.addMiningFields(miningField); naiveBayesModel.setMiningSchema(miningSchema); } private static void setBayesInputs(NaiveBayesModel naiveBayesModel, TreeMap<String, TreeMap<String, TreeMap<String, Long>>> stringFieldValueCounts, TreeMap<String, TreeMap<String, Map<String, Double>>> numericFieldStats, String[] classNames) { BayesInputs bayesInputs = new BayesInputs(); for (Map.Entry<String, TreeMap<String, TreeMap<String, Long>>> categoricalField : stringFieldValueCounts.entrySet()) { String fieldName = categoricalField.getKey(); BayesInput bayesInput = new BayesInput(); bayesInput.setFieldName(new FieldName(fieldName)); for (Map.Entry<String, TreeMap<String, Long>> valueCounts : categoricalField.getValue().entrySet()) { String value = valueCounts.getKey(); PairCounts pairCounts = new PairCounts(); pairCounts.setValue(value); TargetValueCounts targetValueCounts = new TargetValueCounts(); TreeMap<String, Long> classCounts = valueCounts.getValue(); for (String className : classNames) { if (classCounts.containsKey(className)) { targetValueCounts.addTargetValueCounts(new TargetValueCount().setValue(className).setCount(classCounts.get (className))); } else { targetValueCounts.addTargetValueCounts(new TargetValueCount().setValue(className).setCount(0)); } } pairCounts.setTargetValueCounts(targetValueCounts); bayesInput.addPairCounts(pairCounts); } bayesInputs.addBayesInputs(bayesInput); } for (Map.Entry<String, TreeMap<String, Map<String, Double>>> continuousField : numericFieldStats.entrySet()) { String fieldName = continuousField.getKey(); BayesInput bayesInput = new BayesInput(); bayesInput.setFieldName(new FieldName(fieldName)); TargetValueStats targetValueStats = new TargetValueStats(); for (Map.Entry<String, Map<String, Double>> valueStats : continuousField.getValue().entrySet()) { String className = valueStats.getKey(); GaussianDistribution gaussianDistribution = new GaussianDistribution(); gaussianDistribution.setMean(valueStats.getValue().get("mean")); gaussianDistribution.setVariance(valueStats.getValue().get("variance")); TargetValueStat targetValueStat = new TargetValueStat(); targetValueStat.setValue(className); targetValueStat.setContinuousDistribution(gaussianDistribution); targetValueStats.addTargetValueStats(targetValueStat); } bayesInput.setTargetValueStats(targetValueStats); bayesInputs.addBayesInputs(bayesInput); } naiveBayesModel.setBayesInputs(bayesInputs); } private static void setDataDictionary(PMML pmml, TreeMap<String, TreeSet<String>> allTermsPerField, Set<String> numericFieldsNames) { DataDictionary dataDictionary = new DataDictionary(); for (Map.Entry<String, TreeSet<String>> fieldNameAndTerms : allTermsPerField.entrySet()) { DataField dataField = new DataField(); dataField.setName(new FieldName(fieldNameAndTerms.getKey())); dataField.setOpType(OpType.CATEGORICAL); dataField.setDataType(DataType.STRING); for (String term : fieldNameAndTerms.getValue()) { dataField.addValues(new Value(term)); } dataDictionary.addDataFields(dataField); } for (String fieldname : numericFieldsNames) { DataField dataField = new DataField(); dataField.setName(new FieldName(fieldname)); dataField.setOpType(OpType.CONTINUOUS); // TODO: handle ints etc. dataField.setDataType(DataType.DOUBLE); dataDictionary.addDataFields(dataField); } pmml.setDataDictionary(dataDictionary); } private static void setTargetValueCounts(NaiveBayesModel naiveBayesModel, Terms classAgg, long[] classCounts, String[] classLabels) { TargetValueCounts targetValueCounts = new TargetValueCounts(); for (int i = 0; i < classLabels.length; i++) { TargetValueCount targetValueCount = new TargetValueCount(); targetValueCount.setValue(classLabels[i]); targetValueCount.setCount(classCounts[i]); targetValueCounts.addTargetValueCounts(targetValueCount); } naiveBayesModel.setBayesOutput(new BayesOutput().setFieldName(new FieldName(classAgg.getName())).setTargetValueCounts (targetValueCounts)); } }