/**
* Copyright (C) 2001-2017 by RapidMiner and the contributors
*
* Complete list of developers available at our web site:
*
* http://rapidminer.com
*
* This program is free software: you can redistribute it and/or modify it under the terms of the
* GNU Affero General Public License as published by the Free Software Foundation, either version 3
* of the License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without
* even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License along with this program.
* If not, see http://www.gnu.org/licenses/.
*/
package com.rapidminer.operator.learner.tree;
import java.util.Iterator;
import java.util.Map.Entry;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.set.ExampleSetUtilities;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.learner.SimplePredictionModel;
/**
* The tree model is the model created by all decision trees.
*
* @author Sebastian Land
*/
public class TreeModel extends SimplePredictionModel {
private static final long serialVersionUID = 4368631725370998591L;
private Tree root;
public TreeModel(ExampleSet exampleSet, Tree root) {
super(exampleSet, ExampleSetUtilities.SetsCompareOption.ALLOW_SUPERSET,
ExampleSetUtilities.TypesCompareOption.ALLOW_SAME_PARENTS);
this.root = root;
}
public Tree getRoot() {
return this.root;
}
@Override
public double predict(Example example) throws OperatorException {
return predict(example, root);
}
private double predict(Example example, Tree node) {
if (node.isLeaf()) {
int[] counts = new int[getLabel().getMapping().size()];
int sum = 0;
for (Entry<String, Integer> entry : node.getCounterMap().entrySet()) {
int count = entry.getValue();
int index = getLabel().getMapping().getIndex(entry.getKey());
counts[index] = count;
sum += count;
}
for (int i = 0; i < counts.length; i++) {
example.setConfidence(getLabel().getMapping().mapIndex(i), (double) counts[i] / sum);
}
return getLabel().getMapping().getIndex(node.getLabel());
} else {
Iterator<Edge> childIterator = node.childIterator();
while (childIterator.hasNext()) {
Edge edge = childIterator.next();
SplitCondition condition = edge.getCondition();
if (condition.test(example)) {
return predict(example, edge.getChild());
}
}
// nothing known from training --> use majority class in this node
String majorityClass = null;
int majorityCounter = -1;
int[] counts = new int[getLabel().getMapping().size()];
int sum = 0;
for (Entry<String, Integer> entry : node.getSubtreeCounterMap().entrySet()) {
String className = entry.getKey();
int count = entry.getValue().intValue();
int index = getLabel().getMapping().getIndex(className);
counts[index] = count;
sum += count;
if (count > majorityCounter) {
majorityCounter = count;
majorityClass = className;
}
}
for (int i = 0; i < counts.length; i++) {
example.setConfidence(getLabel().getMapping().mapIndex(i), (double) counts[i] / sum);
}
if (majorityClass != null) {
return getLabel().getMapping().getIndex(majorityClass);
} else {
return 0;
}
}
}
@Override
public String toString() {
return this.root.toString();
}
}