/** * */ package org.streaminer.stream.classifier.tree; import org.streaminer.stream.data.Data; import org.streaminer.stream.learner.LearnerUtils; import org.streaminer.stream.learner.Regressor; import java.util.HashMap; import java.util.Map; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * @author chris * */ public class RTree extends ModelTree<RegressionTreeStatistics,Double> implements Regressor<Data> { /** The unique class ID */ private static final long serialVersionUID = 4926545397273482368L; /* The logger for this class */ static Logger log = LoggerFactory.getLogger( RTree.class ); /* The split criterion for this model tree */ SplitCriterion<RegressionTreeStatistics> splitCriterion; /* One binary tree for each numerical attribute */ Map<String,BTreeNode> btrees = new HashMap<String,BTreeNode>(); BestSplitValueFinder splitValueFinder = new BestSplitValueFinder(); /** * @param name * @param parent */ public RTree(String name, TreeNode<RegressionTreeStatistics> parent) { super(name, parent, new ChernoffSplitCriterion<RegressionTreeStatistics>() ); } /** * @see stream.learner.AbstractClassifier#learn(java.lang.Object) */ @Override public void learn(Data item) { // find the leaf node for the given data item and update the // model within that leaf // ModelTree<RegressionTreeStatistics,Double> leaf = getLeaf( item ); leaf.getModel().learn( item ); // // update the node statistics, which determine whether to split this // node or not // if( LearnerUtils.isNumerical( getName(), item ) ){ Double value = LearnerUtils.getDouble( getName(), item ); leaf.getNodeInfo().update( value ); BTreeNode btree = btrees.get( leaf.getName() ); if( btree == null ){ btree = new BTreeNode( leaf.getName(), value ); btrees.put( leaf.getName(), btree ); } else btree.insert( value ); } else throw new RuntimeException( "Nominal values are not supported!" ); //leaf.getNodeInfo().update( item.get( getName() ).toString() ); // // compute the chernoff bound based on the node-statistics // boolean requiresSplit = splitCriterion.requiresSplit( leaf.getNodeInfo() ); Double splitValue = getBestSplitValue( btrees.get( leaf.getName() ) ); if( requiresSplit ){ // // split this leaf and create siblings // ModelTree<RegressionTreeStatistics,Double> parent = (ModelTree<RegressionTreeStatistics,Double>) leaf.getParent(); ModelTree<RegressionTreeStatistics,Double> replacement = new ModelTree<RegressionTreeStatistics,Double>( leaf.getName(), parent, splitCriterion ); replacement.add( new ModelTree<RegressionTreeStatistics,Double>( leaf.getName(), null, splitCriterion ) ); } else { // // update this leaf's statistics and train the associated model // //btrees.get( leaf.getName() ).getNodeInfo().update( value ); } } public Double getBestSplitValue( BTreeNode btree ){ splitValueFinder.reset(); btree.inOrder( (Visitor<BinaryTreeNode<RegressionTreeStatistics,Double>>) splitValueFinder ); log.info( "Best split value is: {} (SDR: {})", splitValueFinder.getValue(), splitValueFinder.getMaximum() ); return splitValueFinder.getValue(); } /** * @see stream.learner.AbstractClassifier#predict(java.lang.Object) */ @Override public Double predict(Data item) { ModelTree<RegressionTreeStatistics,Double> leaf = getLeaf( item ); return leaf.predict( item ); } /** * @see stream.learner.Learner#init() */ @Override public void init() { } /** * This class is a visitor for tree nodes and maintains a maximum standard * deviation reduction during its visits. * * @author Christian Bockermann <christian.bockermann@udo.edu> * */ class BestSplitValueFinder implements Visitor<BinaryTreeNode<RegressionTreeStatistics,Double>> { Double maxValue = null; Double maxSdr = Double.NEGATIVE_INFINITY; public void reset(){ maxValue = null; maxSdr = Double.NEGATIVE_INFINITY; } public Double getValue(){ return maxValue; } public Double getMaximum(){ return maxSdr; } /** * @see stream.learner.tree.Visitor#visit(stream.learner.tree.BinaryTreeNode) */ @Override public void visit( BinaryTreeNode<RegressionTreeStatistics,Double> node) { if( maxValue == null ){ maxValue = node.getValue(); log.info( "Found initial split value: {} (sdr: {})", maxValue, maxSdr ); return; } Double sdr = getStandardDeviationReduction( node ); if( sdr > maxSdr ){ maxValue = node.getValue(); maxSdr = sdr; log.info( "Found new best split value: {} (sdr: {})", maxValue, maxSdr ); } } public Double getStandardDeviationReduction( BinaryTreeNode<RegressionTreeStatistics,Double> node ){ Double sdr = 0.0d; Double sdT = node.getNodeInfo().getStandardDeviation(); Double t = node.getNodeInfo().getNumberOfExamples(); Double t1 = 0.0d; Double sdT1 = 0.0d; if( node.getLeft() != null ){ t1 = node.getLeft().getNodeInfo().getNumberOfExamples(); sdT1 = node.getLeft().getNodeInfo().getStandardDeviation(); } Double t2 = 0.0d; Double sdT2 = 0.0d; if( node.getRight() != null ){ t2 = node.getRight().getNodeInfo().getNumberOfExamples(); sdT2 = node.getRight().getNodeInfo().getStandardDeviation(); } sdr = sdT - ( t1 / t ) * sdT1 - (t2 / t) * sdT2; return sdr; } } }