/* * RapidMiner * * Copyright (C) 2001-2008 by Rapid-I and the contributors * * Complete list of developers available at our web site: * * http://rapid-i.com * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see http://www.gnu.org/licenses/. */ package com.rapidminer.operator.learner.tree; import java.util.HashSet; import java.util.Iterator; import java.util.Set; import com.rapidminer.example.Attribute; import com.rapidminer.example.Example; import com.rapidminer.example.ExampleSet; import com.rapidminer.example.Statistics; import com.rapidminer.tools.math.MathFunctions; /** * This class provides a pruner based on some heuristic statistics. It cuts the tree * to reduce overfitting. * * @author Sebastian Land, Ingo Mierswa * @version $Id: PessimisticPruner.java,v 1.5 2008/05/09 19:22:53 ingomierswa Exp $ */ public class PessimisticPruner implements Pruner { private static final double PRUNE_PREFERENCE = 0.001; private double confidenceLevel; private LeafCreator leafCreator; public PessimisticPruner(double confidenceLevel, LeafCreator leafCreator) { this.confidenceLevel = confidenceLevel; this.leafCreator = leafCreator; } public void prune(Tree root) { Iterator<Edge> childIterator = root.childIterator(); while (childIterator.hasNext()) { pruneChild(childIterator.next().getChild(), root); } } private void pruneChild(Tree currentNode, Tree father) { // going down to fathers of leafs if (!currentNode.isLeaf()) { Iterator<Edge> childIterator = currentNode.childIterator(); while (childIterator.hasNext()) { pruneChild(childIterator.next().getChild(), currentNode); } if (!childrenHaveChildren(currentNode)) { // calculating error estimate for leafs double leafsErrorEstimate = 0; childIterator = currentNode.childIterator(); Set<String> classSet = new HashSet<String>(); while (childIterator.hasNext()) { Tree leafNode = childIterator.next().getChild(); ExampleSet leafExampleSet = leafNode.getTrainingSet(); classSet.add(leafNode.getLabel()); int examples = leafExampleSet.size(); double currentErrorRate = getErrorNumber(leafExampleSet, leafExampleSet.getAttributes().getLabel().getMapping().getIndex(leafNode.getLabel())) / (double) leafExampleSet.size();; leafsErrorEstimate += pessimisticErrors(examples, currentErrorRate, confidenceLevel) * (((double) examples) / currentNode.getTrainingSet().size()); } // calculating error estimate for current node ExampleSet currentNodeExampleSet = currentNode.getTrainingSet(); if (classSet.size() <= 1) { currentNode.removeChildren(); leafCreator.changeTreeToLeaf(currentNode, currentNodeExampleSet); } else { double currentNodeLabel = prunedLabel(currentNodeExampleSet); int examples = currentNodeExampleSet.size(); double currentErrorRate = getErrorNumber(currentNodeExampleSet, currentNodeLabel) / (double) currentNodeExampleSet.size(); double nodeErrorEstimate = pessimisticErrors(examples, currentErrorRate, confidenceLevel); // if currentNode error level is less than children: prune if (nodeErrorEstimate - PRUNE_PREFERENCE <= leafsErrorEstimate) { currentNode.removeChildren(); leafCreator.changeTreeToLeaf(currentNode, currentNodeExampleSet); } } } } } private boolean childrenHaveChildren(Tree node) { Iterator<Edge> iterator = node.childIterator(); while (iterator.hasNext()) { if (!iterator.next().getChild().isLeaf()) return true; } return false; } private int getErrorNumber(ExampleSet exampleSet, double label) { int errors = 0; Iterator<Example> iterator = exampleSet.iterator(); while (iterator.hasNext()) { if (iterator.next().getLabel() != label) { errors++; } } return errors; } public double prunedLabel(ExampleSet exampleSet) { Attribute labelAttribute = exampleSet.getAttributes().getLabel(); exampleSet.recalculateAttributeStatistics(labelAttribute); double test = exampleSet.getStatistics(labelAttribute, Statistics.MODE); return test; } // calculates the pessimistic number of errors, using some confidence level. public double pessimisticErrors(double numberOfExamples, double errorRate, double confidenceLevel) { if (errorRate < 1E-6) { return errorRate + numberOfExamples * (1.0 - Math.exp(Math.log(confidenceLevel) / numberOfExamples)); } else if (errorRate + 0.5 >= numberOfExamples) { return errorRate + 0.67 * (numberOfExamples - errorRate); } else { double coefficient = MathFunctions.normalInverse(1 - confidenceLevel); coefficient *= coefficient; double pessimisticRate = (errorRate + 0.5 + coefficient / 2.0d + Math.sqrt(coefficient * ((errorRate + 0.5) * (1 - (errorRate + 0.5) / numberOfExamples) + coefficient / 4.0d))) / (numberOfExamples + coefficient); return (numberOfExamples * pessimisticRate); } } }