/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* GraftSplit.java
* Copyright (C) 2007 Geoff Webb & Janice Boughton
* a split object for nodes added to a tree during grafting.
* (used in classifier J48g).
*/
package weka.classifiers.trees.j48;
import weka.core.*;
/**
* Class implementing a split for nodes added to a tree during grafting.
*
* @author Janice Boughton (jrbought@infotech.monash.edu.au)
* @version $Revision 1.0 $
*/
public class GraftSplit
extends ClassifierSplitModel
implements Comparable {
/** for serialzation. */
private static final long serialVersionUID = 722773260393182051L;
/** the distribution for graft values, from cases in atbop */
private Distribution m_graftdistro;
/** the attribute we are splitting on */
private int m_attIndex;
/** value of split point (if numeric attribute) */
private double m_splitPoint;
/** dominant class of the subset specified by m_testType */
private int m_maxClass;
/** dominant class of the subset not specified by m_testType */
private int m_otherLeafMaxClass;
/** laplace value of the subset specified by m_testType for m_maxClass */
private double m_laplace;
/** leaf for the subset specified by m_testType */
private Distribution m_leafdistro;
/**
* type of test:
* 0: <= test
* 1: > test
* 2: = test
* 3: != test
*/
private int m_testType;
/**
* constructor
*
* @param a the attribute to split on
* @param v the value of a where split occurs
* @param t the test type (0 is <=, 1 is >, 2 is =, 3 is !)
* @param c the class to label the leaf node pointed to by test as.
* @param l the laplace value (needed when sorting GraftSplits)
*/
public GraftSplit(int a, double v, int t, double c, double l) {
m_attIndex = a;
m_splitPoint = v;
m_testType = t;
m_maxClass = (int)c;
m_laplace = l;
}
/**
* constructor
*
* @param a the attribute to split on
* @param v the value of a where split occurs
* @param t the test type (0 is <=, 1 is >, 2 is =, 3 is !=)
* @param oC the class to label the leaf node not pointed to by test as.
* @param counts the distribution for this split
*/
public GraftSplit(int a, double v, int t, double oC, double [][] counts)
throws Exception {
m_attIndex = a;
m_splitPoint = v;
m_testType = t;
m_otherLeafMaxClass = (int)oC;
// only deal with binary cuts (<= and >; = and !=)
m_numSubsets = 2;
// which subset are we looking at for the graft?
int subset = subsetOfInterest(); // this is the subset for m_leaf
// create graft distribution, based on counts
m_distribution = new Distribution(counts);
// create a distribution object for m_leaf
double [][] lcounts = new double[1][m_distribution.numClasses()];
for(int c = 0; c < lcounts[0].length; c++) {
lcounts[0][c] = counts[subset][c];
}
m_leafdistro = new Distribution(lcounts);
// set the max class
m_maxClass = m_distribution.maxClass(subset);
// set the laplace value (assumes binary class) for subset of interest
m_laplace = (m_distribution.perClassPerBag(subset, m_maxClass) + 1.0)
/ (m_distribution.perBag(subset) + 2.0);
}
/**
* deletes the cases in data that belong to leaf pointed to by
* the test (i.e. the subset of interest). this is useful so
* the instances belonging to that leaf aren't passed down the
* other branch.
*
* @param data the instances to delete from
*/
public void deleteGraftedCases(Instances data) {
int subOfInterest = subsetOfInterest();
for(int x = 0; x < data.numInstances(); x++) {
if(whichSubset(data.instance(x)) == subOfInterest) {
data.delete(x--);
}
}
}
/**
* builds m_graftdistro using the passed data
*
* @param data the instances to use when creating the distribution
*/
public void buildClassifier(Instances data) throws Exception {
// distribution for the graft, not counting cases in atbop, only orig leaf
m_graftdistro = new Distribution(2, data.numClasses());
// which subset are we looking at for the graft?
int subset = subsetOfInterest(); // this is the subset for m_leaf
double thisNodeCount = 0;
double knownCases = 0;
boolean allKnown = true;
// populate distribution
for(int x = 0; x < data.numInstances(); x++) {
Instance instance = data.instance(x);
if(instance.isMissing(m_attIndex)) {
allKnown = false;
continue;
}
knownCases += instance.weight();
int subst = whichSubset(instance);
if(subst == -1)
continue;
m_graftdistro.add(subst, instance);
if(subst == subset) { // instance belongs at m_leaf
thisNodeCount += instance.weight();
}
}
double factor = (knownCases == 0) ? (1.0 / (double)2.0)
: (thisNodeCount / knownCases);
if(!allKnown) {
for(int x = 0; x < data.numInstances(); x++) {
if(data.instance(x).isMissing(m_attIndex)) {
Instance instance = data.instance(x);
int subst = whichSubset(instance);
if(subst == -1)
continue;
instance.setWeight(instance.weight() * factor);
m_graftdistro.add(subst, instance);
}
}
}
// if there are no cases at the leaf, make sure the desired
// class is chosen, by setting counts to 0.01
if(m_graftdistro.perBag(subset) == 0) {
double [] counts = new double[data.numClasses()];
counts[m_maxClass] = 0.01;
m_graftdistro.add(subset, counts);
}
if(m_graftdistro.perBag((subset == 0) ? 1 : 0) == 0) {
double [] counts = new double[data.numClasses()];
counts[(int)m_otherLeafMaxClass] = 0.01;
m_graftdistro.add((subset == 0) ? 1 : 0, counts);
}
}
/**
* @return the NoSplit object for the leaf pointed to by m_testType branch
*/
public NoSplit getLeaf() {
return new NoSplit(m_leafdistro);
}
/**
* @return the NoSplit object for the leaf not pointed to by m_testType branch
*/
public NoSplit getOtherLeaf() {
// the bag (subset) that isn't pointed to by m_testType branch
int bag = (subsetOfInterest() == 0) ? 1 : 0;
double [][] counts = new double[1][m_graftdistro.numClasses()];
double totals = 0;
for(int c = 0; c < counts[0].length; c++) {
counts[0][c] = m_graftdistro.perClassPerBag(bag, c);
totals += counts[0][c];
}
// if empty, make sure proper class gets chosen
if(totals == 0) {
counts[0][m_otherLeafMaxClass] += 0.01;
}
return new NoSplit(new Distribution(counts));
}
/**
* Prints label for subset index of instances (eg class).
*
* @param index the bag to dump label for
* @param data to get attribute names and such
* @return the label as a string
* @exception Exception if something goes wrong
*/
public final String dumpLabelG(int index, Instances data) throws Exception {
StringBuffer text;
text = new StringBuffer();
text.append(((Instances)data).classAttribute().
value((index==subsetOfInterest()) ? m_maxClass : m_otherLeafMaxClass));
text.append(" ("+Utils.roundDouble(m_graftdistro.perBag(index),1));
if(Utils.gr(m_graftdistro.numIncorrect(index),0))
text.append("/"
+Utils.roundDouble(m_graftdistro.numIncorrect(index),2));
// show the graft values, only if this is subsetOfInterest()
if(index == subsetOfInterest()) {
text.append("|"+Utils.roundDouble(m_distribution.perBag(index),2));
if(Utils.gr(m_distribution.numIncorrect(index),0))
text.append("/"
+Utils.roundDouble(m_distribution.numIncorrect(index),2));
}
text.append(")");
return text.toString();
}
/**
* @return the subset that is specified by the test type
*/
public int subsetOfInterest() {
if(m_testType == 2)
return 0;
if(m_testType == 3)
return 1;
return m_testType;
}
/**
* @return the number of positive cases in the subset of interest
*/
public double positivesForSubsetOfInterest() {
return (m_distribution.perClassPerBag(subsetOfInterest(), m_maxClass));
}
/**
* @param subset the subset to get the positives for
* @return the number of positive cases in the specified subset
*/
public double positives(int subset) {
return (m_distribution.perClassPerBag(subset,
m_distribution.maxClass(subset)));
}
/**
* @return the number of instances in the subset of interest
*/
public double totalForSubsetOfInterest() {
return (m_distribution.perBag(subsetOfInterest()));
}
/**
* @param subset the index of the bag to get the total for
* @return the number of instances in the subset
*/
public double totalForSubset(int subset) {
return (m_distribution.perBag(subset));
}
/**
* Prints left side of condition satisfied by instances.
*
* @param data the data.
*/
public String leftSide(Instances data) {
return data.attribute(m_attIndex).name();
}
/**
* @return the index of the attribute to split on
*/
public int attribute() {
return m_attIndex;
}
/**
* Prints condition satisfied by instances in subset index.
*/
public final String rightSide(int index, Instances data) {
StringBuffer text;
text = new StringBuffer();
if(data.attribute(m_attIndex).isNominal())
if(index == 0)
text.append(" = "+
data.attribute(m_attIndex).value((int)m_splitPoint));
else
text.append(" != "+
data.attribute(m_attIndex).value((int)m_splitPoint));
else
if(index == 0)
text.append(" <= "+
Utils.doubleToString(m_splitPoint,6));
else
text.append(" > "+
Utils.doubleToString(m_splitPoint,6));
return text.toString();
}
/**
* Returns a string containing java source code equivalent to the test
* made at this node. The instance being tested is called "i".
*
* @param index index of the nominal value tested
* @param data the data containing instance structure info
* @return a value of type 'String'
*/
public final String sourceExpression(int index, Instances data) {
StringBuffer expr = null;
if(index < 0) {
return "i[" + m_attIndex + "] == null";
}
if(data.attribute(m_attIndex).isNominal()) {
if(index == 0)
expr = new StringBuffer("i[");
else
expr = new StringBuffer("!i[");
expr.append(m_attIndex).append("]");
expr.append(".equals(\"").append(data.attribute(m_attIndex)
.value((int)m_splitPoint)).append("\")");
} else {
expr = new StringBuffer("((Double) i[");
expr.append(m_attIndex).append("])");
if(index == 0) {
expr.append(".doubleValue() <= ").append(m_splitPoint);
} else {
expr.append(".doubleValue() > ").append(m_splitPoint);
}
}
return expr.toString();
}
/**
* @param instance the instance to produce the weights for
* @return a double array of weights, null if only belongs to one subset
*/
public double [] weights(Instance instance) {
double [] weights;
int i;
if(instance.isMissing(m_attIndex)) {
weights = new double [m_numSubsets];
for(i=0;i<m_numSubsets;i++) {
weights [i] = m_graftdistro.perBag(i)/m_graftdistro.total();
}
return weights;
} else {
return null;
}
}
/**
* @param instance the instance for which to determine the subset
* @return an int indicating the subset this instance belongs to
*/
public int whichSubset(Instance instance) {
if(instance.isMissing(m_attIndex))
return -1;
if(instance.attribute(m_attIndex).isNominal()) {
// in the case of nominal, m_splitPoint is the = value, all else is !=
if(instance.value(m_attIndex) == m_splitPoint)
return 0;
else
return 1;
} else {
if(Utils.smOrEq(instance.value(m_attIndex), m_splitPoint))
return 0;
else
return 1;
}
}
/**
* @return the value of the split point
*/
public double splitPoint() {
return m_splitPoint;
}
/**
* @return the dominate class for the subset of interest
*/
public int maxClassForSubsetOfInterest() {
return m_maxClass;
}
/**
* @return the laplace value for maxClass of subset of interest
*/
public double laplaceForSubsetOfInterest() {
return m_laplace;
}
/**
* returns the test type
* @return value of testtype
*/
public int testType() {
return m_testType;
}
/**
* method needed for sorting a collection of GraftSplits by laplace value
* @param g the graft split to compare to this one
* @return -1, 0, or 1 if this GraftSplit laplace is <, = or > than that of g
*/
public int compareTo(Object g) {
if(m_laplace > ((GraftSplit)g).laplaceForSubsetOfInterest())
return 1;
if(m_laplace < ((GraftSplit)g).laplaceForSubsetOfInterest())
return -1;
return 0;
}
/**
* returns the probability for instance for the specified class
* @param classIndex the index of the class
* @param instance the instance to get the probability for
* @param theSubset the subset
*/
public final double classProb(int classIndex, Instance instance,
int theSubset) throws Exception {
if (theSubset <= -1) {
double [] weights = weights(instance);
if (weights == null) {
return m_distribution.prob(classIndex);
} else {
double prob = 0;
for (int i = 0; i < weights.length; i++) {
prob += weights[i] * m_distribution.prob(classIndex, i);
}
return prob;
}
} else {
if (Utils.gr(m_distribution.perBag(theSubset), 0)) {
return m_distribution.prob(classIndex, theSubset);
} else {
return m_distribution.prob(classIndex);
}
}
}
/**
* method for returning information about this GraftSplit
* @param data instances for determining names of attributes and values
* @return a string showing this GraftSplit's information
*/
public String toString(Instances data) {
String theTest;
if(m_testType == 0)
theTest = " <= ";
else if(m_testType == 1)
theTest = " > ";
else if(m_testType == 2)
theTest = " = ";
else
theTest = " != ";
if(data.attribute(m_attIndex).isNominal())
theTest += data.attribute(m_attIndex).value((int)m_splitPoint);
else
theTest += Double.toString(m_splitPoint);
return data.attribute(m_attIndex).name() + theTest
+ " (" + Double.toString(m_laplace) + ") --> "
+ data.attribute(data.classIndex()).value(m_maxClass);
}
/**
* Returns the revision string.
*
* @return the revision
*/
public String getRevision() {
return RevisionUtils.extract("$Revision: 1.2 $");
}
}