/**
* Copyright (C) 2001-2017 by RapidMiner and the contributors
*
* Complete list of developers available at our web site:
*
* http://rapidminer.com
*
* This program is free software: you can redistribute it and/or modify it under the terms of the
* GNU Affero General Public License as published by the Free Software Foundation, either version 3
* of the License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without
* even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License along with this program.
* If not, see http://www.gnu.org/licenses/.
*/
package com.rapidminer.operator.learner.meta;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import java.util.concurrent.atomic.AtomicInteger;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.set.ExampleSetUtilities;
import com.rapidminer.example.table.NominalMapping;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.learner.SimplePredictionModel;
import com.rapidminer.tools.RandomGenerator;
import com.rapidminer.tools.Tools;
/**
* A simple vote model. For classification problems, the majority class is chosen. For regression
* problems, the average prediction value is used. This model only supports simple prediction
* models.
*
* @author Ingo Mierswa
*/
public class SimpleVoteModel extends SimplePredictionModel implements MetaModel {
private static final long serialVersionUID = 1089932073805038503L;
private List<? extends SimplePredictionModel> baseModels;
private boolean labelIsNominal;
private List<Double> labelIndices;
public SimpleVoteModel(ExampleSet exampleSet, List<? extends SimplePredictionModel> baseModels) {
super(exampleSet, ExampleSetUtilities.SetsCompareOption.EQUAL,
ExampleSetUtilities.TypesCompareOption.ALLOW_SAME_PARENTS);
this.baseModels = baseModels;
labelIsNominal = getLabel().isNominal();
if (labelIsNominal) {
labelIndices = new LinkedList<>();
NominalMapping mapping = getLabel().getMapping();
List<String> mappingValues = mapping.getValues();
for (String value : mappingValues) {
double index = mapping.getIndex(value);
labelIndices.add(index);
}
}
}
@Override
public double predict(Example example) throws OperatorException {
if (labelIsNominal) {
Map<Double, AtomicInteger> classVotes = new TreeMap<>();
Iterator<? extends SimplePredictionModel> iterator = baseModels.iterator();
while (iterator.hasNext()) {
double prediction = iterator.next().predict(example);
AtomicInteger counter = classVotes.get(prediction);
if (counter == null) {
classVotes.put(prediction, new AtomicInteger(1));
} else {
counter.incrementAndGet();
}
}
List<Double> bestClasses = new LinkedList<>();
int bestClassesVotes = -1;
for (double currentClass : labelIndices) {
AtomicInteger votes = classVotes.get(currentClass);
if (votes != null) {
int currentVotes = votes.intValue();
if (currentVotes > bestClassesVotes) {
bestClasses.clear();
bestClasses.add(currentClass);
bestClassesVotes = currentVotes;
}
if (currentVotes == bestClassesVotes) {
bestClasses.add(currentClass);
}
example.setConfidence(getLabel().getMapping().mapIndex((int) currentClass), (double) currentVotes
/ (double) baseModels.size());
} else {
example.setConfidence(getLabel().getMapping().mapIndex((int) currentClass), 0.00);
}
}
if (bestClasses.size() == 1) {
return bestClasses.get(0);
} else {
return bestClasses.get(RandomGenerator.getGlobalRandomGenerator().nextInt(bestClasses.size()));
}
} else {
double sum = 0.0d;
Iterator<? extends SimplePredictionModel> iterator = baseModels.iterator();
while (iterator.hasNext()) {
sum += iterator.next().predict(example);
}
return sum / baseModels.size();
}
}
@Override
public String toString() {
StringBuffer buffer = new StringBuffer();
int i = 0;
for (SimplePredictionModel model : baseModels) {
buffer.append("Model " + i + ":" + Tools.getLineSeparator());
buffer.append("---" + Tools.getLineSeparator());
buffer.append(model.toString());
buffer.append(Tools.getLineSeparators(2));
i++;
}
return buffer.toString();
}
@Override
public List<String> getModelNames() {
List<String> names = new LinkedList<>();
for (int i = 0; i < this.baseModels.size(); i++) {
names.add("Model " + (i + 1));
}
return names;
}
@Override
public List<? extends Model> getModels() {
return baseModels;
}
}