/**
* Copyright (C) 2001-2017 by RapidMiner and the contributors
*
* Complete list of developers available at our web site:
*
* http://rapidminer.com
*
* This program is free software: you can redistribute it and/or modify it under the terms of the
* GNU Affero General Public License as published by the Free Software Foundation, either version 3
* of the License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without
* even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License along with this program.
* If not, see http://www.gnu.org/licenses/.
*/
package com.rapidminer.operator.learner.tree;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import com.rapidminer.example.Attribute;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.operator.Operator;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.learner.tree.criterions.ColumnCriterion;
import com.rapidminer.studio.internal.Resources;
/**
* Build a tree from an example set, possibly in parallel. During the tree building process the
* examples and attributes at a node are represented by numbers coming from numbering all examples
* and attributes at the beginning. In this numbering all nominal attributes come before all
* numerical attributes. During the tree growing process the nodes are split into smaller ones.
*
* This class should be extended to specify if and how the calculations should be parallelized. By
* implementing the method {@link #startTree} using {@link #splitNode} one can decide if and in
* which direction the process of splitting the nodes should be parallelized. By implementing the
* abstract method {@link #doStartSelectionInParallel()} one can decide if and when the start
* selection of the examples should be done in parallel. (Note that this only has an effect if there
* are numerical attributes.)
*
* @author Ingo Mierswa, Gisa Schaefer
*/
public abstract class AbstractParallelTreeBuilder {
final protected Operator operator;
final protected ColumnTerminator minLeafSizeTerminator;
final protected List<ColumnTerminator> otherTerminators;
final protected int minSizeForSplit;
final protected ColumnCriterion criterion;
protected BenefitCalculator benefitCalculator;
protected SelectionCreator selectionCreator;
final protected AttributePreprocessing preprocessing;
final protected Pruner pruner;
final protected ParallelDecisionTreeLeafCreator leafCreator = new ParallelDecisionTreeLeafCreator();
protected int numberOfPrepruningAlternatives = 0;
final protected boolean usePrePruning;
protected ColumnExampleTable columnTable;
final protected boolean parallelAllowed;
/**
* Initializes the fields.
*
*/
public AbstractParallelTreeBuilder(Operator operator, ColumnCriterion criterion,
List<ColumnTerminator> terminationCriteria, Pruner pruner, AttributePreprocessing preprocessing,
boolean prePruning, int numberOfPrepruningAlternatives, int minSizeForSplit, int minLeafSize,
boolean parallelAllowed) {
this.operator = operator;
this.minLeafSizeTerminator = new ColumnMinSizeTermination(minLeafSize);
if (terminationCriteria == null) {
throw new IllegalArgumentException("terminationCriteria must not be null!");
}
this.otherTerminators = terminationCriteria;
if (prePruning) {
this.otherTerminators.add(this.minLeafSizeTerminator);
}
this.usePrePruning = prePruning;
if (prePruning) {
this.numberOfPrepruningAlternatives = Math.max(0, numberOfPrepruningAlternatives);
}
this.minSizeForSplit = minSizeForSplit;
if (criterion == null) {
throw new IllegalArgumentException("criterion must not be null!");
}
this.criterion = criterion;
this.pruner = pruner;
this.preprocessing = preprocessing;
this.parallelAllowed = parallelAllowed;
}
/**
* Creates a copy of the example set in form of the {@link ColumnExampleTable}, starts the tree
* growing procedure and prunes the finished tree.
*
* @param exampleSet
* @return
* @throws OperatorException
*/
public Tree learnTree(ExampleSet exampleSet) throws OperatorException {
// preprocess example set before creating the table
exampleSet = preprocessExampleSet(exampleSet);
columnTable = new ColumnExampleTable(exampleSet, operator, parallelAllowed);
benefitCalculator = new BenefitCalculator(columnTable, criterion, operator);
selectionCreator = new SelectionCreator(columnTable);
Map<Integer, int[]> allSelectedExamples = createExampleStartSelection();
int[] selectedExamples = SelectionCreator.getArbitraryValue(allSelectedExamples);
int[] selectedAttributes = selectionCreator.createFullArray(columnTable.getTotalNumberOfRegularAttributes());
// grow tree
Tree root = new Tree(null);
if (shouldStop(selectedExamples, selectedAttributes, 0)) {
leafCreator.changeTreeToLeaf(root, columnTable, selectedExamples);
} else {
startTree(root, allSelectedExamples, selectedAttributes, 1);
}
// prune
if (pruner != null) {
pruner.prune(root);
}
return root;
}
/**
* Hook for preprocessing the example set before building the {@link ColumnExampleTable}.
*
* @param exampleSet
* @return
*/
protected ExampleSet preprocessExampleSet(ExampleSet exampleSet) {
return exampleSet;
}
/**
* Creates for every numerical attribute a sorted start selection, possibly in parallel.
*
* @return
* @throws OperatorException
*/
protected Map<Integer, int[]> createExampleStartSelection() throws OperatorException {
Map<Integer, int[]> allSelectedExamples;
if (doStartSelectionInParallel() && operator != null) {
allSelectedExamples = selectionCreator.getStartSelectionParallel(operator);
} else {
allSelectedExamples = selectionCreator.getStartSelection();
}
return allSelectedExamples;
}
/**
* Decides whether the start selection should be created in parallel.
*
* @return
*/
abstract boolean doStartSelectionInParallel();
/**
* Starts the tree building process for the given parameters.
*
* @param root
* @param allSelectedExamples
* @param selectedAttributes
* @param depth
* @throws OperatorException
*/
abstract void startTree(Tree root, Map<Integer, int[]> allSelectedExamples, int[] selectedAttributes, int depth)
throws OperatorException;
/**
* Splits the node given by the nodeData by calculating the attribute with the best benefit.
*
* @param nodeData
* @param attributeParallel
* if <code>true</code> the calculation of the benefits is done in parallel by
* attributes
* @return
* @throws OperatorException
*/
protected Collection<NodeData> splitNode(NodeData nodeData, boolean attributeParallel) throws OperatorException {
// check if operator was stopped
if (operator != null) {
Resources.getConcurrencyContext(operator).checkStatus();
}
Map<Integer, int[]> allSelectedExamples = nodeData.getAllSelectedExamples();
int[] selectedAttributes = nodeData.getSelectedAttributes();
Tree current = nodeData.getTree();
int depth = nodeData.getDepth();
// terminate
int[] selectedExamples = SelectionCreator.getArbitraryValue(allSelectedExamples);
if (shouldStop(selectedExamples, selectedAttributes, depth)) {
leafCreator.changeTreeToLeaf(current, columnTable, selectedExamples);
return Collections.emptyList();
}
// preprocessing
if (preprocessing != null && depth > 1) {
selectedAttributes = preprocessing.preprocess(selectedAttributes);
}
// calculate all benefits
List<ParallelBenefit> benefits = getBenefits(allSelectedExamples, selectedAttributes, attributeParallel);
// sort all benefits
Collections.sort(benefits);
// try at most k benefits and check if prepruning is fulfilled
boolean splitFound = false;
for (int a = 0; a < numberOfPrepruningAlternatives + 1; a++) {
// break if no benefits are left
if (benefits.size() <= 0) {
break;
}
// search current best
ParallelBenefit bestBenefit = benefits.remove(0);
// check if minimum gain was reached when using prepruning and if the benefit results in
// a split with more than one child
if (usePrePruning && bestBenefit.getBenefit() <= 0 || !usePrePruning
&& !(bestBenefit.getBenefit() > Double.NEGATIVE_INFINITY)) {
break;
}
// split by best attribute
int bestAttribute = bestBenefit.getAttributeNumber();
double bestSplitValue = bestBenefit.getSplitValue();
Collection<Map<Integer, int[]>> splits = selectionCreator.getSplits(allSelectedExamples, bestAttribute,
bestSplitValue);
// if all have minimum size --> remove nominal attribute and recursive call for each
// subset
if (isSplitOK(selectedAttributes, depth, splits)) {
int[] remainingAttributes = selectionCreator.updateRemainingAttributes(selectedAttributes, bestAttribute);
LinkedList<NodeData> children = new LinkedList<>();
int i = 0;
for (Map<Integer, int[]> split : splits) {
if (SelectionCreator.getArbitraryValue(split).length > 0) {
Tree child = new Tree(null);
addToParentTree(current, child, bestAttribute, bestSplitValue,
SelectionCreator.getArbitraryValue(split), i);
NodeData newNode = new NodeData(child, split, remainingAttributes, depth + 1);
children.add(newNode);
i++;
}
}
// end loop
return children;
}
// no valid split found - try again
}
// no split found --> change to leaf and return
if (!splitFound) {
leafCreator.changeTreeToLeaf(current, columnTable, selectedExamples);
}
return Collections.emptyList();
}
/**
* Checks if the tree building should stop. The terminators are checked and, when prepruning is
* activated, the minimal size for a split is checked as well.
*
* @param selectedExamples
* @param selectedAttributes
* @param depth
* @return
*/
protected boolean shouldStop(int[] selectedExamples, int[] selectedAttributes, int depth) {
if (usePrePruning && selectedExamples.length < minSizeForSplit) {
return true;
} else {
for (ColumnTerminator terminator : otherTerminators) {
if (terminator.shouldStop(selectedExamples, selectedAttributes, columnTable, depth)) {
return true;
}
}
return false;
}
}
/**
* For each attribute calculate the benefit for splitting there, possibly in parallel if
* attributeParalle is <code>true</code>.
*
* @param allSelectedExamples
* @param selectedAttributes
* @param attributeParallel
* @return
* @throws OperatorException
*/
protected List<ParallelBenefit> getBenefits(Map<Integer, int[]> allSelectedExamples, int[] selectedAttributes,
boolean attributeParallel) throws OperatorException {
List<ParallelBenefit> benefits;
if (attributeParallel && operator != null) {
benefits = benefitCalculator.calculateAllBenefitsParallel(allSelectedExamples, selectedAttributes);
} else {
benefits = benefitCalculator.calculateAllBenefits(allSelectedExamples, selectedAttributes);
}
return benefits;
}
/**
* Checks in the case of prepruning whether the minimal leaf size is satisfied.
*
* @param selectedAttributes
* @param depth
* @param splits
* @return
*/
private boolean isSplitOK(int[] selectedAttributes, int depth, Collection<Map<Integer, int[]>> splits) {
// check if children all have the minimum size
boolean splitOK = true;
if (usePrePruning) {
for (Map<Integer, int[]> splitinfo : splits) {
int[] split = SelectionCreator.getArbitraryValue(splitinfo);
if (split.length > 0 && minLeafSizeTerminator.shouldStop(split, selectedAttributes, columnTable, depth)) {
splitOK = false;
break;
}
}
}
return splitOK;
}
/**
* Adds the child tree to the parent tree via an edge describing the split.
*
* @param parent
* @param bestAttribute
* @param bestSplitValue
* @param counter
* @param split
* @param child
*/
private void addToParentTree(Tree parent, Tree child, int bestAttribute, double bestSplitValue, int[] split, int counter) {
SplitCondition condition = null;
if (columnTable.representsNominalAttribute(bestAttribute)) {
// find the attribute value we are splitting
Attribute best = columnTable.getNominalAttribute(bestAttribute);
final byte index = columnTable.getNominalAttributeColumn(bestAttribute)[split[0]];
String splitValueName;
// NaNs are represented by the number mapping size
if (index == best.getMapping().size()) {
splitValueName = null;
} else {
splitValueName = best.getMapping().mapIndex(index);
}
condition = new NominalSplitCondition(best, splitValueName);
} else {
if (counter == 0) {
condition = new LessEqualsSplitCondition(columnTable.getNumericalAttribute(bestAttribute), bestSplitValue);
} else if (counter == 1) {
condition = new GreaterSplitCondition(columnTable.getNumericalAttribute(bestAttribute), bestSplitValue);
} else {
condition = new NumericalMissingSplitCondition(columnTable.getNumericalAttribute(bestAttribute));
}
}
parent.addChild(child, condition);
}
/**
* Class to bundle the parameters of {@link AbstractParallelTreeBuilder#splitNode(NodeData)}.
*/
protected class NodeData {
Tree tree;
Map<Integer, int[]> allSelectedExamples;
int[] selectedAttributes;
int depth;
NodeData(Tree tree, Map<Integer, int[]> allSelectedExamples, int[] selectedAttributes, int depth) {
this.tree = tree;
this.allSelectedExamples = allSelectedExamples;
this.selectedAttributes = selectedAttributes;
this.depth = depth;
}
Tree getTree() {
return tree;
}
Map<Integer, int[]> getAllSelectedExamples() {
return allSelectedExamples;
}
int[] getSelectedAttributes() {
return selectedAttributes;
}
int getDepth() {
return depth;
}
}
}