/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */ /* * LADTree.java * Copyright (C) 2001 University of Waikato, Hamilton, New Zealand * */ package weka.classifiers.trees; import weka.classifiers.*; import weka.core.Capabilities; import weka.core.Capabilities.Capability; import weka.core.*; import weka.classifiers.trees.adtree.ReferenceInstances; import java.util.*; import java.io.*; import weka.core.TechnicalInformation; import weka.core.TechnicalInformationHandler; import weka.core.TechnicalInformation.Field; import weka.core.TechnicalInformation.Type; /** <!-- globalinfo-start --> * Class for generating a multi-class alternating decision tree using the LogitBoost strategy. For more info, see<br/> * <br/> * Geoffrey Holmes, Bernhard Pfahringer, Richard Kirkby, Eibe Frank, Mark Hall: Multiclass alternating decision trees. In: ECML, 161-172, 2001. * <p/> <!-- globalinfo-end --> * <!-- technical-bibtex-start --> * BibTeX: * <pre> * @inproceedings{Holmes2001, * author = {Geoffrey Holmes and Bernhard Pfahringer and Richard Kirkby and Eibe Frank and Mark Hall}, * booktitle = {ECML}, * pages = {161-172}, * publisher = {Springer}, * title = {Multiclass alternating decision trees}, * year = {2001} * } * </pre> * <p/> <!-- technical-bibtex-end --> * <!-- options-start --> * Valid options are: <p/> * * <pre> -B <number of boosting iterations> * Number of boosting iterations. * (Default = 10)</pre> * * <pre> -D * If set, classifier is run in debug mode and * may output additional info to the console</pre> * <!-- options-end --> * * @author Richard Kirkby * @version $Revision: 6036 $ */ public class LADTree extends AbstractClassifier implements Drawable, AdditionalMeasureProducer, TechnicalInformationHandler { /** * For serialization */ private static final long serialVersionUID = -4940716114518300302L; // Constant from LogitBoost protected double Z_MAX = 4; // Number of classes protected int m_numOfClasses; // Instances as reference instances protected ReferenceInstances m_trainInstances; // Root of the tree protected PredictionNode m_root = null; // To keep track of the order in which splits are added protected int m_lastAddedSplitNum = 0; // Indices for numeric attributes protected int[] m_numericAttIndices; // Variables to keep track of best options protected double m_search_smallestLeastSquares; protected PredictionNode m_search_bestInsertionNode; protected Splitter m_search_bestSplitter; protected Instances m_search_bestPathInstances; // A collection of splitter nodes protected FastVector m_staticPotentialSplitters2way; // statistics protected int m_nodesExpanded = 0; protected int m_examplesCounted = 0; // options protected int m_boostingIterations = 10; /** * Returns a string describing classifier * @return a description suitable for * displaying in the explorer/experimenter gui */ public String globalInfo() { return "Class for generating a multi-class alternating decision tree using " + "the LogitBoost strategy. For more info, see\n\n" + getTechnicalInformation().toString(); } /** * Returns an instance of a TechnicalInformation object, containing * detailed information about the technical background of this class, * e.g., paper reference or book this class is based on. * * @return the technical information about this class */ public TechnicalInformation getTechnicalInformation() { TechnicalInformation result; result = new TechnicalInformation(Type.INPROCEEDINGS); result.setValue(Field.AUTHOR, "Geoffrey Holmes and Bernhard Pfahringer and Richard Kirkby and Eibe Frank and Mark Hall"); result.setValue(Field.TITLE, "Multiclass alternating decision trees"); result.setValue(Field.BOOKTITLE, "ECML"); result.setValue(Field.YEAR, "2001"); result.setValue(Field.PAGES, "161-172"); result.setValue(Field.PUBLISHER, "Springer"); return result; } /** helper classes ********************************************************************/ protected class LADInstance extends DenseInstance { public double[] fVector; public double[] wVector; public double[] pVector; public double[] zVector; public LADInstance(Instance instance) { super(instance); // copy the instance setDataset(instance.dataset()); // preserve dataset // set up vectors fVector = new double[m_numOfClasses]; wVector = new double[m_numOfClasses]; pVector = new double[m_numOfClasses]; zVector = new double[m_numOfClasses]; // set initial probabilities double initProb = 1.0 / ((double) m_numOfClasses); for (int i=0; i<m_numOfClasses; i++) { pVector[i] = initProb; } updateZVector(); updateWVector(); } public void updateWeights(double[] fVectorIncrement) { for (int i=0; i<fVector.length; i++) { fVector[i] += fVectorIncrement[i]; } updateVectors(fVector); } public void updateVectors(double[] newFVector) { updatePVector(newFVector); updateZVector(); updateWVector(); } public void updatePVector(double[] newFVector) { double max = newFVector[Utils.maxIndex(newFVector)]; for (int i=0; i<pVector.length; i++) { pVector[i] = Math.exp(newFVector[i] - max); } Utils.normalize(pVector); } public void updateWVector() { for (int i=0; i<wVector.length; i++) { wVector[i] = (yVector(i) - pVector[i]) / zVector[i]; } } public void updateZVector() { for (int i=0; i<zVector.length; i++) { if (yVector(i) == 1) { zVector[i] = 1.0 / pVector[i]; if (zVector[i] > Z_MAX) { // threshold zVector[i] = Z_MAX; } } else { zVector[i] = -1.0 / (1.0 - pVector[i]); if (zVector[i] < -Z_MAX) { // threshold zVector[i] = -Z_MAX; } } } } public double yVector(int index) { return (index == (int) classValue() ? 1.0 : 0.0); } public Object copy() { LADInstance copy = new LADInstance((Instance) super.copy()); System.arraycopy(fVector, 0, copy.fVector, 0, fVector.length); System.arraycopy(wVector, 0, copy.wVector, 0, wVector.length); System.arraycopy(pVector, 0, copy.pVector, 0, pVector.length); System.arraycopy(zVector, 0, copy.zVector, 0, zVector.length); return copy; } public String toString() { StringBuffer text = new StringBuffer(); text.append(" * F("); for (int i=0; i<fVector.length; i++) { text.append(Utils.doubleToString(fVector[i], 3)); if (i<fVector.length-1) text.append(","); } text.append(") P("); for (int i=0; i<pVector.length; i++) { text.append(Utils.doubleToString(pVector[i], 3)); if (i<pVector.length-1) text.append(","); } text.append(") W("); for (int i=0; i<wVector.length; i++) { text.append(Utils.doubleToString(wVector[i], 3)); if (i<wVector.length-1) text.append(","); } text.append(")"); return super.toString() + text.toString(); } } protected class PredictionNode implements Serializable, Cloneable{ private double[] values; private FastVector children; // any number of splitter nodes public PredictionNode(double[] newValues) { values = new double[m_numOfClasses]; setValues(newValues); children = new FastVector(); } public void setValues(double[] newValues) { System.arraycopy(newValues, 0, values, 0, m_numOfClasses); } public double[] getValues() { return values; } public FastVector getChildren() { return children; } public Enumeration children() { return children.elements(); } public void addChild(Splitter newChild) { // merges, adds a clone (deep copy) Splitter oldEqual = null; for (Enumeration e = children(); e.hasMoreElements(); ) { Splitter split = (Splitter) e.nextElement(); if (newChild.equalTo(split)) { oldEqual = split; break; } } if (oldEqual == null) { Splitter addChild = (Splitter) newChild.clone(); addChild.orderAdded = ++m_lastAddedSplitNum; children.addElement(addChild); } else { // do a merge for (int i=0; i<newChild.getNumOfBranches(); i++) { PredictionNode oldPred = oldEqual.getChildForBranch(i); PredictionNode newPred = newChild.getChildForBranch(i); if (oldPred != null && newPred != null) oldPred.merge(newPred); } } } public Object clone() { // does a deep copy (recurses through tree) PredictionNode clone = new PredictionNode(values); // should actually clone once merges are enabled! for (Enumeration e = children.elements(); e.hasMoreElements(); ) clone.children.addElement((Splitter)((Splitter) e.nextElement()).clone()); return clone; } public void merge(PredictionNode merger) { // need to merge linear models here somehow for (int i=0; i<m_numOfClasses; i++) values[i] += merger.values[i]; for (Enumeration e = merger.children(); e.hasMoreElements(); ) { addChild((Splitter)e.nextElement()); } } } /** splitter classes ******************************************************************/ protected abstract class Splitter implements Serializable, Cloneable { protected int attIndex; public int orderAdded; public abstract int getNumOfBranches(); public abstract int branchInstanceGoesDown(Instance i); public abstract Instances instancesDownBranch(int branch, Instances sourceInstances); public abstract String attributeString(); public abstract String comparisonString(int branchNum); public abstract boolean equalTo(Splitter compare); public abstract void setChildForBranch(int branchNum, PredictionNode childPredictor); public abstract PredictionNode getChildForBranch(int branchNum); public abstract Object clone(); } protected class TwoWayNominalSplit extends Splitter { //private int attIndex; private int trueSplitValue; private PredictionNode[] children; public TwoWayNominalSplit(int _attIndex, int _trueSplitValue) { attIndex = _attIndex; trueSplitValue = _trueSplitValue; children = new PredictionNode[2]; } public int getNumOfBranches() { return 2; } public int branchInstanceGoesDown(Instance inst) { if (inst.isMissing(attIndex)) return -1; else if (inst.value(attIndex) == trueSplitValue) return 0; else return 1; } public Instances instancesDownBranch(int branch, Instances instances) { ReferenceInstances filteredInstances = new ReferenceInstances(instances, 1); if (branch == -1) { for (Enumeration e = instances.enumerateInstances(); e.hasMoreElements(); ) { Instance inst = (Instance) e.nextElement(); if (inst.isMissing(attIndex)) filteredInstances.addReference(inst); } } else if (branch == 0) { for (Enumeration e = instances.enumerateInstances(); e.hasMoreElements(); ) { Instance inst = (Instance) e.nextElement(); if (!inst.isMissing(attIndex) && inst.value(attIndex) == trueSplitValue) filteredInstances.addReference(inst); } } else { for (Enumeration e = instances.enumerateInstances(); e.hasMoreElements(); ) { Instance inst = (Instance) e.nextElement(); if (!inst.isMissing(attIndex) && inst.value(attIndex) != trueSplitValue) filteredInstances.addReference(inst); } } return filteredInstances; } public String attributeString() { return m_trainInstances.attribute(attIndex).name(); } public String comparisonString(int branchNum) { Attribute att = m_trainInstances.attribute(attIndex); if (att.numValues() != 2) return ((branchNum == 0 ? "= " : "!= ") + att.value(trueSplitValue)); else return ("= " + (branchNum == 0 ? att.value(trueSplitValue) : att.value(trueSplitValue == 0 ? 1 : 0))); } public boolean equalTo(Splitter compare) { if (compare instanceof TwoWayNominalSplit) { // test object type TwoWayNominalSplit compareSame = (TwoWayNominalSplit) compare; return (attIndex == compareSame.attIndex && trueSplitValue == compareSame.trueSplitValue); } else return false; } public void setChildForBranch(int branchNum, PredictionNode childPredictor) { children[branchNum] = childPredictor; } public PredictionNode getChildForBranch(int branchNum) { return children[branchNum]; } public Object clone() { // deep copy TwoWayNominalSplit clone = new TwoWayNominalSplit(attIndex, trueSplitValue); if (children[0] != null) clone.setChildForBranch(0, (PredictionNode) children[0].clone()); if (children[1] != null) clone.setChildForBranch(1, (PredictionNode) children[1].clone()); return clone; } } protected class TwoWayNumericSplit extends Splitter implements Cloneable { //private int attIndex; private double splitPoint; private PredictionNode[] children; public TwoWayNumericSplit(int _attIndex, double _splitPoint) { attIndex = _attIndex; splitPoint = _splitPoint; children = new PredictionNode[2]; } public TwoWayNumericSplit(int _attIndex, Instances instances) throws Exception { attIndex = _attIndex; splitPoint = findSplit(instances, attIndex); children = new PredictionNode[2]; } public int getNumOfBranches() { return 2; } public int branchInstanceGoesDown(Instance inst) { if (inst.isMissing(attIndex)) return -1; else if (inst.value(attIndex) < splitPoint) return 0; else return 1; } public Instances instancesDownBranch(int branch, Instances instances) { ReferenceInstances filteredInstances = new ReferenceInstances(instances, 1); if (branch == -1) { for (Enumeration e = instances.enumerateInstances(); e.hasMoreElements(); ) { Instance inst = (Instance) e.nextElement(); if (inst.isMissing(attIndex)) filteredInstances.addReference(inst); } } else if (branch == 0) { for (Enumeration e = instances.enumerateInstances(); e.hasMoreElements(); ) { Instance inst = (Instance) e.nextElement(); if (!inst.isMissing(attIndex) && inst.value(attIndex) < splitPoint) filteredInstances.addReference(inst); } } else { for (Enumeration e = instances.enumerateInstances(); e.hasMoreElements(); ) { Instance inst = (Instance) e.nextElement(); if (!inst.isMissing(attIndex) && inst.value(attIndex) >= splitPoint) filteredInstances.addReference(inst); } } return filteredInstances; } public String attributeString() { return m_trainInstances.attribute(attIndex).name(); } public String comparisonString(int branchNum) { return ((branchNum == 0 ? "< " : ">= ") + Utils.doubleToString(splitPoint, 3)); } public boolean equalTo(Splitter compare) { if (compare instanceof TwoWayNumericSplit) { // test object type TwoWayNumericSplit compareSame = (TwoWayNumericSplit) compare; return (attIndex == compareSame.attIndex && splitPoint == compareSame.splitPoint); } else return false; } public void setChildForBranch(int branchNum, PredictionNode childPredictor) { children[branchNum] = childPredictor; } public PredictionNode getChildForBranch(int branchNum) { return children[branchNum]; } public Object clone() { // deep copy TwoWayNumericSplit clone = new TwoWayNumericSplit(attIndex, splitPoint); if (children[0] != null) clone.setChildForBranch(0, (PredictionNode) children[0].clone()); if (children[1] != null) clone.setChildForBranch(1, (PredictionNode) children[1].clone()); return clone; } private double findSplit(Instances instances, int index) throws Exception { double splitPoint = 0; double bestVal = Double.MAX_VALUE, currVal, currCutPoint; int numMissing = 0; double[][] distribution = new double[3][instances.numClasses()]; // Compute counts for all the values for (int i = 0; i < instances.numInstances(); i++) { Instance inst = instances.instance(i); if (!inst.isMissing(index)) { distribution[1][(int)inst.classValue()] ++; } else { distribution[2][(int)inst.classValue()] ++; numMissing++; } } // Sort instances instances.sort(index); // Make split counts for each possible split and evaluate for (int i = 0; i < instances.numInstances() - (numMissing + 1); i++) { Instance inst = instances.instance(i); Instance instPlusOne = instances.instance(i + 1); distribution[0][(int)inst.classValue()] += inst.weight(); distribution[1][(int)inst.classValue()] -= inst.weight(); if (Utils.sm(inst.value(index), instPlusOne.value(index))) { currCutPoint = (inst.value(index) + instPlusOne.value(index)) / 2.0; currVal = ContingencyTables.entropyConditionedOnRows(distribution); if (Utils.sm(currVal, bestVal)) { splitPoint = currCutPoint; bestVal = currVal; } } } return splitPoint; } } /** * Sets up the tree ready to be trained. * * @param instances the instances to train the tree with * @exception Exception if training data is unsuitable */ public void initClassifier(Instances instances) throws Exception { // clear stats m_nodesExpanded = 0; m_examplesCounted = 0; m_lastAddedSplitNum = 0; m_numOfClasses = instances.numClasses(); // make sure training data is suitable if (instances.checkForStringAttributes()) { throw new Exception("Can't handle string attributes!"); } if (!instances.classAttribute().isNominal()) { throw new Exception("Class must be nominal!"); } // create training set (use LADInstance class) m_trainInstances = new ReferenceInstances(instances, instances.numInstances()); for (Enumeration e = instances.enumerateInstances(); e.hasMoreElements(); ) { Instance inst = (Instance) e.nextElement(); if (!inst.classIsMissing()) { LADInstance adtInst = new LADInstance(inst); m_trainInstances.addReference(adtInst); adtInst.setDataset(m_trainInstances); } } // create the root prediction node m_root = new PredictionNode(new double[m_numOfClasses]); // pre-calculate what we can generateStaticPotentialSplittersAndNumericIndices(); } public void next(int iteration) throws Exception { boost(); } public void done() throws Exception {} /** * Performs a single boosting iteration. * Will add a new splitter node and two prediction nodes to the tree * (unless merging takes place). * * @exception Exception if try to boost without setting up tree first */ private void boost() throws Exception { if (m_trainInstances == null) throw new Exception("Trying to boost with no training data"); // perform the search searchForBestTest(); if (m_Debug) { System.out.println("Best split found: " + m_search_bestSplitter.getNumOfBranches() + "-way split on " + m_search_bestSplitter.attributeString() //+ "\nsmallestLeastSquares = " + m_search_smallestLeastSquares); + "\nBestGain = " + m_search_smallestLeastSquares); } if (m_search_bestSplitter == null) return; // handle empty instances // create the new nodes for the tree, updating the weights for (int i=0; i<m_search_bestSplitter.getNumOfBranches(); i++) { Instances applicableInstances = m_search_bestSplitter.instancesDownBranch(i, m_search_bestPathInstances); double[] predictionValues = calcPredictionValues(applicableInstances); PredictionNode newPredictor = new PredictionNode(predictionValues); updateWeights(applicableInstances, predictionValues); m_search_bestSplitter.setChildForBranch(i, newPredictor); } // insert the new nodes m_search_bestInsertionNode.addChild((Splitter) m_search_bestSplitter); if (m_Debug) { System.out.println("Tree is now:\n" + toString(m_root, 1) + "\n"); //System.out.println("Instances are now:\n" + m_trainInstances + "\n"); } // free memory m_search_bestPathInstances = null; } private void updateWeights(Instances instances, double[] newPredictionValues) { for (int i=0; i<instances.numInstances(); i++) ((LADInstance) instances.instance(i)).updateWeights(newPredictionValues); } /** * Generates the m_staticPotentialSplitters2way * vector to contain all possible nominal splits, and the m_numericAttIndices array to * index the numeric attributes in the training data. * */ private void generateStaticPotentialSplittersAndNumericIndices() { m_staticPotentialSplitters2way = new FastVector(); FastVector numericIndices = new FastVector(); for (int i=0; i<m_trainInstances.numAttributes(); i++) { if (i == m_trainInstances.classIndex()) continue; if (m_trainInstances.attribute(i).isNumeric()) numericIndices.addElement(new Integer(i)); else { int numValues = m_trainInstances.attribute(i).numValues(); if (numValues == 2) // avoid redundancy due to 2-way symmetry m_staticPotentialSplitters2way.addElement(new TwoWayNominalSplit(i, 0)); else for (int j=0; j<numValues; j++) m_staticPotentialSplitters2way.addElement(new TwoWayNominalSplit(i, j)); } } m_numericAttIndices = new int[numericIndices.size()]; for (int i=0; i<numericIndices.size(); i++) m_numericAttIndices[i] = ((Integer)numericIndices.elementAt(i)).intValue(); } /** * Performs a search for the best test (splitter) to add to the tree, by looking * for the largest weight change. * * @exception Exception if search fails */ private void searchForBestTest() throws Exception { if (m_Debug) { System.out.println("Searching for best split..."); } m_search_smallestLeastSquares = 0.0; //Double.POSITIVE_INFINITY; searchForBestTest(m_root, m_trainInstances); } /** * Recursive function that carries out search for the best test (splitter) to add to * this part of the tree, by looking for the largest weight change. Will try 2-way * and/or multi-way splits depending on the current state. * * @param currentNode the root of the subtree to be searched, and the current node * being considered as parent of a new split * @param instances the instances that apply at this node * @exception Exception if search fails */ private void searchForBestTest(PredictionNode currentNode, Instances instances) throws Exception { // keep stats m_nodesExpanded++; m_examplesCounted += instances.numInstances(); // evaluate static splitters (nominal) for (Enumeration e = m_staticPotentialSplitters2way.elements(); e.hasMoreElements(); ) { evaluateSplitter((Splitter) e.nextElement(), currentNode, instances); } if (m_Debug) { //System.out.println("Instances considered are: " + instances); } // evaluate dynamic splitters (numeric) for (int i=0; i<m_numericAttIndices.length; i++) { evaluateNumericSplit(currentNode, instances, m_numericAttIndices[i]); } if (currentNode.getChildren().size() == 0) return; // keep searching goDownAllPaths(currentNode, instances); } /** * Continues general multi-class search by investigating every node in the * subtree under currentNode. * * @param currentNode the root of the subtree to be searched * @param instances the instances that apply at this node * @exception Exception if search fails */ private void goDownAllPaths(PredictionNode currentNode, Instances instances) throws Exception { for (Enumeration e = currentNode.children(); e.hasMoreElements(); ) { Splitter split = (Splitter) e.nextElement(); for (int i=0; i<split.getNumOfBranches(); i++) searchForBestTest(split.getChildForBranch(i), split.instancesDownBranch(i, instances)); } } /** * Investigates the option of introducing a split under currentNode. If the * split creates a weight change that is larger than has already been found it will * update the search information to record this as the best option so far. * * @param split the splitter node to evaluate * @param currentNode the parent under which the split is to be considered * @param instances the instances that apply at this node * @exception Exception if something goes wrong */ private void evaluateSplitter(Splitter split, PredictionNode currentNode, Instances instances) throws Exception { double leastSquares = leastSquaresNonMissing(instances,split.attIndex); for (int i=0; i<split.getNumOfBranches(); i++) leastSquares -= leastSquares(split.instancesDownBranch(i, instances)); if (m_Debug) { //System.out.println("Instances considered are: " + instances); System.out.print(split.getNumOfBranches() + "-way split on " + split.attributeString() + " has leastSquares value of " + Utils.doubleToString(leastSquares,3)); } if (leastSquares > m_search_smallestLeastSquares) { if (m_Debug) { System.out.print(" (best so far)"); } m_search_smallestLeastSquares = leastSquares; m_search_bestInsertionNode = currentNode; m_search_bestSplitter = split; m_search_bestPathInstances = instances; } if (m_Debug) { System.out.print("\n"); } } private void evaluateNumericSplit(PredictionNode currentNode, Instances instances, int attIndex) { double[] splitAndLS = findNumericSplitpointAndLS(instances, attIndex); double gain = leastSquaresNonMissing(instances,attIndex) - splitAndLS[1]; if (m_Debug) { //System.out.println("Instances considered are: " + instances); System.out.print("Numeric split on " + instances.attribute(attIndex).name() + " has leastSquares value of " //+ Utils.doubleToString(splitAndLS[1],3)); + Utils.doubleToString(gain,3)); } if (gain > m_search_smallestLeastSquares) { if (m_Debug) { System.out.print(" (best so far)"); } m_search_smallestLeastSquares = gain; //splitAndLS[1]; m_search_bestInsertionNode = currentNode; m_search_bestSplitter = new TwoWayNumericSplit(attIndex, splitAndLS[0]);; m_search_bestPathInstances = instances; } if (m_Debug) { System.out.print("\n"); } } private double[] findNumericSplitpointAndLS(Instances instances, int attIndex) { double allLS = leastSquares(instances); // all instances in right subset double[] term1L = new double[m_numOfClasses]; double[] term2L = new double[m_numOfClasses]; double[] term3L = new double[m_numOfClasses]; double[] meanNumL = new double[m_numOfClasses]; double[] meanDenL = new double[m_numOfClasses]; double[] term1R = new double[m_numOfClasses]; double[] term2R = new double[m_numOfClasses]; double[] term3R = new double[m_numOfClasses]; double[] meanNumR = new double[m_numOfClasses]; double[] meanDenR = new double[m_numOfClasses]; double temp1, temp2, temp3; double[] classMeans = new double[m_numOfClasses]; double[] classTotals = new double[m_numOfClasses]; // fill up RHS for (int j=0; j<m_numOfClasses; j++) { for (int i=0; i<instances.numInstances(); i++) { LADInstance inst = (LADInstance) instances.instance(i); temp1 = inst.wVector[j] * inst.zVector[j]; term1R[j] += temp1 * inst.zVector[j]; term2R[j] += temp1; term3R[j] += inst.wVector[j]; meanNumR[j] += inst.wVector[j] * inst.zVector[j]; } } //leastSquares = term1 - (2.0 * u * term2) + (u * u * term3); double leastSquares; boolean newSplit; double smallestLeastSquares = Double.POSITIVE_INFINITY; double bestSplit = 0.0; double meanL, meanR; instances.sort(attIndex); for (int i=0; i<instances.numInstances()-1; i++) {// shift inst from right to left if (instances.instance(i+1).isMissing(attIndex)) break; if (instances.instance(i+1).value(attIndex) > instances.instance(i).value(attIndex)) newSplit = true; else newSplit = false; LADInstance inst = (LADInstance) instances.instance(i); leastSquares = 0.0; for (int j=0; j<m_numOfClasses; j++) { temp1 = inst.wVector[j] * inst.zVector[j]; temp2 = temp1 * inst.zVector[j]; temp3 = inst.wVector[j] * inst.zVector[j]; term1L[j] += temp2; term2L[j] += temp1; term3L[j] += inst.wVector[j]; term1R[j] -= temp2; term2R[j] -= temp1; term3R[j] -= inst.wVector[j]; meanNumL[j] += temp3; meanNumR[j] -= temp3; if (newSplit) { meanL = meanNumL[j] / term3L[j]; meanR = meanNumR[j] / term3R[j]; leastSquares += term1L[j] - (2.0 * meanL * term2L[j]) + (meanL * meanL * term3L[j]); leastSquares += term1R[j] - (2.0 * meanR * term2R[j]) + (meanR * meanR * term3R[j]); } } if (m_Debug && newSplit) System.out.println(attIndex + "/" + ((instances.instance(i).value(attIndex) + instances.instance(i+1).value(attIndex)) / 2.0) + " = " + (allLS - leastSquares)); if (newSplit && leastSquares < smallestLeastSquares) { bestSplit = (instances.instance(i).value(attIndex) + instances.instance(i+1).value(attIndex)) / 2.0; smallestLeastSquares = leastSquares; } } double[] result = new double[2]; result[0] = bestSplit; result[1] = smallestLeastSquares > 0 ? smallestLeastSquares : 0; return result; } private double leastSquares(Instances instances) { double numerator=0, denominator=0, w, t; double[] classMeans = new double[m_numOfClasses]; double[] classTotals = new double[m_numOfClasses]; for (int i=0; i<instances.numInstances(); i++) { LADInstance inst = (LADInstance) instances.instance(i); for (int j=0; j<m_numOfClasses; j++) { classMeans[j] += inst.zVector[j] * inst.wVector[j]; classTotals[j] += inst.wVector[j]; } } double numInstances = (double) instances.numInstances(); for (int j=0; j<m_numOfClasses; j++) { if (classTotals[j] != 0) classMeans[j] /= classTotals[j]; } for (int i=0; i<instances.numInstances(); i++) for (int j=0; j<m_numOfClasses; j++) { LADInstance inst = (LADInstance) instances.instance(i); w = inst.wVector[j]; t = inst.zVector[j] - classMeans[j]; numerator += w * (t * t); denominator += w; } //System.out.println(numerator + " / " + denominator); return numerator > 0 ? numerator : 0;// / denominator; } private double leastSquaresNonMissing(Instances instances, int attIndex) { double numerator=0, denominator=0, w, t; double[] classMeans = new double[m_numOfClasses]; double[] classTotals = new double[m_numOfClasses]; for (int i=0; i<instances.numInstances(); i++) { LADInstance inst = (LADInstance) instances.instance(i); for (int j=0; j<m_numOfClasses; j++) { classMeans[j] += inst.zVector[j] * inst.wVector[j]; classTotals[j] += inst.wVector[j]; } } double numInstances = (double) instances.numInstances(); for (int j=0; j<m_numOfClasses; j++) { if (classTotals[j] != 0) classMeans[j] /= classTotals[j]; } for (int i=0; i<instances.numInstances(); i++) for (int j=0; j<m_numOfClasses; j++) { LADInstance inst = (LADInstance) instances.instance(i); if(!inst.isMissing(attIndex)) { w = inst.wVector[j]; t = inst.zVector[j] - classMeans[j]; numerator += w * (t * t); denominator += w; } } //System.out.println(numerator + " / " + denominator); return numerator > 0 ? numerator : 0;// / denominator; } private double[] calcPredictionValues(Instances instances) { double[] classMeans = new double[m_numOfClasses]; double meansSum = 0; double multiplier = ((double) (m_numOfClasses-1)) / ((double) (m_numOfClasses)); double[] classTotals = new double[m_numOfClasses]; for (int i=0; i<instances.numInstances(); i++) { LADInstance inst = (LADInstance) instances.instance(i); for (int j=0; j<m_numOfClasses; j++) { classMeans[j] += inst.zVector[j] * inst.wVector[j]; classTotals[j] += inst.wVector[j]; } } double numInstances = (double) instances.numInstances(); for (int j=0; j<m_numOfClasses; j++) { if (classTotals[j] != 0) classMeans[j] /= classTotals[j]; meansSum += classMeans[j]; } meansSum /= m_numOfClasses; for (int j=0; j<m_numOfClasses; j++) { classMeans[j] = multiplier * (classMeans[j] - meansSum); } return classMeans; } /** * Returns the class probability distribution for an instance. * * @param instance the instance to be classified * @return the distribution the tree generates for the instance */ public double[] distributionForInstance(Instance instance) { double[] predValues = new double[m_numOfClasses]; for (int i=0; i<m_numOfClasses; i++) predValues[i] = 0.0; double[] distribution = predictionValuesForInstance(instance, m_root, predValues); double max = distribution[Utils.maxIndex(distribution)]; for (int i=0; i<m_numOfClasses; i++) { distribution[i] = Math.exp(distribution[i] - max); } double sum = Utils.sum(distribution); if (sum > 0.0) Utils.normalize(distribution, sum); return distribution; } /** * Returns the class prediction values (votes) for an instance. * * @param inst the instance * @param currentNode the root of the tree to get the values from * @param currentValues the current values before adding the values contained in the * subtree * @return the class prediction values (votes) */ private double[] predictionValuesForInstance(Instance inst, PredictionNode currentNode, double[] currentValues) { double[] predValues = currentNode.getValues(); for (int i=0; i<m_numOfClasses; i++) currentValues[i] += predValues[i]; //for (int i=0; i<m_numOfClasses; i++) currentValues[i] = predValues[i]; for (Enumeration e = currentNode.children(); e.hasMoreElements(); ) { Splitter split = (Splitter) e.nextElement(); int branch = split.branchInstanceGoesDown(inst); if (branch >= 0) currentValues = predictionValuesForInstance(inst, split.getChildForBranch(branch), currentValues); } return currentValues; } /** model output functions ************************************************************/ /** * Returns a description of the classifier. * * @return a string containing a description of the classifier */ public String toString() { String className = getClass().getName(); if (m_root == null) return (className +" not built yet"); else { return (className + ":\n\n" + toString(m_root, 1) + "\nLegend: " + legend() + "\n#Tree size (total): " + numOfAllNodes(m_root) + "\n#Tree size (number of predictor nodes): " + numOfPredictionNodes(m_root) + "\n#Leaves (number of predictor nodes): " + numOfLeafNodes(m_root) + "\n#Expanded nodes: " + m_nodesExpanded + "\n#Processed examples: " + m_examplesCounted + "\n#Ratio e/n: " + ((double)m_examplesCounted/(double)m_nodesExpanded) ); } } /** * Traverses the tree, forming a string that describes it. * * @param currentNode the current node under investigation * @param level the current level in the tree * @return the string describing the subtree */ private String toString(PredictionNode currentNode, int level) { StringBuffer text = new StringBuffer(); text.append(": "); double[] predValues = currentNode.getValues(); for (int i=0; i<m_numOfClasses; i++) { text.append(Utils.doubleToString(predValues[i],3)); if (i<m_numOfClasses-1) text.append(","); } for (Enumeration e = currentNode.children(); e.hasMoreElements(); ) { Splitter split = (Splitter) e.nextElement(); for (int j=0; j<split.getNumOfBranches(); j++) { PredictionNode child = split.getChildForBranch(j); if (child != null) { text.append("\n"); for (int k = 0; k < level; k++) { text.append("| "); } text.append("(" + split.orderAdded + ")"); text.append(split.attributeString() + " " + split.comparisonString(j)); text.append(toString(child, level + 1)); } } } return text.toString(); } /** * Returns graph describing the tree. * * @return the graph of the tree in dotty format * @exception Exception if something goes wrong */ public String graph() throws Exception { StringBuffer text = new StringBuffer(); text.append("digraph ADTree {\n"); //text.append("center=true\nsize=\"8.27,11.69\"\n"); graphTraverse(m_root, text, 0, 0); return text.toString() +"}\n"; } /** * Traverses the tree, graphing each node. * * @param currentNode the currentNode under investigation * @param text the string built so far * @param splitOrder the order the parent splitter was added to the tree * @param predOrder the order this predictor was added to the split * @exception Exception if something goes wrong */ protected void graphTraverse(PredictionNode currentNode, StringBuffer text, int splitOrder, int predOrder) throws Exception { text.append("S" + splitOrder + "P" + predOrder + " [label=\""); double[] predValues = currentNode.getValues(); for (int i=0; i<m_numOfClasses; i++) { text.append(Utils.doubleToString(predValues[i],3)); if (i<m_numOfClasses-1) text.append(","); } if (splitOrder == 0) // show legend in root text.append(" (" + legend() + ")"); text.append("\" shape=box style=filled]\n"); for (Enumeration e = currentNode.children(); e.hasMoreElements(); ) { Splitter split = (Splitter) e.nextElement(); text.append("S" + splitOrder + "P" + predOrder + "->" + "S" + split.orderAdded + " [style=dotted]\n"); text.append("S" + split.orderAdded + " [label=\"" + split.orderAdded + ": " + split.attributeString() + "\"]\n"); for (int i=0; i<split.getNumOfBranches(); i++) { PredictionNode child = split.getChildForBranch(i); if (child != null) { text.append("S" + split.orderAdded + "->" + "S" + split.orderAdded + "P" + i + " [label=\"" + split.comparisonString(i) + "\"]\n"); graphTraverse(child, text, split.orderAdded, i); } } } } /** * Returns the legend of the tree, describing how results are to be interpreted. * * @return a string containing the legend of the classifier */ public String legend() { Attribute classAttribute = null; if (m_trainInstances == null) return ""; try {classAttribute = m_trainInstances.classAttribute();} catch (Exception x){}; if (m_numOfClasses == 1) { return ("-ve = " + classAttribute.value(0) + ", +ve = " + classAttribute.value(1)); } else { StringBuffer text = new StringBuffer(); for (int i=0; i<m_numOfClasses; i++) { if (i>0) text.append(", "); text.append(classAttribute.value(i)); } return text.toString(); } } /** option handling ******************************************************************/ /** * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String numOfBoostingIterationsTipText() { return "The number of boosting iterations to use, which determines the size of the tree."; } /** * Gets the number of boosting iterations. * * @return the number of boosting iterations */ public int getNumOfBoostingIterations() { return m_boostingIterations; } /** * Sets the number of boosting iterations. * * @param b the number of boosting iterations to use */ public void setNumOfBoostingIterations(int b) { m_boostingIterations = b; } /** * Returns an enumeration describing the available options. * * @return an enumeration of all the available options */ public Enumeration listOptions() { Vector newVector = new Vector(1); newVector.addElement(new Option( "\tNumber of boosting iterations.\n" +"\t(Default = 10)", "B", 1,"-B <number of boosting iterations>")); Enumeration enu = super.listOptions(); while (enu.hasMoreElements()) { newVector.addElement(enu.nextElement()); } return newVector.elements(); } /** * Parses a given list of options. Valid options are:<p> * * -B num <br> * Set the number of boosting iterations * (default 10) <p> * * @param options the list of options as an array of strings * @exception Exception if an option is not supported */ public void setOptions(String[] options) throws Exception { String bString = Utils.getOption('B', options); if (bString.length() != 0) setNumOfBoostingIterations(Integer.parseInt(bString)); super.setOptions(options); Utils.checkForRemainingOptions(options); } /** * Gets the current settings of ADTree. * * @return an array of strings suitable for passing to setOptions() */ public String[] getOptions() { String[] options = new String[2 + super.getOptions().length]; int current = 0; options[current++] = "-B"; options[current++] = "" + getNumOfBoostingIterations(); System.arraycopy(super.getOptions(), 0, options, current, super.getOptions().length); while (current < options.length) options[current++] = ""; return options; } /** additional measures ***************************************************************/ /** * Calls measure function for tree size. * * @return the tree size */ public double measureTreeSize() { return numOfAllNodes(m_root); } /** * Calls measure function for leaf size. * * @return the leaf size */ public double measureNumLeaves() { return numOfPredictionNodes(m_root); } /** * Calls measure function for leaf size. * * @return the leaf size */ public double measureNumPredictionLeaves() { return numOfLeafNodes(m_root); } /** * Returns the number of nodes expanded. * * @return the number of nodes expanded during search */ public double measureNodesExpanded() { return m_nodesExpanded; } /** * Returns the number of examples "counted". * * @return the number of nodes processed during search */ public double measureExamplesCounted() { return m_examplesCounted; } /** * Returns an enumeration of the additional measure names. * * @return an enumeration of the measure names */ public Enumeration enumerateMeasures() { Vector newVector = new Vector(5); newVector.addElement("measureTreeSize"); newVector.addElement("measureNumLeaves"); newVector.addElement("measureNumPredictionLeaves"); newVector.addElement("measureNodesExpanded"); newVector.addElement("measureExamplesCounted"); return newVector.elements(); } /** * Returns the value of the named measure. * * @param measureName the name of the measure to query for its value * @return the value of the named measure * @exception IllegalArgumentException if the named measure is not supported */ public double getMeasure(String additionalMeasureName) { if (additionalMeasureName.equals("measureTreeSize")) { return measureTreeSize(); } else if (additionalMeasureName.equals("measureNodesExpanded")) { return measureNodesExpanded(); } else if (additionalMeasureName.equals("measureNumLeaves")) { return measureNumLeaves(); } else if (additionalMeasureName.equals("measureNumPredictionLeaves")) { return measureNumPredictionLeaves(); } else if (additionalMeasureName.equals("measureExamplesCounted")) { return measureExamplesCounted(); } else {throw new IllegalArgumentException(additionalMeasureName + " not supported (ADTree)"); } } /** * Returns the number of prediction nodes in a tree. * * @param root the root of the tree being measured * @return tree size in number of prediction nodes */ protected int numOfPredictionNodes(PredictionNode root) { int numSoFar = 0; if (root != null) { numSoFar++; for (Enumeration e = root.children(); e.hasMoreElements(); ) { Splitter split = (Splitter) e.nextElement(); for (int i=0; i<split.getNumOfBranches(); i++) numSoFar += numOfPredictionNodes(split.getChildForBranch(i)); } } return numSoFar; } /** * Returns the number of leaf nodes in a tree. * * @param root the root of the tree being measured * @return tree leaf size in number of prediction nodes */ protected int numOfLeafNodes(PredictionNode root) { int numSoFar = 0; if (root.getChildren().size() > 0) { for (Enumeration e = root.children(); e.hasMoreElements(); ) { Splitter split = (Splitter) e.nextElement(); for (int i=0; i<split.getNumOfBranches(); i++) numSoFar += numOfLeafNodes(split.getChildForBranch(i)); } } else numSoFar = 1; return numSoFar; } /** * Returns the total number of nodes in a tree. * * @param root the root of the tree being measured * @return tree size in number of splitter + prediction nodes */ protected int numOfAllNodes(PredictionNode root) { int numSoFar = 0; if (root != null) { numSoFar++; for (Enumeration e = root.children(); e.hasMoreElements(); ) { numSoFar++; Splitter split = (Splitter) e.nextElement(); for (int i=0; i<split.getNumOfBranches(); i++) numSoFar += numOfAllNodes(split.getChildForBranch(i)); } } return numSoFar; } /** main functions ********************************************************************/ /** * Builds a classifier for a set of instances. * * @param instances the instances to train the classifier with * @exception Exception if something goes wrong */ public void buildClassifier(Instances instances) throws Exception { // set up the tree initClassifier(instances); // build the tree for (int T = 0; T < m_boostingIterations; T++) { boost(); } } public int predictiveError(Instances test) { int error = 0; for(int i = test.numInstances()-1; i>=0; i--) { Instance inst = test.instance(i); try { if (classifyInstance(inst) != inst.classValue()) error++; } catch (Exception e) { error++;} } return error; } /** * Merges two trees together. Modifies the tree being acted on, leaving tree passed * as a parameter untouched (cloned). Does not check to see whether training instances * are compatible - strange things could occur if they are not. * * @param mergeWith the tree to merge with * @exception Exception if merge could not be performed */ public void merge(LADTree mergeWith) throws Exception { if (m_root == null || mergeWith.m_root == null) throw new Exception("Trying to merge an uninitialized tree"); if (m_numOfClasses != mergeWith.m_numOfClasses) throw new Exception("Trees not suitable for merge - " + "different sized prediction nodes"); m_root.merge(mergeWith.m_root); } /** * Returns the type of graph this classifier * represents. * @return Drawable.TREE */ public int graphType() { return Drawable.TREE; } /** * Returns the revision string. * * @return the revision */ public String getRevision() { return RevisionUtils.extract("$Revision: 6036 $"); } /** * Returns default capabilities of the classifier. * * @return the capabilities of this classifier */ public Capabilities getCapabilities() { Capabilities result = super.getCapabilities(); result.disableAll(); // attributes result.enable(Capability.NOMINAL_ATTRIBUTES); result.enable(Capability.NUMERIC_ATTRIBUTES); result.enable(Capability.DATE_ATTRIBUTES); result.enable(Capability.MISSING_VALUES); // class result.enable(Capability.NOMINAL_CLASS); result.enable(Capability.MISSING_CLASS_VALUES); return result; } /** * Main method for testing this class. * * @param argv the options */ public static void main(String [] argv) { runClassifier(new LADTree(), argv); } }