/**
* Copyright (C) 2001-2017 by RapidMiner and the contributors
*
* Complete list of developers available at our web site:
*
* http://rapidminer.com
*
* This program is free software: you can redistribute it and/or modify it under the terms of the
* GNU Affero General Public License as published by the Free Software Foundation, either version 3
* of the License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without
* even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License along with this program.
* If not, see http://www.gnu.org/licenses/.
*/
package com.rapidminer.operator.learner.tree;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import com.rapidminer.tools.math.MathFunctions;
/**
* This class provides a pruner based on some heuristic statistics. It cuts the tree to reduce
* overfitting. The pruning only uses the information of the tree structure and not the example sets
* that can be saved in tree nodes.
*
* @author Sebastian Land, Ingo Mierswa, Gisa Schaefer
*/
public class TreebasedPessimisticPruner implements Pruner {
private static final double PRUNE_PREFERENCE = 0.001;
private double confidenceLevel;
public TreebasedPessimisticPruner(double confidenceLevel, LeafCreator leafCreator) {
this.confidenceLevel = confidenceLevel;
}
@Override
public void prune(Tree root) {
Iterator<Edge> childIterator = root.childIterator();
while (childIterator.hasNext()) {
pruneChild(childIterator.next().getChild());
}
}
/**
* Prunes the tree given by currentNode recursively.
*
* @param currentNode
*/
private void pruneChild(Tree currentNode) {
// going down to fathers of leafs
if (!currentNode.isLeaf()) {
Iterator<Edge> childIterator = currentNode.childIterator();
while (childIterator.hasNext()) {
pruneChild(childIterator.next().getChild());
}
if (!childrenHaveChildren(currentNode)) {
// calculating error estimate for leafs
double leafsErrorEstimate = 0;
int examplesCurrentNode = currentNode.getSubtreeFrequencySum();
childIterator = currentNode.childIterator();
Set<String> classSet = new HashSet<String>();
// calculate sum of pessimistic errors of the child nodes
while (childIterator.hasNext()) {
Tree leafNode = childIterator.next().getChild();
classSet.add(leafNode.getLabel());
int examples = leafNode.getFrequencySum();
double currentErrorRate = getErrorNumber(leafNode, leafNode.getLabel()) / (double) examples;
leafsErrorEstimate += pessimisticErrors(examples, currentErrorRate, confidenceLevel)
* ((double) examples / (double) examplesCurrentNode);
}
// calculating error estimate for current node
if (classSet.size() <= 1) {
changeToLeaf(currentNode);
} else {
String currentNodeLabel = prunedLabel(currentNode);
double currentErrorRate = getErrorNumber(currentNode, currentNodeLabel) / (double) examplesCurrentNode;
double nodeErrorEstimate = pessimisticErrors(examplesCurrentNode, currentErrorRate, confidenceLevel);
// if currentNode error level is less than children: prune
if (nodeErrorEstimate - PRUNE_PREFERENCE <= leafsErrorEstimate) {
changeToLeaf(currentNode);
}
}
}
}
}
/**
* Checks if the children of the node have child nodes, i.e. are not leaves
*
* @param node
* @return
*/
private boolean childrenHaveChildren(Tree node) {
Iterator<Edge> iterator = node.childIterator();
while (iterator.hasNext()) {
if (!iterator.next().getChild().isLeaf()) {
return true;
}
}
return false;
}
/**
* Removes the children of the node and adds the information a leaf must contain.
*
* @param node
*/
private void changeToLeaf(Tree node) {
Map<String, Integer> counterMap = node.getSubtreeCounterMap();
int maximum = 0;
String label = "";
for (String entry : counterMap.keySet()) {
int number = counterMap.get(entry);
node.addCount(entry, number); // needed since the counterMap of node does not get
// changed by calling getSubtreeCounterMap
if (number > maximum) {
maximum = number;
label = entry;
}
}
node.removeChildren();
node.setLeaf(label);
}
/**
* Counts how many examples represented in the node have a label different from label.
*
* @param node
* @param label
* @return
*/
private int getErrorNumber(Tree node, String label) {
Map<String, Integer> counterMap;
if (node.isLeaf()) {
counterMap = node.getCounterMap();
} else {
counterMap = node.getSubtreeCounterMap();
}
int errors = 0;
for (String entry : counterMap.keySet()) {
if (!label.equals(entry)) {
errors += counterMap.get(entry);
}
}
return errors;
}
/**
* Calculates the label a node would have if it became a leaf.
*
* @param node
* a node that is not a leaf
* @return the majority label
*/
public String prunedLabel(Tree node) {
Map<String, Integer> counterMap = node.getSubtreeCounterMap();
int maximum = 0;
String label = "";
for (String entry : counterMap.keySet()) {
int number = counterMap.get(entry);
if (number > maximum) {
maximum = number;
label = entry;
}
}
return label;
}
/**
* Calculates the pessimistic number of errors, using some confidence level.
*
* @param numberOfExamples
* @param errorRate
* @param confidenceLevel
* @return
*/
public double pessimisticErrors(double numberOfExamples, double errorRate, double confidenceLevel) {
if (errorRate < 1E-6) {
return errorRate + numberOfExamples * (1.0 - Math.exp(Math.log(confidenceLevel) / numberOfExamples));
} else if (errorRate + 0.5 >= numberOfExamples) {
return errorRate + 0.67 * (numberOfExamples - errorRate);
} else {
double coefficient = MathFunctions.normalInverse(1 - confidenceLevel);
coefficient *= coefficient;
double pessimisticRate = (errorRate + 0.5 + coefficient / 2.0d + Math.sqrt(coefficient
* ((errorRate + 0.5) * (1 - (errorRate + 0.5) / numberOfExamples) + coefficient / 4.0d)))
/ (numberOfExamples + coefficient);
return numberOfExamples * pessimisticRate;
}
}
}