/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept. This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit). http://www.cs.umass.edu/~mccallum/mallet This software is provided under the terms of the Common Public License, version 1.0, as published by http://www.opensource.org. For further information, see the file `LICENSE' included with this distribution. */ package cc.mallet.classify; import java.io.Serializable; import java.util.logging.Logger; import cc.mallet.pipe.Pipe; import cc.mallet.types.Alphabet; import cc.mallet.types.AugmentableFeatureVector; import cc.mallet.types.FeatureSelection; import cc.mallet.types.FeatureVector; import cc.mallet.types.InfoGain; import cc.mallet.types.Instance; import cc.mallet.types.InstanceList; import cc.mallet.types.Labeling; import cc.mallet.util.MalletLogger; /** Decision Tree classifier. @author Andrew McCallum <a href="mailto:mccallum@cs.umass.edu">mccallum@cs.umass.edu</a> */ public class DecisionTree extends Classifier implements Serializable //implements InduceFeatures { private static final long serialVersionUID = 1L; private static Logger logger = MalletLogger.getLogger(DecisionTree.class.getName()); Node root; public DecisionTree (Pipe instancePipe, DecisionTree.Node root) { super (instancePipe); this.root = root; } public Node getRoot () { return root; } private Node getLeaf (Node node, FeatureVector fv) { if (node.child0 == null) return node; else if (fv.value (node.featureIndex) != 0) return getLeaf (node.child1, fv); else return getLeaf (node.child0, fv); } public Classification classify (Instance instance) { FeatureVector fv = (FeatureVector) instance.getData (); assert (instancePipe == null || fv.getAlphabet () == this.instancePipe.getDataAlphabet ()); Node leaf = getLeaf (root, fv); return new Classification (instance, this, leaf.labeling); } // Entropy of 1.0 would say that it take "one bit" to indicate the correct class, // e.g. that there is a 50/50 split between two classes given a particular feature public double addFeaturesClassEntropyThreshold = 0.7; public void induceFeatures (InstanceList ilist, boolean withFeatureShrinkage, boolean inducePerClassFeatures) { if (inducePerClassFeatures) { int numClasses = ilist.getTargetAlphabet().size(); // int numFeatures = ilist.getDataAlphabet().size(); FeatureSelection[] pcfs = new FeatureSelection[numClasses]; for (int j = 0; j < numClasses; j++) pcfs[j] = (FeatureSelection) ilist.getPerLabelFeatureSelection()[j].clone(); for (int i = 0; i < ilist.size(); i++) { Object data = ilist.get(i).getData(); AugmentableFeatureVector afv = (AugmentableFeatureVector) data; root.induceFeatures (afv, null, pcfs, ilist.getFeatureSelection(), ilist.getPerLabelFeatureSelection(), withFeatureShrinkage, inducePerClassFeatures, addFeaturesClassEntropyThreshold); } } else { throw new UnsupportedOperationException ("Not yet implemented"); } } public static class Node implements Serializable { private static final long serialVersionUID = 1L; int featureIndex; // the feature on which the children (would) distinguish double infoGain; // the information gain of splitting on this feature InstanceList ilist; Alphabet dictionary; double labelEntropy; // the class label entropy of data in this (unsplit) node Labeling labeling; // the class label distribution in the node (unsplit) Node parent, child0, child1; String name; // xxx Also calculate some sort of inverted entropy for feature induction, // in order to find the one class needs a new feature with a negative weight. public Node (InstanceList ilist, Node parent, FeatureSelection fs) { InfoGain ig = new InfoGain (ilist); this.featureIndex = ig.getMaxValuedIndexIn (fs); this.infoGain = ig.value(featureIndex); this.ilist = ilist; this.dictionary = ilist.getDataAlphabet(); this.parent = parent; this.labeling = ig.getBaseLabelDistribution(); this.labelEntropy = ig.getBaseEntropy(); this.child0 = this.child1 = null; } /** The root has depth zero. */ public int depth () { int depth = 0; Node p = parent; while (p != null) { p = p.parent; depth++; } return depth; } public boolean isLeaf () { return (child0 == null && child1 == null); } public boolean isRoot () { return parent == null; } public Node getFeatureAbsentChild () { return child0; } public Node getFeaturePresentChild () { return child1; } public double getSplitInfoGain () { return infoGain; } public Object getSplitFeature () { return ilist.getDataAlphabet().lookupObject(featureIndex); } public void split (FeatureSelection fs) { if (ilist == null) throw new IllegalStateException ("Frozen. Cannot split."); InstanceList ilist0 = new InstanceList (ilist.getPipe()); InstanceList ilist1 = new InstanceList (ilist.getPipe()); for (int i = 0; i < ilist.size(); i++) { Instance instance = ilist.get(i); FeatureVector fv = (FeatureVector) instance.getData (); // xxx What test should this be? What to do with negative values? // Whatever is decided here should also go in InfoGain.calcInfoGains() if (fv.value (featureIndex) != 0) { //System.out.println ("list1 add "+instance.getUri()+" weight="+ilist.getInstanceWeight(i)); ilist1.add (instance, ilist.getInstanceWeight(i)); } else { //System.out.println ("list0 add "+instance.getUri()+" weight="+ilist.getInstanceWeight(i)); ilist0.add (instance, ilist.getInstanceWeight(i)); } } logger.info("child0="+ilist0.size()+" child1="+ilist1.size()); child0 = new Node (ilist0, this, fs); child1 = new Node (ilist1, this, fs); } // Saves memory by allowing ilist to be garbage collected public void stopGrowth () { if (child0 != null) { child0.stopGrowth(); child1.stopGrowth(); } ilist = null; } public void induceFeatures (AugmentableFeatureVector afv, FeatureSelection featuresAlreadyThere, FeatureSelection[] perClassFeaturesAlreadyThere, FeatureSelection newFeatureSelection, FeatureSelection[] perClassNewFeatureSelection, boolean withInteriorNodes, boolean addPerClassFeatures, double classEntropyThreshold) { if (!isRoot() && (isLeaf() || withInteriorNodes) && labelEntropy < classEntropyThreshold) { String name = getName(); logger.info("Trying to add feature "+name); //int conjunctionIndex = afv.getAlphabet().lookupIndex (name, false); if (addPerClassFeatures) { int classIndex = labeling.getBestIndex(); if (!perClassFeaturesAlreadyThere[classIndex].contains (name)) { afv.add (name, 1.0); perClassNewFeatureSelection[classIndex].add (name); } } else { throw new UnsupportedOperationException ("Not yet implemented."); } } boolean featurePresent = afv.value (featureIndex) != 0; if (child0 != null && !featurePresent) child0.induceFeatures (afv, featuresAlreadyThere, perClassFeaturesAlreadyThere, newFeatureSelection, perClassNewFeatureSelection, withInteriorNodes, addPerClassFeatures, classEntropyThreshold); if (child1 != null && featurePresent) child1.induceFeatures (afv, featuresAlreadyThere, perClassFeaturesAlreadyThere, newFeatureSelection, perClassNewFeatureSelection, withInteriorNodes, addPerClassFeatures, classEntropyThreshold); } public String getName () { // String prefix; if (parent == null) return "root"; else if (parent.parent == null) { if (parent.getFeaturePresentChild() == this) return dictionary.lookupObject(parent.featureIndex).toString(); else { assert (dictionary != null); assert (dictionary.lookupObject(parent.featureIndex) != null); return "!" + dictionary.lookupObject(parent.featureIndex).toString(); } } else { if (parent.getFeaturePresentChild() == this) return parent.getName() + "&" + dictionary.lookupObject(parent.featureIndex).toString(); else return parent.getName() + "&!" + dictionary.lookupObject(parent.featureIndex).toString(); } } public void print () { if (child0 == null) System.out.println (getName() + ": " + labeling.getBestLabel()); else { child0.print(); child1.print(); } } } }