/* * RapidMiner * * Copyright (C) 2001-2008 by Rapid-I and the contributors * * Complete list of developers available at our web site: * * http://rapid-i.com * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see http://www.gnu.org/licenses/. */ package com.rapidminer.operator.learner.meta; import java.awt.Component; import java.util.Iterator; import javax.swing.JTabbedPane; import com.rapidminer.example.Attribute; import com.rapidminer.example.Example; import com.rapidminer.example.ExampleSet; import com.rapidminer.gui.tools.ExtendedJTabbedPane; import com.rapidminer.operator.IOContainer; import com.rapidminer.operator.Model; import com.rapidminer.operator.OperatorException; import com.rapidminer.operator.learner.PredictionModel; import com.rapidminer.tools.Tools; /** * The model created by an AdditiveRegression meta learner. * * @author Ingo Mierswa * @version $Id: AdditiveRegressionModel.java,v 1.7 2008/05/09 19:22:47 ingomierswa Exp $ */ public class AdditiveRegressionModel extends PredictionModel { private static final long serialVersionUID = -8036434608645810089L; private Model defaultModel; private Model[] residualModels; private double shrinkage; public AdditiveRegressionModel(ExampleSet exampleSet, Model defaultModel, Model[] residualModels, double shrinkage) { super(exampleSet); this.defaultModel = defaultModel; this.residualModels = residualModels; this.shrinkage = shrinkage; } public ExampleSet performPrediction(ExampleSet exampleSet, Attribute predictedLabel) throws OperatorException { // apply default model exampleSet = defaultModel.apply(exampleSet); double[] predictions = new double[exampleSet.size()]; Iterator<Example> e = exampleSet.iterator(); int counter = 0; while (e.hasNext()) { predictions[counter++] = e.next().getPredictedLabel(); } PredictionModel.removePredictedLabel(exampleSet); // apply all models to the example set sum up the predictions for (int i = 0; i < residualModels.length; i++) { exampleSet = residualModels[i].apply(exampleSet); e = exampleSet.iterator(); counter = 0; while (e.hasNext()) { predictions[counter++] += shrinkage * e.next().getPredictedLabel(); } PredictionModel.removePredictedLabel(exampleSet); } // set final predictions e = exampleSet.iterator(); counter = 0; Attribute newPredictedLabel = createPredictedLabel(exampleSet, getLabel()); while (e.hasNext()) { e.next().setValue(newPredictedLabel, predictions[counter++]); } return exampleSet; } public Component getVisualizationComponent(IOContainer container) { JTabbedPane tabPane = new ExtendedJTabbedPane(); tabPane.add("Default Model", defaultModel.getVisualizationComponent(container)); int index = 1; for (Model residualModel : residualModels) { tabPane.add("Model " + index, residualModel.getVisualizationComponent(container)); index++; } return tabPane; } public String toString() { StringBuffer result = new StringBuffer(super.toString()); result.append("Default model:" + Tools.getLineSeparator() + this.defaultModel.toString() + Tools.getLineSeparator() + Tools.getLineSeparator()); result.append("Number of base models: " + this.residualModels.length + Tools.getLineSeparator()); result.append("Shrinkage: " + this.shrinkage + Tools.getLineSeparator()); for (int i = 0; i < this.residualModels.length; i++) { result.append(Tools.getLineSeparator() + Tools.ordinalNumber(i+1) + " Model:" + Tools.getLineSeparator() + this.residualModels[i] + Tools.getLineSeparator()); } return result.toString(); } }