/* * RapidMiner * * Copyright (C) 2001-2007 by Rapid-I and the contributors * * Complete list of developers available at our web site: * * http://rapid-i.com * * This program is free software; you can redistribute it and/or * modify it under the terms of the GNU General Public License as * published by the Free Software Foundation; either version 2 of the * License, or (at your option) any later version. * * This program is distributed in the hope that it will be useful, but * WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU * General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 * USA. */ package com.rapidminer.operator.learner; import java.util.Arrays; import java.util.Comparator; import java.util.Iterator; import java.util.List; import com.rapidminer.example.Attribute; import com.rapidminer.example.Example; import com.rapidminer.example.ExampleSet; import com.rapidminer.operator.Model; import com.rapidminer.operator.OperatorDescription; import com.rapidminer.operator.OperatorException; import com.rapidminer.parameter.ParameterType; import com.rapidminer.parameter.ParameterTypeCategory; import com.rapidminer.parameter.UndefinedParameterError; import com.rapidminer.tools.Tools; /** * A DecisionStump clone that allows to specify different utility functions. * It is quick for nominal attributes, but does not yet apply pruning for continuos attributes. * Currently it can only handle boolean class labels. * * @author Martin Scholz * @version $Id: MultiCriterionDecisionStumps.java,v 1.3 2007/07/13 22:52:12 ingomierswa Exp $ */ public class MultiCriterionDecisionStumps extends AbstractLearner { private static final String ACC = "accuracy"; // TP + TN = p + N - n ~ p - n private static final String ENTROPY = "entropy"; private static final String SQRT_PN = "sqrt(TP*FP) + sqrt(FN*TN)"; // sqrt(pn) + sqrt((P-p)(N-n)) private static final String GINI = "gini index"; // sqrt(pn) + sqrt((P-p)(N-n)) private static final String CHI_SQUARE = "chi square test"; private static final String[] UTILITY_FUNCTION_LIST = new String[] { ENTROPY, ACC, SQRT_PN, GINI, CHI_SQUARE }; private static final String PARAMETER_UTILITY_FUNCTION = "utility_function"; public static class DecisionStumpModel extends SimplePredictionModel { private static final long serialVersionUID = -261158567126510415L; private final Attribute testAttribute; private final double testValue; private final boolean prediction; private boolean includeNaNs; private final boolean numerical; // nominal attribute: test is "equals" // numerical attribute: test is "<=" // if true, then the provided prediction is made public DecisionStumpModel(Attribute attribute, double testValue, ExampleSet exampleSet, boolean prediction, boolean includeNaNs) { super(exampleSet); this.prediction = prediction; this.includeNaNs = includeNaNs; this.testAttribute = attribute; this.testValue = testValue; if (testAttribute == null || !testAttribute.isNominal()) { this.numerical = true; } else { this.numerical = false; } } public double predict(Example example) { boolean evaluatesToTrue; if (this.testAttribute == null) { evaluatesToTrue = true; } else { double exampleValue = example.getValue(testAttribute); if (Double.isNaN(exampleValue)) { evaluatesToTrue = includeNaNs; } else if (this.numerical) { evaluatesToTrue = ( example.getValue(testAttribute) <= testValue ); } else { evaluatesToTrue = example.getValue(testAttribute) == testValue; } } if (evaluatesToTrue == prediction) { return this.getLabel().getMapping().getPositiveIndex(); } else return this.getLabel().getMapping().getNegativeIndex(); } /** @return a <code>String</code> representation of this rule model. */ public String toString() { String posIndexS = getLabel().getMapping().getPositiveString(); String negIndexS = getLabel().getMapping().getNegativeString(); StringBuffer result = new StringBuffer(super.toString()); result.append(Tools.getLineSeparator() + " (" + this.getLabel().getName() + "="); result.append((prediction ? posIndexS : negIndexS) + ") <-- "); result.append(((testAttribute != null) ? (testAttribute.getName() + ( numerical ? ( " <= " + testValue ) : ( " = " + testAttribute.getMapping().mapIndex((int) testValue)))) : "")); result.append(Tools.getLineSeparator() + " unknown: predict '" + (includeNaNs ? posIndexS : negIndexS) + "'."); return result.toString(); } } private int posIndex; private double globalP; private double globalN; private Model bestModel; private double bestScore; private String utilityFunction; public MultiCriterionDecisionStumps(OperatorDescription description) { super(description); } public boolean supportsCapability(LearnerCapability lc) { if (lc == com.rapidminer.operator.learner.LearnerCapability.NUMERICAL_ATTRIBUTES) return true; if (lc == com.rapidminer.operator.learner.LearnerCapability.POLYNOMINAL_ATTRIBUTES) return true; if (lc == com.rapidminer.operator.learner.LearnerCapability.BINOMINAL_ATTRIBUTES) return true; if (lc == com.rapidminer.operator.learner.LearnerCapability.BINOMINAL_CLASS) return true; if (lc == com.rapidminer.operator.learner.LearnerCapability.WEIGHTED_EXAMPLES) return true; return false; } protected void initHighscore() { this.bestModel = null; this.bestScore = Double.NEGATIVE_INFINITY; } /** @return the best decision stump found */ protected Model getBestModel() { return this.bestModel; } private void setBestModel(DecisionStumpModel model, double score) { this.bestModel = model; this.bestScore = score; } public Model learn(ExampleSet exampleSet) throws OperatorException { this.utilityFunction = UTILITY_FUNCTION_LIST[this.getParameterAsInt(PARAMETER_UTILITY_FUNCTION)]; this.initHighscore(); this.posIndex = exampleSet.getAttributes().getLabel().getMapping().getPositiveIndex(); double[] globalCounts = this.computePriors(exampleSet); this.globalP = globalCounts[0]; this.globalN = globalCounts[1]; { // init with better on eof the default models boolean defaultModelPrecition = (this.getScore(globalCounts, true) >= this.getScore(globalCounts, false)); this.setBestModel(new DecisionStumpModel(null, 0, exampleSet, defaultModelPrecition, true), this.getScore(globalCounts, defaultModelPrecition)); } this.evaluateNominalAttributes(exampleSet); this.evaluateNumericalAttributes(exampleSet); return this.getBestModel(); } @SuppressWarnings("unchecked") private void evaluateNumericalAttributes(ExampleSet exampleSet) throws OperatorException { int numAttr = exampleSet.getAttributes().size(); int[] mapAttribToIndex = new int[numAttr]; Attribute[] mapIndexToAttrib = new Attribute[numAttr]; int index = 0; { int i = 0; for (Attribute attribute : exampleSet.getAttributes()) { if (!attribute.isNominal()) { mapIndexToAttrib[index] = attribute; mapAttribToIndex[i] = index++; } else mapAttribToIndex[i] = -1; i++; } } if (index == 0) return; boolean hasWeight = (exampleSet.getAttributes().getWeight() != null); double[][] weightedLabel = new double[exampleSet.size()][2]; double[][][] values = new double[index][exampleSet.size()][]; Iterator<Example> reader = exampleSet.iterator(); int exampleNum = 0; double[] weightedPriors = new double[2]; while (reader.hasNext()) { Example example = reader.next(); int label = (example.getLabel() == posIndex) ? 0 : 1; double weight = (hasWeight ? example.getWeight() : 1.0d); weightedPriors[label] += weight; weightedLabel[exampleNum] = new double[] {label, weight }; for (int i=0; i<index; i++) { double attribValue = example.getValue(mapIndexToAttrib[i]); values[i][exampleNum] = new double[] { attribValue, exampleNum }; } exampleNum++; } final boolean predictNaN = (weightedPriors[0] >= weightedPriors[1]); Comparator cmp = new Comparator<double[]>() { public int compare(double[] arg0, double[] arg1) { return Double.compare(arg0[0], arg1[0]); } }; for (int i=0; i<index; i++) { final Attribute currentAttribute = mapIndexToAttrib[i]; final double[][] currentAttributeValues = values[i]; Arrays.sort(currentAttributeValues, cmp); final double counts[] = new double[exampleSet.getAttributes().getLabel().getMapping().size()]; double lastValue = Double.NEGATIVE_INFINITY; double lastScore = Double.NEGATIVE_INFINITY; boolean betterPrediction = false; for (int j=0; j<currentAttributeValues.length; j++) { final double curAttribValue = currentAttributeValues[j][0]; if (Double.isNaN(curAttribValue) || curAttribValue == Double.POSITIVE_INFINITY) { break; } final int curExampleNumber = (int) currentAttributeValues[j][1]; final int curLabel = (int) weightedLabel[curExampleNumber][0]; final double curWeight = weightedLabel[curExampleNumber][1]; if ( curAttribValue != lastValue && lastScore > this.bestScore ) { double testValue = (curAttribValue + lastValue) / 2.0d; boolean includeNaNs = (predictNaN == betterPrediction); DecisionStumpModel dsm = new DecisionStumpModel(currentAttribute, testValue, exampleSet, betterPrediction, includeNaNs); this.setBestModel(dsm, lastScore); } counts[curLabel] += curWeight; double scorePos = this.getScore(counts, true); double scoreNeg = this.getScore(counts, false); lastScore = Math.max(scorePos, scoreNeg); betterPrediction = (scorePos >= scoreNeg); lastValue = curAttribValue; } } } private void evaluateNominalAttributes(ExampleSet exampleSet) throws OperatorException { int numAttr = exampleSet.getAttributes().size(); int[] mapAttribToIndex = new int[numAttr]; Attribute[] mapIndexToAttrib = new Attribute[numAttr]; int index = 0; { int i = 0; for (Attribute attribute : exampleSet.getAttributes()) { if (attribute.isNominal()) { mapIndexToAttrib[index] = attribute; mapAttribToIndex[i] = index++; } else mapAttribToIndex[i] = -1; i++; } } if (index == 0) return; double[][][] counter = new double[index][][]; double [][] countNaNs = new double[index][exampleSet.getAttributes().getLabel().getMapping().size()]; for (int i=0; i<index; i++) { int numValues = mapIndexToAttrib[i].getMapping().size(); counter[i] = new double[numValues][exampleSet.getAttributes().getLabel().getMapping().size()]; } Attribute weightAttr = exampleSet.getAttributes().getWeight(); Iterator<Example> reader = exampleSet.iterator(); while (reader.hasNext()) { Example example = reader.next(); double weight = (weightAttr == null) ? 1.0d : example.getWeight(); int label = (example.getLabel() == posIndex) ? 0 : 1; for (int i=0; i<index; i++) { double attributeValue = example.getValue(mapIndexToAttrib[i]); if (Double.isNaN(attributeValue)) { countNaNs[i][label] += weight; } else counter[i][(int)attributeValue][label] += weight; } } for (int i=0; i<index; i++) { double[][] attributeMatrix = counter[i]; for (int j=0; j<attributeMatrix.length; j++) { ScoreNaNInfo snp = this.getScore(attributeMatrix[j], countNaNs[i]); if (snp.score > this.bestScore) { Attribute attribute = mapIndexToAttrib[i]; double testValue = j; this.setBestModel(new DecisionStumpModel(attribute, testValue, exampleSet, snp.predicted, snp.includeNaNs), snp.score); } } } } // Helper class. private static class ScoreNaNInfo { public double score; public boolean includeNaNs; public boolean predicted; ScoreNaNInfo(double score, boolean includeNaNs, boolean predicted) { this.score = score; this.includeNaNs = includeNaNs; this.predicted = predicted; } public ScoreNaNInfo max(ScoreNaNInfo other) { if (this.score >= other.score) return this; else return other; } } // Evaluate all four combinations: with and without including NaNs, predicting pos or neg class private ScoreNaNInfo getScore(double[] counts, double[] countNaNs) throws UndefinedParameterError { ScoreNaNInfo snp, snp2; // exclude NaNs, predict true double score = this.getScore(counts, true); snp = new ScoreNaNInfo(score, false, true); // exclude NaNs, predict false score = this.getScore(counts, false); snp2 = new ScoreNaNInfo(score, false, false); snp = snp.max(snp2); if (countNaNs[0] > 0 || countNaNs[1] > 0) { counts[0] += countNaNs[0]; counts[1] += countNaNs[1]; // include NaNs, predict true score = this.getScore(counts, true); snp2 = new ScoreNaNInfo(score, true, true); snp = snp.max(snp2); // include NaNs, predict false score = this.getScore(counts, false); snp2 = new ScoreNaNInfo(score, true, false); snp = snp.max(snp2); } return snp; } /** * Computes the score for the specified utility function, the provided counts and class. */ protected double getScore(double[] counts, boolean predictPositives) { double p = counts[0]; double n = counts[1]; double score; if (this.utilityFunction.equals(ACC)) { score = (predictPositives ? (p - n) : (n - p)); } else if (this.utilityFunction.equals(ENTROPY)) { if ( (p - n >= 0) ^ predictPositives) // symmetric, label has no effect return Double.NEGATIVE_INFINITY; double cov = p + n; double uncov = globalP + globalN - cov; double scoreCovered = (cov == 0 ? 0 : entropyLog2(p/cov) + entropyLog2(n/cov)); double scoreUncovered = (uncov == 0) ? 0 : entropyLog2( (globalP-p) / uncov ) + entropyLog2( (globalN-n) / uncov ); score = (cov * scoreCovered + uncov * scoreUncovered) / (cov + uncov); score = -score; // maximization problem assumed } else if (this.utilityFunction.equals(SQRT_PN)) { if ( (p - n >= 0) ^ predictPositives) // symmetric, label has no effect return Double.NEGATIVE_INFINITY; score = Math.sqrt(p * n) + Math.sqrt((globalP - p) * (globalN - n)); score = -score; // maximization problem assumed } else if (this.utilityFunction.equals(GINI)) { if ( (p - n >= 0) ^ predictPositives) // symmetric, label has no effect return Double.NEGATIVE_INFINITY; double cov = p + n; double uncov = globalP + globalN - cov; double scoreCovered = (cov == 0 ? 0 : (p / cov) * (n / cov)); double scoreUncovered = (uncov == 0) ? 0 : ((globalP - p) / uncov) * ((globalN - n) / uncov); score = (cov * scoreCovered + uncov * scoreUncovered) / (cov + uncov); score = -score; // maximization problem assumed } else if (this.utilityFunction.equals(CHI_SQUARE)) { double q = globalP - p; double r = globalN - n; double cov = p + n; double uncov = q + r; double total = cov + uncov; double c11, c12, c21, c22; c11 = cov * globalP / total; c12 = cov * globalN / total; c21 = uncov * globalP / total; c22 = uncov * globalN / total; if (cov > 0 && uncov > 0) { score = (p - c11) * (p - c11) / c11 + (n - c12) * (n - c12) / c12 + (q - c21) * (q - c21) / c21 + (r - c22) * (r - c22) / c22; } else score = 0; } else { score = Double.NaN; logWarning("Found unknown utility function: " + this.utilityFunction); } return score; } // more intuitive than log_e, although it should make no difference private double entropyLog2(double p) { if (Double.isNaN(p) || p == 0) { // NaN may e.g. occur when coverage is 0 return 0; } else return (- p * Math.log(p) / Math.log(2.0d)); } /** * @param exampleSet the exampleSet to get the weighted priors for * @return a double[2] object, first parameter is p, second is n. */ protected double[] computePriors(ExampleSet exampleSet) { Attribute weightAttr = exampleSet.getAttributes().getWeight(); double p = 0; double n = 0; Iterator<Example> reader = exampleSet.iterator(); while (reader.hasNext()) { Example example = reader.next(); double weight = (weightAttr == null) ? 1 : example.getValue(weightAttr); if (example.getLabel() == posIndex) { p += weight; } else n += weight; } return (new double[] { p, n }); } /** * Adds the parameter &utility function". */ public List<ParameterType> getParameterTypes() { List<ParameterType> types = super.getParameterTypes(); types.add(new ParameterTypeCategory(PARAMETER_UTILITY_FUNCTION, "The function to be optimized by the rule.", UTILITY_FUNCTION_LIST, 0)); return types; } }