/*
* RapidMiner
*
* Copyright (C) 2001-2008 by Rapid-I and the contributors
*
* Complete list of developers available at our web site:
*
* http://rapid-i.com
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
package com.rapidminer.operator.learner.meta;
import java.awt.BorderLayout;
import java.awt.Component;
import java.util.Iterator;
import javax.swing.JLabel;
import javax.swing.JPanel;
import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.operator.IOContainer;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.learner.PredictionModel;
import com.rapidminer.tools.Tools;
/**
* Model for TransformedRegression. Applies the inverse transformation on the
* predictions of the inner model.
*
* @author Stefan Rueping
* @version $Id: TransformedRegressionModel.java,v 1.11 2006/03/21 15:35:48
* ingomierswa Exp $
*/
public class TransformedRegressionModel extends PredictionModel {
private static final long serialVersionUID = -1273082758742436998L;
public static final String[] METHODS = { "log", "exp", "rank", "none" };
public static final int LOG = 0;
public static final int EXP = 1;
public static final int RANK = 2;
public static final int NONE = 3;
private int method;
private double[] rank;
private double mean;
private double stddev;
private Model model;
private boolean interpolate;
private boolean zscale;
public TransformedRegressionModel(ExampleSet exampleSet, int method, double[] rank, Model model, boolean zscale, double mean, double stddev, boolean interpolate) {
super(exampleSet);
this.method = method;
this.rank = rank;
this.model = model;
this.zscale = zscale;
this.mean = mean;
this.stddev = stddev;
this.interpolate = interpolate;
}
/** Iterates over all examples and applies this model. */
public ExampleSet performPrediction(ExampleSet exampleSet, Attribute predictedLabelAttribute) throws OperatorException {
ExampleSet eSet = (ExampleSet) exampleSet.clone();
eSet = model.apply(eSet);
Iterator<Example> reader = eSet.iterator();
Iterator<Example> originalReader = exampleSet.iterator();
switch (method) {
case LOG:
while (originalReader.hasNext()) {
double functionValue = reader.next().getPredictedLabel();
if (zscale) {
// if(zscale) is quicker and has less chance of
// numerical errors
functionValue = functionValue * stddev + mean;
}
Example example = originalReader.next();
example.setPredictedLabel(Math.exp(functionValue) - rank[0]);
}
break;
case EXP:
while (originalReader.hasNext()) {
double functionValue = reader.next().getPredictedLabel();
if (zscale) {
// if(zscale) is quicker and has less chance of
// numerical errors
functionValue = functionValue * stddev + mean;
}
Example example = originalReader.next();
example.setPredictedLabel(Math.log(functionValue));
}
break;
case RANK:
while (originalReader.hasNext()) {
double predictedRank = reader.next().getPredictedLabel();
if (zscale) {
// if(zscale) is quicker and has less chance of
// numerical errors
predictedRank = predictedRank * stddev + mean;
}
Example example = originalReader.next();
if (interpolate) {
int lower = (int) Math.round(Math.floor(predictedRank));
int upper = (int) Math.round(Math.ceil(predictedRank));
if (lower < 0)
lower = 0;
if (lower >= rank.length)
lower = rank.length - 1;
if (upper < 0)
upper = 0;
if (upper >= rank.length)
upper = rank.length - 1;
if (!(upper == lower)) {
predictedRank = (upper - predictedRank) * rank[lower] + (predictedRank - lower) * rank[upper];
} else {
predictedRank = rank[lower];
}
} else {
int thisRank = (int) Math.round(predictedRank);
if (thisRank < 0)
thisRank = 0;
if (thisRank >= rank.length)
thisRank = rank.length - 1;
predictedRank = rank[thisRank];
}
example.setPredictedLabel(predictedRank);
}
break;
case NONE:
if (zscale) {
while (originalReader.hasNext()) {
double functionValue = reader.next().getPredictedLabel() * stddev + mean;
Example example = originalReader.next();
example.setPredictedLabel(functionValue);
}
}
break;
default:
// cannot happen
break;
}
return exampleSet;
}
public Component getVisualizationComponent(IOContainer container) {
JPanel result = new JPanel();
result.setLayout(new BorderLayout());
result.add(new JLabel("Method: " + METHODS[method]), BorderLayout.NORTH);
result.add(model.getVisualizationComponent(container), BorderLayout.CENTER);
return result;
}
public String toString() {
StringBuffer result = new StringBuffer(super.toString() + Tools.getLineSeparator());
result.append("Method: " + METHODS[method] + Tools.getLineSeparator());
result.append(model.toString());
return result.toString();
}
}