/* * RapidMiner * * Copyright (C) 2001-2008 by Rapid-I and the contributors * * Complete list of developers available at our web site: * * http://rapid-i.com * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see http://www.gnu.org/licenses/. */ package com.rapidminer.operator.learner.meta; import java.awt.Component; import java.io.BufferedOutputStream; import java.io.File; import java.io.FileOutputStream; import java.io.IOException; import java.io.PrintStream; import java.util.ArrayList; import java.util.Iterator; import java.util.List; import javax.swing.JTabbedPane; import com.rapidminer.example.Attribute; import com.rapidminer.example.Example; import com.rapidminer.example.ExampleSet; import com.rapidminer.example.table.AttributeFactory; import com.rapidminer.gui.tools.ExtendedJTabbedPane; import com.rapidminer.operator.IOContainer; import com.rapidminer.operator.Model; import com.rapidminer.operator.OperatorException; import com.rapidminer.operator.UserError; import com.rapidminer.operator.learner.PredictionModel; import com.rapidminer.tools.LogService; import com.rapidminer.tools.Ontology; import com.rapidminer.tools.Tools; /** * A subgroup discovery model. * * @author Martin Scholz * @version $Id: SDEnsemble.java,v 1.11 2008/05/09 19:22:47 ingomierswa Exp $ */ public class SDEnsemble extends PredictionModel { private static final long serialVersionUID = 1320495411014477089L; public static final short RULE_COMBINE_ADDITIVE = 1; public static final short RULE_COMBINE_MULTIPLY = 2; // Holds the models and their weights in array format. // Please access with getter methods. private List modelInfo; // If set to a value i >= 0 then only the first i models are applied private int maxModelNumber = -1; private static final String MAX_MODEL_NUMBER = "iteration"; // name of a parameter that allows to specify a file to print predictions to private static final String PRED_TO_FILE = "predictions_to_file"; // the file to print to, null if turned off private File predictionsFile = null; // if set to true then some statistics are printed to stdout private boolean print_to_stdout = false; // The classes priors in the training set, starting with index 0. private double[] priors; /** * @param exampleSet * the example set used for training * @param modelInfo * a <code>List</code> of <code>Object[2]</code> arrays, each * entry holding a model and a <code>double[][]</code> array * containing weights for all prediction/label combinations. * @param priors * an array of the prior probabilities of labels */ public SDEnsemble(ExampleSet exampleSet, List modelInfo, double[] priors, short combinationMethod) { super(exampleSet); this.modelInfo = modelInfo; this.priors = priors; } public Component getVisualizationComponent(IOContainer container) { JTabbedPane tabPane = new ExtendedJTabbedPane(); for (int i = 0; i < this.getNumberOfModels(); i++) { Model model = this.getModel(i); tabPane.add("Model " + (i + 1), model.getVisualizationComponent(container)); } return tabPane; } /** @return a <code>String</code> representation of the ruleset. */ public String toString() { StringBuffer result = new StringBuffer(super.toString() + (Tools.getLineSeparator() + "Number of inner models: " + this.getNumberOfModels())); for (int i = 0; i < this.getNumberOfModels(); i++) { Model model = this.getModel(i); result.append((i > 0 ? Tools.getLineSeparator() : "") // + "Weights: " + this.getFactorForModel(i, true) + "," // + this.getFactorForModel(i, false) + " - " + "(Embedded model #" + i + "):" + model.toResultString()); } return result.toString(); } /** * Setting the parameter <code>MAX_MODEL_NUMBER</code> allows to discard * all but the first n models for specified n. <code>PRED_TO_FILE</code> * requires a filename on the local disk system the predictions of the * single classifiers are written to. <code>print_to_stdout</code> prints * some statistics about the base classifiers to the standard output. */ public void setParameter(String name, String value) throws OperatorException { if (name.equalsIgnoreCase("print_to_stdout")) { this.print_to_stdout = true; return; } else if (name.equalsIgnoreCase(PRED_TO_FILE)) { if (value != null) { String filename = value; File file = new File(filename); if (file.exists()) { boolean result = file.delete(); if (!result) LogService.getGlobal().logError("Cannot delete file: " + file); } try { file.createNewFile(); } catch (IOException e) { throw new UserError(null, 303, filename, e.getMessage()); } this.predictionsFile = file; return; } } else try { if (name.equalsIgnoreCase(MAX_MODEL_NUMBER)) { this.maxModelNumber = Integer.parseInt(value); return; } } catch (NumberFormatException e) {} super.setParameter(name, value); } /** @return the number of embedded models */ public int getNumberOfModels() { if (this.maxModelNumber >= 0) return Math.min(this.maxModelNumber, modelInfo.size()); else return modelInfo.size(); } /** * Gets weights for models in the case of general nominal class labels. The * indices are not in RapidMiner format, so add * <code>Attribute.FIRST_CLASS_INDEX</code> before calling this method and * before reading from the returned array. * * @return a <code>double[]</code> object with the weights to be applied * for each class if the corresponding rule yields * <code>predicted</code>. * @param modelNr * the number of the model * @param predicted * the predicted label * @return a <code>double[]</code> with one weight per class label. */ private double[] getWeightsForModel(int modelNr, int predicted) { Object[] obj = (Object[]) this.modelInfo.get(modelNr); double[][] weight = (double[][]) obj[1]; return weight[predicted]; } /** * Getter method for prior class probabilities estimated as the relative * frequencies in the training set. * * @param classIndex * the index of a class starting with 0 (not the internal representation!) * @return the prior probability of the specified class */ private double getPriorOfClass(int classIndex) { return this.priors[classIndex]; } /** * Getter method for embedded models * * @param index * the number of a model part of this boost model * @return binary or nominal decision model for the given classification * index. */ public Model getModel(int index) { Object[] obj = (Object[]) this.modelInfo.get(index); return (Model)obj[0]; } /** * Iterates over all models and returns the class with maximum likelihood. * * @param exampleSet * the set of examples to be classified */ public ExampleSet performPrediction(ExampleSet exampleSet, Attribute predictedLabelAttribute) throws OperatorException { // If parameter is set than the single predictions are written to file: PrintStream predOut = null; if (this.predictionsFile != null) { try { // Create an output stream to write the predictions to: predOut = new PrintStream(new BufferedOutputStream(new FileOutputStream(this.predictionsFile))); } catch (IOException e) { throw new UserError(null, 303, this.predictionsFile.getName(), e.getMessage()); } finally { if (predOut != null) { predOut.close(); } } } // Prepare an ExampleSet for each model. ExampleSet[] eSet = new ExampleSet[this.getNumberOfModels()]; for (int i = 0; i < this.getNumberOfModels(); i++) { Model model = this.getModel(i); eSet[i] = (ExampleSet) exampleSet.clone(); eSet[i] = model.apply(eSet[i]); } // Prepare one ExampleReader per ExampleSet List<Iterator<Example>> reader = new ArrayList<Iterator<Example>>(eSet.length); for (int r = 0; r < eSet.length; r++) { reader.add(eSet[r].iterator()); } // Apply all models: Iterator<Example> originalReader = exampleSet.iterator(); final int posIndex = SDRulesetInduction.getPosIndex(exampleSet.getAttributes().getLabel()); // <statistics per rule> int[] numCovered = new int[this.getNumberOfModels()]; int[] posCovered = new int[this.getNumberOfModels()]; int posTotal = 0; // </statistics per rule> while (originalReader.hasNext()) { Example example = originalReader.next(); double sumPos = 0; double sumTotal = 0; for (int k = 0; k < reader.size(); k++) { Example e = reader.get(k).next(); if (predOut != null) { predOut.print(e.getPredictedLabel() + " "); } double[] modelWeights; int predicted = ((int) e.getPredictedLabel()); modelWeights = this.getWeightsForModel(k, predicted); for (int i = 0; i < modelWeights.length; i++) { sumTotal += modelWeights[i]; } sumPos += modelWeights[posIndex]; if (this.print_to_stdout) { // statistics per rule int label = ((int) e.getLabel()); if (k == 0 && label == posIndex) { posTotal++; } // If "posIndex" is the wrong subset this will be corrected // later on. if (predicted == posIndex) { numCovered[k]++; if (label == predicted) posCovered[k]++; } } } // end of loop evaluating all models for a single example if (predOut != null) { predOut.println(example.getLabel()); // end line for the // predictions of this // example } if (sumTotal > 0) { sumPos /= sumTotal; } else { sumPos = this.getPriorOfClass(posIndex); } example.setPredictedLabel(sumPos); } // Closes the file storing the single predictions: if (predOut != null) { predOut.close(); } return exampleSet; } /** * Creates a predicted label with the given name. If name is null, the name * "prediction(labelname)" is used. */ protected Attribute createPredictedLabel(ExampleSet exampleSet) { Attribute predictedLabel = super.createPredictedLabel(exampleSet, getLabel()); return exampleSet.getAttributes().replace(predictedLabel, AttributeFactory.changeValueType(predictedLabel, Ontology.REAL)); } }