/**
* Copyright (C) 2001-2017 by RapidMiner and the contributors
*
* Complete list of developers available at our web site:
*
* http://rapidminer.com
*
* This program is free software: you can redistribute it and/or modify it under the terms of the
* GNU Affero General Public License as published by the Free Software Foundation, either version 3
* of the License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without
* even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License along with this program.
* If not, see http://www.gnu.org/licenses/.
*/
package com.rapidminer.operator.learner.meta;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.OperatorProgress;
import com.rapidminer.operator.learner.PredictionModel;
import com.rapidminer.tools.Tools;
/**
* MultiModels are used for multi class learning tasks. A MultiModel contains a set of Models that
* can handle only two-class decisions. In this case, the models must be regression models which are
* combined.
*
* @author Ingo Mierswa, Simon Fischer
*/
public class MultiModelByRegression extends PredictionModel implements MetaModel {
private static final long serialVersionUID = 4526668088304067678L;
private static final int OPERATOR_PROGRESS_STEPS = 5000;
private Model[] models;
public MultiModelByRegression(ExampleSet exampleSet, Model[] models) {
super(exampleSet, null, null);
this.models = models;
}
public int getNumberOfModels() {
return models.length;
}
/** Returns a binary decision model for the given classification index. */
public Model getModel(int index) {
return models[index];
}
/**
* Iterates over all classes of the label and applies one model for each class. For each example
* the predicted label is determined by choosing the model with the highest confidence.
*/
@Override
public ExampleSet performPrediction(ExampleSet exampleSet, Attribute predictedLabelAttribute) throws OperatorException {
// initialize progress
OperatorProgress progress = null;
if (getShowProgress() && getOperator() != null && getOperator().getProgress() != null) {
progress = getOperator().getProgress();
progress.setTotal(100);
}
ExampleSet[] eSet = new ExampleSet[getNumberOfModels()];
for (int i = 0; i < getNumberOfModels(); i++) {
Model model = getModel(i);
eSet[i] = (ExampleSet) exampleSet.clone();
eSet[i] = model.apply(eSet[i]);
if (progress != null) {
progress.setCompleted((int) (50.0 * (i + 1) / getNumberOfModels()));
}
}
List<Iterator<Example>> reader = new ArrayList<Iterator<Example>>(eSet.length);
for (int r = 0; r < eSet.length; r++) {
reader.add(eSet[r].iterator());
}
Iterator<Example> originalReader = exampleSet.iterator();
Attribute predictedLabel = exampleSet.getAttributes().getPredictedLabel();
int progressCounter = 0;
while (originalReader.hasNext()) {
double bestLabel = Double.NaN;
double highestFunctionValue = Double.NEGATIVE_INFINITY;
for (int k = 0; k < reader.size(); k++) {
double functionValue = reader.get(k).next().getPredictedLabel();
if (functionValue > highestFunctionValue) {
highestFunctionValue = functionValue;
bestLabel = k;
}
}
Example example = originalReader.next();
example.setPredictedLabel(bestLabel);
example.setConfidence(predictedLabel.getMapping().mapIndex((int) bestLabel), 1.0d);
if (progress != null && ++progressCounter % OPERATOR_PROGRESS_STEPS == 0) {
progress.setCompleted((int) (50 + 50.0 * progressCounter / exampleSet.size()));
}
}
return exampleSet;
}
@Override
public String toString() {
StringBuffer result = new StringBuffer(super.toString() + Tools.getLineSeparator());
for (int i = 0; i < models.length; i++) {
result.append((i > 0 ? Tools.getLineSeparator() : "") + models[i].toString());
}
return result.toString();
}
@Override
public List<String> getModelNames() {
List<String> names = new LinkedList<String>();
for (int i = 0; i < this.getNumberOfModels(); i++) {
names.add("Model " + (i + 1));
}
return names;
}
@Override
public List<Model> getModels() {
return Arrays.asList(models);
}
}