/* * RapidMiner * * Copyright (C) 2001-2011 by Rapid-I and the contributors * * Complete list of developers available at our web site: * * http://rapid-i.com * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see http://www.gnu.org/licenses/. */ package com.rapidminer.operator.learner.tree; import java.util.Iterator; import com.rapidminer.example.Example; import com.rapidminer.example.ExampleSet; import com.rapidminer.operator.OperatorException; import com.rapidminer.operator.learner.SimplePredictionModel; /** * The tree model is the model created by all decision trees. * * @author Sebastian Land */ public class TreeModel extends SimplePredictionModel { private static final long serialVersionUID = 4368631725370998591L; private Tree root; public TreeModel(ExampleSet exampleSet, Tree root) { super(exampleSet); this.root = root; } public Tree getRoot() { return this.root; } @Override public double predict(Example example) throws OperatorException { return predict(example, root); } private double predict(Example example, Tree node) { if (node.isLeaf()) { Iterator<String> s = node.getCounterMap().keySet().iterator(); int[] counts = new int[getLabel().getMapping().size()]; int sum = 0; while (s.hasNext()) { String className = s.next(); int count = node.getCount(className); int index = getLabel().getMapping().getIndex(className); counts[index] = count; sum += count; } for (int i = 0; i < counts.length; i++) { example.setConfidence(getLabel().getMapping().mapIndex(i), ((double) counts[i]) / sum); } return getLabel().getMapping().getIndex(node.getLabel()); } else { Iterator<Edge> childIterator = node.childIterator(); while (childIterator.hasNext()) { Edge edge = childIterator.next(); SplitCondition condition = edge.getCondition(); if (condition.test(example)) { return (predict(example, edge.getChild())); } } // nothing known from training --> use majority class in this node String majorityClass = null; int majorityCounter = -1; Iterator<String> s = node.getCounterMap().keySet().iterator(); int[] counts = new int[getLabel().getMapping().size()]; int sum = 0; while (s.hasNext()) { String className = s.next(); int count = node.getCount(className); int index = getLabel().getMapping().getIndex(className); counts[index] = count; sum += count; if (count > majorityCounter) { majorityCounter = count; majorityClass = className; } } for (int i = 0; i < counts.length; i++) { example.setConfidence(getLabel().getMapping().mapIndex(i), ((double) counts[i]) / sum); } if (majorityClass != null) return getLabel().getMapping().getIndex(majorityClass); else return 0; } } @Override public String toString() { return this.root.toString(); } }