/**
* Copyright (C) 2001-2017 by RapidMiner and the contributors
*
* Complete list of developers available at our web site:
*
* http://rapidminer.com
*
* This program is free software: you can redistribute it and/or modify it under the terms of the
* GNU Affero General Public License as published by the Free Software Foundation, either version 3
* of the License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without
* even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License along with this program.
* If not, see http://www.gnu.org/licenses/.
*/
package com.rapidminer.operator.learner.tree;
import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.Statistics;
import com.rapidminer.tools.math.MathFunctions;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Set;
/**
* This class provides a pruner based on some heuristic statistics. It cuts the tree to reduce
* overfitting.
*
* @author Sebastian Land, Ingo Mierswa
*/
public class PessimisticPruner implements Pruner {
private static final double PRUNE_PREFERENCE = 0.001;
private double confidenceLevel;
private LeafCreator leafCreator;
public PessimisticPruner(double confidenceLevel, LeafCreator leafCreator) {
this.confidenceLevel = confidenceLevel;
this.leafCreator = leafCreator;
}
@Override
public void prune(Tree root) {
Iterator<Edge> childIterator = root.childIterator();
while (childIterator.hasNext()) {
pruneChild(childIterator.next().getChild(), root);
}
}
private void pruneChild(Tree currentNode, Tree father) {
// going down to fathers of leafs
if (!currentNode.isLeaf()) {
Iterator<Edge> childIterator = currentNode.childIterator();
while (childIterator.hasNext()) {
pruneChild(childIterator.next().getChild(), currentNode);
}
if (!childrenHaveChildren(currentNode)) {
// calculating error estimate for leafs
double leafsErrorEstimate = 0;
childIterator = currentNode.childIterator();
Set<String> classSet = new HashSet<String>();
while (childIterator.hasNext()) {
Tree leafNode = childIterator.next().getChild();
ExampleSet leafExampleSet = leafNode.getTrainingSet();
classSet.add(leafNode.getLabel());
int examples = leafExampleSet.size();
double currentErrorRate = getErrorNumber(leafExampleSet, leafExampleSet.getAttributes().getLabel()
.getMapping().getIndex(leafNode.getLabel()))
/ (double) leafExampleSet.size();
;
leafsErrorEstimate += pessimisticErrors(examples, currentErrorRate, confidenceLevel)
* (((double) examples) / currentNode.getTrainingSet().size());
}
// calculating error estimate for current node
ExampleSet currentNodeExampleSet = currentNode.getTrainingSet();
if (classSet.size() <= 1) {
currentNode.removeChildren();
leafCreator.changeTreeToLeaf(currentNode, currentNodeExampleSet);
} else {
double currentNodeLabel = prunedLabel(currentNodeExampleSet);
int examples = currentNodeExampleSet.size();
double currentErrorRate = getErrorNumber(currentNodeExampleSet, currentNodeLabel)
/ (double) currentNodeExampleSet.size();
double nodeErrorEstimate = pessimisticErrors(examples, currentErrorRate, confidenceLevel);
// if currentNode error level is less than children: prune
if (nodeErrorEstimate - PRUNE_PREFERENCE <= leafsErrorEstimate) {
currentNode.removeChildren();
leafCreator.changeTreeToLeaf(currentNode, currentNodeExampleSet);
}
}
}
}
}
private boolean childrenHaveChildren(Tree node) {
Iterator<Edge> iterator = node.childIterator();
while (iterator.hasNext()) {
if (!iterator.next().getChild().isLeaf()) {
return true;
}
}
return false;
}
private int getErrorNumber(ExampleSet exampleSet, double label) {
int errors = 0;
Iterator<Example> iterator = exampleSet.iterator();
while (iterator.hasNext()) {
if (iterator.next().getLabel() != label) {
errors++;
}
}
return errors;
}
public double prunedLabel(ExampleSet exampleSet) {
Attribute labelAttribute = exampleSet.getAttributes().getLabel();
exampleSet.recalculateAttributeStatistics(labelAttribute);
double test = exampleSet.getStatistics(labelAttribute, Statistics.MODE);
return test;
}
// calculates the pessimistic number of errors, using some confidence level.
public double pessimisticErrors(double numberOfExamples, double errorRate, double confidenceLevel) {
if (errorRate < 1E-6) {
return errorRate + numberOfExamples * (1.0 - Math.exp(Math.log(confidenceLevel) / numberOfExamples));
} else if (errorRate + 0.5 >= numberOfExamples) {
return errorRate + 0.67 * (numberOfExamples - errorRate);
} else {
double coefficient = MathFunctions.normalInverse(1 - confidenceLevel);
coefficient *= coefficient;
double pessimisticRate = (errorRate + 0.5 + coefficient / 2.0d + Math.sqrt(coefficient
* ((errorRate + 0.5) * (1 - (errorRate + 0.5) / numberOfExamples) + coefficient / 4.0d)))
/ (numberOfExamples + coefficient);
return (numberOfExamples * pessimisticRate);
}
}
}