/** * Copyright (C) 2001-2017 by RapidMiner and the contributors * * Complete list of developers available at our web site: * * http://rapidminer.com * * This program is free software: you can redistribute it and/or modify it under the terms of the * GNU Affero General Public License as published by the Free Software Foundation, either version 3 * of the License, or (at your option) any later version. * * This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without * even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU * Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License along with this program. * If not, see http://www.gnu.org/licenses/. */ package com.rapidminer.operator.learner.meta; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Map.Entry; import com.rapidminer.example.Example; import com.rapidminer.example.ExampleSet; import com.rapidminer.example.set.ExampleSetUtilities; import com.rapidminer.operator.Model; import com.rapidminer.operator.OperatorException; import com.rapidminer.operator.learner.SimplePredictionModel; import com.rapidminer.tools.RandomGenerator; import com.rapidminer.tools.Tools; /** * {@link MetaModel} that bases its decision on the arithmetic mean of the confidence values of the * given {@link SimplePredictionModel}s. This value is at the same time used as confidence value for * the prediction. * <p> * This meta model only works with {@link SimplePredictionModel}s that calculate meaningful * confidence values and predict a nominal label. * * @author Zoltan Prekopcsak, Michael Knopf * @since 7.0.0 */ public class ConfidenceVoteModel extends SimplePredictionModel implements MetaModel { private static final long serialVersionUID = 1L; /** List of voting models. */ private List<? extends SimplePredictionModel> models; /** * Creates a new {@link MetaModel} with confidence based voting for the given example set and * models. * * @param exampleSet * the example set * @param models * the voting models * @throws IllegalArgumentException * if the given example set's label is not nominal */ public ConfidenceVoteModel(ExampleSet exampleSet, List<? extends SimplePredictionModel> models) { super(exampleSet, ExampleSetUtilities.SetsCompareOption.EQUAL, ExampleSetUtilities.TypesCompareOption.ALLOW_SAME_PARENTS); if (!getLabel().isNominal()) { throw new IllegalArgumentException("Label must be nominal."); } this.models = models; } @Override public String toString() { StringBuilder buffer = new StringBuilder(); int i = 1; for (SimplePredictionModel model : models) { buffer.append(i); buffer.append(") "); buffer.append(model.getName()); buffer.append(Tools.getLineSeparator()); buffer.append("---"); buffer.append(Tools.getLineSeparator()); buffer.append(model.toString()); buffer.append(Tools.getLineSeparators(2)); i++; } return buffer.toString(); } @Override public List<? extends Model> getModels() { return Collections.unmodifiableList(models); } @Override public List<String> getModelNames() { List<String> names = new ArrayList<>(models.size()); for (SimplePredictionModel model : models) { names.add(model.getName()); } return names; } @Override public double predict(Example example) throws OperatorException { Map<String, Double> classConfidenceSums = new HashMap<>(); for (SimplePredictionModel model : models) { model.predict(example); for (String className : getLabel().getMapping().getValues()) { Double classConfidence = example.getConfidence(className); if (Double.isNaN(classConfidence)) { throw new OperatorException("Child model failed to compute confidence value."); } Double currentSum = classConfidenceSums.get(className); if (currentSum == null) { classConfidenceSums.put(className, classConfidence); } else { classConfidenceSums.put(className, currentSum + classConfidence); } } } // normalize confidence sums for (Entry<String, Double> entry : classConfidenceSums.entrySet()) { entry.setValue(entry.getValue() / models.size()); } List<String> bestClasses = new ArrayList<>(classConfidenceSums.size()); double maxConfidence = -1; for (Entry<String, Double> entry : classConfidenceSums.entrySet()) { String className = entry.getKey(); double confidence = entry.getValue(); if (confidence > maxConfidence) { maxConfidence = confidence; bestClasses.clear(); } if (confidence == maxConfidence) { bestClasses.add(className); } example.setConfidence(className, confidence); } if (bestClasses.size() == 1) { return getLabel().getMapping().getIndex(bestClasses.get(0)); } else { return getLabel().getMapping() .getIndex(bestClasses.get(RandomGenerator.getGlobalRandomGenerator().nextInt(bestClasses.size()))); } } }