/*
* RapidMiner
*
* Copyright (C) 2001-2011 by Rapid-I and the contributors
*
* Complete list of developers available at our web site:
*
* http://rapid-i.com
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
package com.rapidminer.operator.learner.tree;
import java.util.Collections;
import java.util.List;
import java.util.Vector;
import com.rapidminer.example.Attribute;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.set.SplittedExampleSet;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.learner.tree.criterions.Criterion;
/**
* Build a tree from an example set.
*
* @author Ingo Mierswa
*/
public class TreeBuilder {
protected Terminator minLeafSizeTerminator;
private List<Terminator> otherTerminators;
private int minSizeForSplit = 2;
private Criterion criterion;
private NumericalSplitter splitter;
protected SplitPreprocessing preprocessing = null;
private Pruner pruner;
protected LeafCreator leafCreator = new DecisionTreeLeafCreator();
protected int numberOfPrepruningAlternatives = 0;
protected boolean usePrePruning = true;
public TreeBuilder(Criterion criterion,
List<Terminator> terminationCriteria,
Pruner pruner,
SplitPreprocessing preprocessing,
LeafCreator leafCreator,
boolean noPrePruning,
int numberOfPrepruningAlternatives,
int minSizeForSplit,
int minLeafSize) {
this.minLeafSizeTerminator = new MinSizeTermination(minLeafSize);
this.otherTerminators = terminationCriteria;
this.otherTerminators.add(this.minLeafSizeTerminator);
this.usePrePruning = !noPrePruning;
this.numberOfPrepruningAlternatives = Math.max(0, numberOfPrepruningAlternatives);
this.minSizeForSplit = minSizeForSplit;
this.leafCreator = leafCreator;
this.criterion = criterion;
this.splitter = new NumericalSplitter(this.criterion);
this.pruner = pruner;
this.preprocessing = preprocessing;
}
public Tree learnTree(ExampleSet exampleSet) throws OperatorException {
// grow tree
Tree root = new Tree((ExampleSet)exampleSet.clone());
if (shouldStop(exampleSet, 0)) {
leafCreator.changeTreeToLeaf(root, exampleSet);
} else {
buildTree(root, exampleSet, 1);
}
// prune
if (pruner != null)
pruner.prune(root);
return root;
}
/** This method calculates the benefit of the given attribute. This implementation
* utilizes the defined {@link Criterion}. Subclasses might want to override this
* method in order to calculate the benefit in other ways. */
public Benefit calculateBenefit(ExampleSet trainingSet, Attribute attribute) throws OperatorException {
if (attribute.isNominal()) {
return new Benefit(criterion.getNominalBenefit(trainingSet, attribute), attribute);
} else {
// numerical attribute
double splitValue = splitter.getBestSplit(trainingSet, attribute);
if (!Double.isNaN(splitValue))
return new Benefit(criterion.getNumericalBenefit(trainingSet, attribute, splitValue), attribute, splitValue);
else
return null;
}
}
protected boolean shouldStop(ExampleSet exampleSet, int depth) {
if (usePrePruning && (exampleSet.size() < minSizeForSplit)) {
return true;
} else {
for (Terminator terminator : otherTerminators) {
if (terminator.shouldStop(exampleSet, depth))
return true;
}
return false;
}
}
protected Vector<Benefit> calculateAllBenefits(ExampleSet trainingSet) throws OperatorException {
Vector<Benefit> benefits = new Vector<Benefit>();
for (Attribute attribute : trainingSet.getAttributes()) {
Benefit currentBenefit = calculateBenefit(trainingSet, attribute);
if (currentBenefit != null) {
benefits.add(currentBenefit);
}
}
return benefits;
}
protected void buildTree(Tree current, ExampleSet exampleSet, int depth) throws OperatorException {
// terminate (beginning of recursive method!)
if (shouldStop(exampleSet, depth)) {
leafCreator.changeTreeToLeaf(current, exampleSet);
return;
}
// preprocessing
if (preprocessing != null) {
exampleSet = preprocessing.preprocess(exampleSet);
}
ExampleSet trainingSet = (ExampleSet)exampleSet.clone();
// calculate all benefits
Vector<Benefit> benefits = calculateAllBenefits(exampleSet);
// sort all benefits
Collections.sort(benefits);
// try at most k benefits and check if prepruning is fulfilled
boolean splitFound = false;
for (int a = 0; a < numberOfPrepruningAlternatives + 1; a++) {
// break if no benefits are left
if (benefits.size() <= 0)
break;
// search current best
Benefit bestBenefit = benefits.remove(0);
// check if minimum gain was reached
if (usePrePruning && (bestBenefit.getBenefit() <= 0)) {
continue;
}
// split by best attribute
SplittedExampleSet splitted = null;
Attribute bestAttribute = bestBenefit.getAttribute();
double bestSplitValue = bestBenefit.getSplitValue();
if (bestAttribute.isNominal()) {
splitted = SplittedExampleSet.splitByAttribute(trainingSet, bestAttribute);
} else {
splitted = SplittedExampleSet.splitByAttribute(trainingSet, bestAttribute, bestSplitValue);
}
// check if children all have the minimum size
boolean splitOK = true;
if (usePrePruning) {
for (int i = 0; i < splitted.getNumberOfSubsets(); i++) {
splitted.selectSingleSubset(i);
if ((splitted.size()) > 0 && (minLeafSizeTerminator.shouldStop(splitted, depth))) {
splitOK = false;
break;
}
}
}
// if all have minimum size --> remove nominal attribute and recursive call for each subset
if (splitOK) {
if (bestAttribute.isNominal()) {
splitted.getAttributes().remove(bestAttribute);
}
for (int i = 0; i < splitted.getNumberOfSubsets(); i++) {
splitted.selectSingleSubset(i);
if (splitted.size() > 0) {
Tree child = new Tree((ExampleSet)splitted.clone());
SplitCondition condition = null;
if (bestAttribute.isNominal()) {
condition = new NominalSplitCondition(bestAttribute, splitted.getExample(0).getValueAsString(bestAttribute));
} else {
if (i == 0) {
condition = new LessEqualsSplitCondition(bestAttribute, bestSplitValue);
} else {
condition = new GreaterSplitCondition(bestAttribute, bestSplitValue);
}
}
current.addChild(child, condition);
buildTree(child, splitted, depth + 1);
}
}
// end loop
splitFound = true;
break;
} else {
continue;
}
}
// no split found --> change to leaf and return
if (!splitFound) {
leafCreator.changeTreeToLeaf(current, trainingSet);
}
}
}