/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept. This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit). http://www.cs.umass.edu/~mccallum/mallet This software is provided under the terms of the Common Public License, version 1.0, as published by http://www.opensource.org. For further information, see the file `LICENSE' included with this distribution. */ package cc.mallet.classify; import java.io.IOException; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.io.Serializable; import java.util.logging.Logger; import cc.mallet.classify.Boostable; import cc.mallet.classify.Classification; import cc.mallet.classify.Classifier; import cc.mallet.pipe.Pipe; import cc.mallet.types.Alphabet; import cc.mallet.types.FeatureVector; import cc.mallet.types.GainRatio; import cc.mallet.types.Instance; import cc.mallet.types.InstanceList; import cc.mallet.util.MalletLogger; import cc.mallet.util.Maths; /** * A C4.5 Decision Tree classifier. * * @see C45Trainer * @author Gary Huang <a href="mailto:ghuang@cs.umass.edu">ghuang@cs.umass.edu</a> */ public class C45 extends Classifier implements Boostable, Serializable { private static Logger logger = MalletLogger.getLogger(C45.class.getName()); Node m_root; public C45 (Pipe instancePipe, C45.Node root) { super (instancePipe); m_root = root; } public Node getRoot () { return m_root; } private Node getLeaf (Node node, FeatureVector fv) { if (node.getLeftChild() == null && node.getRightChild() == null) return node; else if (fv.value(node.getGainRatio().getMaxValuedIndex()) <= node.getGainRatio().getMaxValuedThreshold()) return getLeaf(node.getLeftChild(), fv); else return getLeaf(node.getRightChild(), fv); } public Classification classify (Instance instance) { FeatureVector fv = (FeatureVector) instance.getData (); assert (instancePipe == null || fv.getAlphabet () == this.instancePipe.getDataAlphabet ()); Node leaf = getLeaf(m_root, fv); return new Classification (instance, this, leaf.getGainRatio().getBaseLabelDistribution()); } /** * Prune the tree using minimum description length */ public void prune() { getRoot().computeCostAndPrune(); } /** * @return the total number of nodes in this tree */ public int getSize() { Node root = getRoot(); if (root == null) return 0; return 1+root.getNumDescendants(); } /** * Prints the tree */ public void print() { if (getRoot() != null) getRoot().print(); } public static class Node implements Serializable { private static final long serialVersionUID = 1L; GainRatio m_gainRatio; // the entire set of instances given to the root node InstanceList m_ilist; // indices of instances at this node int[] m_instIndices; // data vocabulary Alphabet m_dataDict; // mininum number of instances allowed in this node int m_minNumInsts; Node m_parent, m_leftChild, m_rightChild; public Node(InstanceList ilist, Node parent, int minNumInsts) { this(ilist, parent, minNumInsts, null); } public Node(InstanceList ilist, Node parent, int minNumInsts, int[] instIndices) { if (instIndices == null) { instIndices = new int[ilist.size()]; for (int ii = 0; ii < instIndices.length; ii++) instIndices[ii] = ii; } m_gainRatio = GainRatio.createGainRatio(ilist, instIndices, minNumInsts); m_ilist = ilist; m_instIndices = instIndices; m_dataDict = m_ilist.getDataAlphabet(); m_minNumInsts = minNumInsts; m_parent = parent; m_leftChild = m_rightChild = null; } /** The root has depth zero. */ public int depth () { int depth = 0; Node p = m_parent; while (p != null) { p = p.m_parent; depth++; } return depth; } public int getSize() { return m_instIndices.length; } public boolean isLeaf() { return (m_leftChild == null && m_rightChild == null); } public boolean isRoot() { return m_parent == null; } public Node getParent() { return m_parent; } public Node getLeftChild() { return m_leftChild; } public Node getRightChild() { return m_rightChild; } public GainRatio getGainRatio() { return m_gainRatio; } public Object getSplitFeature() { return m_dataDict.lookupObject(m_gainRatio.getMaxValuedIndex()); } public InstanceList getInstances() { InstanceList ret = new InstanceList(m_ilist.getPipe()); for (int ii = 0; ii < m_instIndices.length; ii++) ret.add(m_ilist.get(m_instIndices[ii])); return ret; } /** * Count the number of non-leaf descendant nodes */ public int getNumDescendants() { if (isLeaf()) return 0; int count = 0; if (! getLeftChild().isLeaf()) count += 1 + getLeftChild().getNumDescendants(); if (! getRightChild().isLeaf()) count += 1 + getRightChild().getNumDescendants(); return count; } public void split() { if (m_ilist == null) throw new IllegalStateException ("Frozen. Cannot split."); int numLeftChildren = 0; boolean[] toLeftChild = new boolean[m_instIndices.length]; for (int i = 0; i < m_instIndices.length; i++) { Instance instance = m_ilist.get(m_instIndices[i]); FeatureVector fv = (FeatureVector) instance.getData(); if (fv.value (m_gainRatio.getMaxValuedIndex()) <= m_gainRatio.getMaxValuedThreshold()) { toLeftChild[i] = true; numLeftChildren++; } else toLeftChild[i] = false; } logger.info("leftChild.size=" + numLeftChildren + " rightChild.size=" + (m_instIndices.length-numLeftChildren)); int[] leftIndices = new int[numLeftChildren]; int[] rightIndices = new int[m_instIndices.length - numLeftChildren]; int li = 0, ri = 0; for (int i = 0; i < m_instIndices.length; i++) { if (toLeftChild[i]) leftIndices[li++] = m_instIndices[i]; else rightIndices[ri++] = m_instIndices[i]; } m_leftChild = new Node(m_ilist, this, m_minNumInsts, leftIndices); m_rightChild = new Node(m_ilist, this, m_minNumInsts, rightIndices); } public double computeCostAndPrune() { double costS = getMDL(); if (isLeaf()) return costS + 1; double minCost1 = getLeftChild().computeCostAndPrune(); double minCost2 = getRightChild().computeCostAndPrune(); double costSplit = Math.log(m_gainRatio.getNumSplitPointsForBestFeature()) / GainRatio.log2; double minCostN = Math.min(costS+1, costSplit+1+minCost1+minCost2); if (Maths.almostEquals(minCostN, costS+1)) m_leftChild = m_rightChild = null; return minCostN; } /** * Calculates the minimum description length of this node, i.e., * the length of the binary encoding that describes the feature * and the split value used at this node */ public double getMDL() { int numClasses = m_ilist.getTargetAlphabet().size(); double mdl = getSize() * getGainRatio().getBaseEntropy(); mdl += ((numClasses-1) * Math.log(getSize() / 2.0)) / (2 * GainRatio.log2); double piPow = Math.pow(Math.PI, numClasses/2.0); double gammaVal = Maths.gamma(numClasses/2.0); mdl += Math.log(piPow/gammaVal) / GainRatio.log2; return mdl; } /** * Saves memory by allowing ilist to be garbage collected * (deletes this node's associated instance list) */ public void stopGrowth () { if (m_leftChild != null) m_leftChild.stopGrowth(); if (m_rightChild != null) m_rightChild.stopGrowth(); m_ilist = null; } public String getName() { return getStringBufferName().toString(); } public StringBuffer getStringBufferName() { StringBuffer sb = new StringBuffer(); if (m_parent == null) return sb.append("root"); else if (m_parent.getParent() == null) { sb.append("(\""); sb.append(m_dataDict.lookupObject(m_parent.getGainRatio().getMaxValuedIndex()).toString()); sb.append("\""); if (m_parent.getLeftChild() == this) sb.append(" <= "); else sb.append(" > "); sb.append(m_parent.getGainRatio().getMaxValuedThreshold()); return sb.append(")"); } else { sb.append(m_parent.getStringBufferName()); sb.append(" && (\""); sb.append(m_dataDict.lookupObject(m_parent.getGainRatio().getMaxValuedIndex()).toString()); sb.append("\""); if (m_parent.getLeftChild() == this) sb.append(" <= "); else sb.append(" > "); sb.append(m_parent.getGainRatio().getMaxValuedThreshold()); return sb.append(")"); } } /** * Prints the tree rooted at this node */ public void print() { print(""); } public void print(String prefix) { if (isLeaf()) { int bestLabelIndex = getGainRatio().getBaseLabelDistribution().getBestIndex(); int numMajorityLabel = (int) (getGainRatio().getBaseLabelDistribution().value(bestLabelIndex) * getSize()); System.out.println("root:" + getGainRatio().getBaseLabelDistribution().getBestLabel() + " " + numMajorityLabel + "/" + getSize()); } else { String featName = m_dataDict.lookupObject(getGainRatio().getMaxValuedIndex()).toString(); double threshold = getGainRatio().getMaxValuedThreshold(); System.out.print(prefix + "\"" + featName + "\" <= " + threshold + ":"); if (m_leftChild.isLeaf()) { int bestLabelIndex = m_leftChild.getGainRatio().getBaseLabelDistribution().getBestIndex(); int numMajorityLabel = (int) (m_leftChild.getGainRatio().getBaseLabelDistribution().value(bestLabelIndex) * m_leftChild.getSize()); System.out.println(m_leftChild.getGainRatio().getBaseLabelDistribution().getBestLabel() + " " + numMajorityLabel + "/" + m_leftChild.getSize()); } else { System.out.println(); m_leftChild.print(prefix + "| "); } System.out.print(prefix + "\"" + featName + "\" > " + threshold + ":"); if (m_rightChild.isLeaf()) { int bestLabelIndex = m_rightChild.getGainRatio().getBaseLabelDistribution().getBestIndex(); int numMajorityLabel = (int) (m_rightChild.getGainRatio().getBaseLabelDistribution().value(bestLabelIndex) * m_rightChild.getSize()); System.out.println(m_rightChild.getGainRatio().getBaseLabelDistribution().getBestLabel() + " " + numMajorityLabel + "/" + m_rightChild.getSize()); } else { System.out.println(); m_rightChild.print(prefix + "| "); } } } } // Serialization // serialVersionUID is overriden to prevent innocuous changes in this // class from making the serialization mechanism think the external // format has changed. private static final long serialVersionUID = 1; private static final int CURRENT_SERIAL_VERSION = 1; private void writeObject(ObjectOutputStream out) throws IOException { out.writeInt(CURRENT_SERIAL_VERSION); out.writeObject(getInstancePipe()); out.writeObject(m_root); } private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException { int version = in.readInt(); if (version != CURRENT_SERIAL_VERSION) throw new ClassNotFoundException("Mismatched C45 versions: wanted " + CURRENT_SERIAL_VERSION + ", got " + version); instancePipe = (Pipe) in.readObject(); m_root = (Node) in.readObject(); } }