/* * RapidMiner * * Copyright (C) 2001-2008 by Rapid-I and the contributors * * Complete list of developers available at our web site: * * http://rapid-i.com * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see http://www.gnu.org/licenses/. */ package com.rapidminer.operator.learner.meta; import java.awt.BorderLayout; import java.awt.Component; import java.util.Iterator; import javax.swing.JLabel; import javax.swing.JPanel; import com.rapidminer.example.Attribute; import com.rapidminer.example.Example; import com.rapidminer.example.ExampleSet; import com.rapidminer.operator.IOContainer; import com.rapidminer.operator.Model; import com.rapidminer.operator.OperatorException; import com.rapidminer.operator.learner.PredictionModel; import com.rapidminer.tools.Tools; /** * Model for TransformedRegression. Applies the inverse transformation on the * predictions of the inner model. * * @author Stefan Rueping * @version $Id: TransformedRegressionModel.java,v 1.11 2006/03/21 15:35:48 * ingomierswa Exp $ */ public class TransformedRegressionModel extends PredictionModel { private static final long serialVersionUID = -1273082758742436998L; public static final String[] METHODS = { "log", "exp", "rank", "none" }; public static final int LOG = 0; public static final int EXP = 1; public static final int RANK = 2; public static final int NONE = 3; private int method; private double[] rank; private double mean; private double stddev; private Model model; private boolean interpolate; private boolean zscale; public TransformedRegressionModel(ExampleSet exampleSet, int method, double[] rank, Model model, boolean zscale, double mean, double stddev, boolean interpolate) { super(exampleSet); this.method = method; this.rank = rank; this.model = model; this.zscale = zscale; this.mean = mean; this.stddev = stddev; this.interpolate = interpolate; } /** Iterates over all examples and applies this model. */ public ExampleSet performPrediction(ExampleSet exampleSet, Attribute predictedLabelAttribute) throws OperatorException { ExampleSet eSet = (ExampleSet) exampleSet.clone(); eSet = model.apply(eSet); Iterator<Example> reader = eSet.iterator(); Iterator<Example> originalReader = exampleSet.iterator(); switch (method) { case LOG: while (originalReader.hasNext()) { double functionValue = reader.next().getPredictedLabel(); if (zscale) { // if(zscale) is quicker and has less chance of // numerical errors functionValue = functionValue * stddev + mean; } Example example = originalReader.next(); example.setPredictedLabel(Math.exp(functionValue) - rank[0]); } break; case EXP: while (originalReader.hasNext()) { double functionValue = reader.next().getPredictedLabel(); if (zscale) { // if(zscale) is quicker and has less chance of // numerical errors functionValue = functionValue * stddev + mean; } Example example = originalReader.next(); example.setPredictedLabel(Math.log(functionValue)); } break; case RANK: while (originalReader.hasNext()) { double predictedRank = reader.next().getPredictedLabel(); if (zscale) { // if(zscale) is quicker and has less chance of // numerical errors predictedRank = predictedRank * stddev + mean; } Example example = originalReader.next(); if (interpolate) { int lower = (int) Math.round(Math.floor(predictedRank)); int upper = (int) Math.round(Math.ceil(predictedRank)); if (lower < 0) lower = 0; if (lower >= rank.length) lower = rank.length - 1; if (upper < 0) upper = 0; if (upper >= rank.length) upper = rank.length - 1; if (!(upper == lower)) { predictedRank = (upper - predictedRank) * rank[lower] + (predictedRank - lower) * rank[upper]; } else { predictedRank = rank[lower]; } } else { int thisRank = (int) Math.round(predictedRank); if (thisRank < 0) thisRank = 0; if (thisRank >= rank.length) thisRank = rank.length - 1; predictedRank = rank[thisRank]; } example.setPredictedLabel(predictedRank); } break; case NONE: if (zscale) { while (originalReader.hasNext()) { double functionValue = reader.next().getPredictedLabel() * stddev + mean; Example example = originalReader.next(); example.setPredictedLabel(functionValue); } } break; default: // cannot happen break; } return exampleSet; } public Component getVisualizationComponent(IOContainer container) { JPanel result = new JPanel(); result.setLayout(new BorderLayout()); result.add(new JLabel("Method: " + METHODS[method]), BorderLayout.NORTH); result.add(model.getVisualizationComponent(container), BorderLayout.CENTER); return result; } public String toString() { StringBuffer result = new StringBuffer(super.toString() + Tools.getLineSeparator()); result.append("Method: " + METHODS[method] + Tools.getLineSeparator()); result.append(model.toString()); return result.toString(); } }