/**
*
*/
package org.streaminer.stream.classifier.tree;
import org.streaminer.stream.classifier.Classifier;
import org.streaminer.stream.data.Data;
import org.streaminer.stream.learner.LearnerUtils;
import org.streaminer.stream.model.PredictionModel;
/**
* <p>
* This class implements a general model tree, which basically is a decision tree
* with prediction models at its leaf nodes.
* </p>
*
* @author Christian Bockermann <christian.bockermann@udo.edu>
*
*/
public class ModelTree<I extends NodeInfo,O>
extends TreeNode<I>
implements PredictionModel<Data,O>
{
/** The unique class ID */
private static final long serialVersionUID = 4855897886594236180L;
Classifier<Data,O> model;
SplitCriterion<I> splitCrierion;
/**
* @param name
* @param parent
*/
public ModelTree(String name, TreeNode<I> parent, SplitCriterion<I> splitCrit ) {
super(name, null, parent);
splitCrierion = splitCrit;
}
/**
* @return the model
*/
public Classifier<Data, O> getModel() {
return model;
}
/**
* @param model the model to set
*/
public void setModel(Classifier<Data, O> model) {
this.model = model;
}
/**
* @see stream.model.PredictionModel#predict(java.lang.Object)
*/
@Override
public O predict(Data item) {
ModelTree<I,O> leaf = getLeaf( item );
if( leaf != null )
return leaf.model.predict(item);
return null;
}
/**
* This method traverses the tree by testing the given data item at each
* inner node until reaching a leaf.
*
* @param item
* @return
*/
public ModelTree<I,O> getLeaf( Data item ){
if( this.isLeaf() )
return this;
if( this.value == null )
throw new RuntimeException( "Weird error! This node is not a leaf, but also does not contain a threshold value!" );
if( LearnerUtils.isNumerical( getName(), item ) ){
//
// This checks for the best matching successor based on numerical
// intervals obtained from each sibling
//
Double val = LearnerUtils.getDouble( getName(), item );
ModelTree<I,O> child = getChildFor( val );
if( child != null )
return child.getLeaf( item );
} else {
// check for matching child on nominal attribute value
//
ModelTree<I,O> child = getChildFor( item.get( getName() ).toString() );
if( child != null )
return child.getLeaf( item );
}
return null;
}
/**
* This method checks all siblings of the current node and returns the
* successor that matches the given value. This method handles the case
* of a real-valued condition.
*
* @param val
* @return
*/
@SuppressWarnings("unchecked")
public ModelTree<I,O> getChildFor( Double val ){
//
// the very first sibbling corresponds to ] -infinity; value ]
//
Double lower = Double.NEGATIVE_INFINITY;
for( int i = 0; i < children.size(); i++ ){
ModelTree<I,O> child = (ModelTree<I,O>) children.get(i);
Double upper = (Double) child.value;
if( lower < val && val <= upper ){
return child;
} else {
lower = upper;
}
}
return null;
}
/**
* This method returns the child for the given nominal value.
*
* @param value
* @return
*/
@SuppressWarnings("unchecked")
public ModelTree<I,O> getChildFor( String value ){
//
// check for matching child on nominal attribute value
//
for( int i = 0; i < children.size(); i++ ){
ModelTree<I,O> child = (ModelTree<I,O>) children.get(i);
if( child.value.equals( value ) )
return child;
}
return null;
}
}