package org.streaminer.stream.classifier.tree; import org.streaminer.stream.data.Data; import org.streaminer.stream.learner.Regressor; import org.streaminer.stream.model.PredictionModel; import java.util.LinkedHashMap; import java.util.Map; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * <p> * This class implements a regression tree model. A regression tree model is a * decision tree with regression models at its leaf nodes. * </p> * * @author Christian Bockermann <christian.bockermann@udo.edu> * */ public class RegressionTreeModel implements PredictionModel<Data, Double> { /** The unique class ID */ private static final long serialVersionUID = 5489550422815750513L; static final transient Logger log = LoggerFactory .getLogger(RegressionTreeModel.class); /** * root of regression tree */ private RegressionTreeNode root; /** * constructs new regression model * * @throws ClassNotFoundException * @throws IllegalAccessException * @throws InstantiationException */ @SuppressWarnings("unchecked") public RegressionTreeModel(Regressor<Data> regression) throws Exception { root = new LeafNode(null, false, regression, 0); } /** * Returns the leaf at which the path through the tree ends for the * specified {@link Example}. * * @param item * @return leaf if found, null instead */ public LeafNode getLeaf(Data item) { RegressionTreeNode currentNode = this.root; while (currentNode instanceof InnerNode) { InnerNode innerNode = (InnerNode) currentNode; currentNode = innerNode.traverseNode(item.get(innerNode .getFeature())); } if (currentNode instanceof LeafNode) { return (LeafNode) currentNode; } return null; } @Override public Double predict(Data item) { LeafNode leaf = this.getLeaf(item); if (leaf != null) { Double prediction = leaf.getRegressionModel().predict(item); return prediction; } return Double.NaN; } /** * @return root element of regression tree */ public RegressionTreeNode getRoot() { return this.root; } /** * sets root element of regression tree * * @param root * RegressionTreeNode to be new root of regression tree */ public void setRoot(RegressionTreeNode root) { this.root = root; } /** * @return String representation of this RegressionTreeModel */ @Override public String toString() { return this.root.toString(0); } }