/* * RapidMiner * * Copyright (C) 2001-2008 by Rapid-I and the contributors * * Complete list of developers available at our web site: * * http://rapid-i.com * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see http://www.gnu.org/licenses/. */ package com.rapidminer.operator.learner.meta; import java.awt.Component; import java.util.Iterator; import java.util.List; import javax.swing.JTabbedPane; import com.rapidminer.example.Attribute; import com.rapidminer.example.Example; import com.rapidminer.example.ExampleSet; import com.rapidminer.gui.tools.ExtendedJTabbedPane; import com.rapidminer.operator.IOContainer; import com.rapidminer.operator.Model; import com.rapidminer.operator.OperatorException; import com.rapidminer.operator.learner.PredictionModel; import com.rapidminer.tools.Ontology; import com.rapidminer.tools.Tools; /** * The model for the internal Bagging implementation. * * @author Martin Scholz, Ingo Mierswa * @version $Id: BaggingModel.java,v 1.7 2008/09/23 14:07:33 ingomierswa Exp $ */ public class BaggingModel extends PredictionModel { private static final long serialVersionUID = -4691755811263523354L; /** Holds the models. */ private List<Model> models; public BaggingModel(ExampleSet exampleSet, List<Model> models) { super(exampleSet); this.models = models; } /** @return the number of embedded models */ public int getNumberOfModels() { return this.models.size(); } /** * Getter method for embedded models * * @param index * the number of a model part of this boost model * @return binary or nominal decision model */ public Model getModel(int index) { return this.models.get(index); } /** * Iterates over all models and averages confidences. * * @param origExampleSet * the set of examples to be classified */ public ExampleSet performPrediction(ExampleSet origExampleSet, Attribute predictedLabel) throws OperatorException { if (predictedLabel.isNominal()) { // nominal prediction final String attributePrefix = "BaggingModelPrediction"; final int numLabels = predictedLabel.getMapping().size(); final Attribute[] specialAttributes = new Attribute[numLabels]; for (int i = 0; i < numLabels; i++) { specialAttributes[i] = com.rapidminer.example.Tools.createSpecialAttribute(origExampleSet, attributePrefix + i, Ontology.NUMERICAL); } Iterator<Example> reader = origExampleSet.iterator(); while (reader.hasNext()) { Example example = reader.next(); for (int i = 0; i < specialAttributes.length; i++) { example.setValue(specialAttributes[i], 0); } } reader = origExampleSet.iterator(); for (int modelNr = 0; modelNr < this.getNumberOfModels(); modelNr++) { Model model = this.getModel(modelNr); ExampleSet exampleSet = (ExampleSet) origExampleSet.clone(); exampleSet = model.apply(exampleSet); this.updateEstimates(exampleSet, modelNr, specialAttributes); PredictionModel.removePredictedLabel(exampleSet); } // Turn prediction weights into confidences and a crisp prediction: this.evaluateSpecialAttributes(origExampleSet, specialAttributes); // Clean up attributes: for (int i = 0; i < numLabels; i++) { origExampleSet.getAttributes().remove(specialAttributes[i]); origExampleSet.getExampleTable().removeAttribute(specialAttributes[i]); } return origExampleSet; } else { // numerical prediction double[] predictionSums = new double[origExampleSet.size()]; for (Model model : models) { ExampleSet resultSet = model.apply(origExampleSet); int index = 0; Attribute innerPredictedLabel = resultSet.getAttributes().getPredictedLabel(); for (Example example : resultSet) { predictionSums[index++] += example.getValue(innerPredictedLabel); } PredictionModel.removePredictedLabel(resultSet); } int index = 0; for (Example example : origExampleSet) { example.setValue(predictedLabel, predictionSums[index++] / (double)models.size()); } return origExampleSet; } } private void updateEstimates(ExampleSet exampleSet, int modelNr, Attribute[] specialAttributes) { final int numModels = this.getNumberOfModels(); final int numClasses = this.getLabel().getMapping().size(); Iterator<Example> reader = exampleSet.iterator(); while (reader.hasNext()) { Example example = reader.next(); for (int i=0; i<numClasses; i++) { String consideredPrediction = this.getLabel().getMapping().mapIndex(i); double confidence = example.getConfidence(consideredPrediction); double value = example.getValue(specialAttributes[i]); value += confidence / numModels; example.setValue(specialAttributes[i], value); } } } private void evaluateSpecialAttributes(ExampleSet exampleSet, Attribute[] specialAttributes) { Attribute exSetLabel = exampleSet.getAttributes().getLabel(); Iterator<Example> reader = exampleSet.iterator(); while (reader.hasNext()) { Example example = reader.next(); int bestLabel = 0; double bestConf = -1; for (int n = 0; n < specialAttributes.length; n++) { double curConf = example.getValue(specialAttributes[n]); String curPredS = this.getLabel().getMapping().mapIndex(n); example.setConfidence(curPredS, curConf); if (curConf > bestConf) { bestConf = curConf; bestLabel = n; } } example.setValue(example.getAttributes().getPredictedLabel(), exSetLabel.getMapping().mapString(this.getLabel().getMapping().mapIndex(bestLabel))); } } public Component getVisualizationComponent(IOContainer container) { JTabbedPane tabPane = new ExtendedJTabbedPane(); for (int i = 0; i < this.getNumberOfModels(); i++) { Model model = this.getModel(i); tabPane.add("Model " + (i + 1), model.getVisualizationComponent(container)); } return tabPane; } /** @return a <code>String</code> representation of this boosting model. */ public String toString() { StringBuffer result = new StringBuffer(super.toString() + Tools.getLineSeparator() + "Number of inner models: " + this.getNumberOfModels() + Tools.getLineSeparators(2)); for (int i = 0; i < this.getNumberOfModels(); i++) { Model model = this.getModel(i); result.append((i > 0 ? Tools.getLineSeparator() : "") + "Embedded model #" + i + ":" + Tools.getLineSeparator() + model.toResultString()); } return result.toString(); } }