/*
* RapidMiner
*
* Copyright (C) 2001-2008 by Rapid-I and the contributors
*
* Complete list of developers available at our web site:
*
* http://rapid-i.com
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
package com.rapidminer.operator.learner.meta;
import java.awt.Component;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import java.util.concurrent.atomic.AtomicInteger;
import javax.swing.JTabbedPane;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.gui.tools.ExtendedJTabbedPane;
import com.rapidminer.operator.IOContainer;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.learner.SimplePredictionModel;
import com.rapidminer.tools.RandomGenerator;
/**
* A simple vote model. For classification problems, the majority class is chosen.
* For regression problems, the average prediction value is used. This model
* only supports simple prediction models.
*
* @author Ingo Mierswa
* @version $Id: SimpleVoteModel.java,v 1.6 2008/05/09 19:22:47 ingomierswa Exp $
*/
public class SimpleVoteModel extends SimplePredictionModel {
private static final long serialVersionUID = 1089932073805038503L;
private List<SimplePredictionModel> baseModels;
public SimpleVoteModel(ExampleSet exampleSet, List<SimplePredictionModel> baseModels) {
super(exampleSet);
this.baseModels = baseModels;
}
public double predict(Example example) throws OperatorException {
if (getLabel().isNominal()) {
Map<Double, AtomicInteger> classVotes = new TreeMap<Double, AtomicInteger>();
Iterator<SimplePredictionModel> iterator = baseModels.iterator();
while (iterator.hasNext()) {
double prediction = iterator.next().predict(example);
AtomicInteger counter = classVotes.get(prediction);
if (counter == null) {
classVotes.put(prediction, new AtomicInteger(1));
} else {
counter.incrementAndGet();
}
}
Iterator<Double> votedClasses = classVotes.keySet().iterator();
List<Double> bestClasses = new LinkedList<Double>();
int bestClassesVotes = -1;
while (votedClasses.hasNext()) {
double currentClass = votedClasses.next();
int currentVotes = classVotes.get(currentClass).intValue();
if (currentVotes > bestClassesVotes) {
bestClasses.clear();
bestClasses.add(currentClass);
bestClassesVotes = currentVotes;
}
if (currentVotes == bestClassesVotes) {
bestClasses.add(currentClass);
}
example.setConfidence(getLabel().getMapping().mapIndex((int)currentClass), ((double) currentVotes) / (double)baseModels.size());
}
if (bestClasses.size() == 1) {
return bestClasses.get(0);
} else {
return bestClasses.get(RandomGenerator.getGlobalRandomGenerator().nextInt(bestClasses.size()));
}
} else {
double sum = 0.0d;
Iterator<SimplePredictionModel> iterator = baseModels.iterator();
while (iterator.hasNext()) {
sum += iterator.next().predict(example);
}
return sum / baseModels.size();
}
}
public Component getVisualizationComponent(IOContainer container) {
JTabbedPane tabPane = new ExtendedJTabbedPane();
int index = 1;
for (Model model : baseModels) {
tabPane.add("Model " + index, model.getVisualizationComponent(container));
index++;
}
return tabPane;
}
}