/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept. This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit). http://www.cs.umass.edu/~mccallum/mallet This software is provided under the terms of the Common Public License, version 1.0, as published by http://www.opensource.org. For further information, see the file `LICENSE' included with this distribution. */ package cc.mallet.types; import java.awt.geom.Point2D; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.Comparator; import java.util.Hashtable; import java.util.Iterator; import java.util.logging.Logger; import cc.mallet.util.MalletLogger; import cc.mallet.util.Maths; /** * List of features along with their thresholds sorted in descending order of * the ratio of (1) information gained by splitting instances on the * feature at its associated threshold value, to (2) the split information.<p> * * The calculations performed do not take into consideration the instance weights.<p> * * To create an instance of GainRatio from an InstanceList, one must do the following:<p><tt> * * InstanceList ilist = ... * ... * GainRatio gr = GainRatio.createGainRatio(ilist); * </tt><p> * * J. R. Quinlan * "Improved Use of Continuous Attributes in C4.5" * ftp://ftp.cs.cmu.edu/project/jair/volume4/quinlan96a.ps * * @author Gary Huang <a href="mailto:ghuang@cs.umass.edu">ghuang@cs.umass.edu</a> */ public class GainRatio extends RankedFeatureVector { private static final Logger logger = MalletLogger.getLogger (GainRatio.class.getName ()); private static final long serialVersionUID = 1L; public static final double log2 = Math.log(2); double[] m_splitPoints; double m_baseEntropy; LabelVector m_baseLabelDistribution; int m_numSplitPointsForBestFeature; int m_minNumInsts; /** * Calculates gain ratios for all (feature, split point) pairs * snd returns array of:<pre> * 1. gain ratios (each element is the max gain ratio of a feature * for those split points with at least average gain) * 2. the optimal split point for each feature * 3. the overall entropy * 4. the overall label distribution of the given instances * 5. the number of split points of the split feature. * </pre> */ protected static Object[] calcGainRatios(InstanceList ilist, int[] instIndices, int minNumInsts) { int numInsts = instIndices.length; Alphabet dataDict = ilist.getDataAlphabet(); LabelAlphabet targetDict = (LabelAlphabet) ilist.getTargetAlphabet(); double[] targetCounts = new double[targetDict.size()]; // Accumulate target label counts and make sure // the sum of each instance's target label is 1 for (int ii = 0; ii < numInsts; ii++) { Instance inst = ilist.get(instIndices[ii]); Labeling labeling = inst.getLabeling(); double labelWeightSum = 0; for (int ll = 0; ll < labeling.numLocations(); ll++) { int li = labeling.indexAtLocation(ll); double labelWeight = labeling.valueAtLocation(ll); labelWeightSum += labelWeight; targetCounts[li] += labelWeight; } assert(Maths.almostEquals(labelWeightSum, 1)); } // Calculate the base entropy Info(D) and the the // label distribution of the given instances double[] targetDistribution = new double[targetDict.size()]; double baseEntropy = 0; for (int ci = 0; ci < targetDict.size(); ci++) { double p = targetCounts[ci] / numInsts; targetDistribution[ci] = p; if (p > 0) baseEntropy -= p * Math.log(p) / log2; } LabelVector baseLabelDistribution = new LabelVector(targetDict, targetDistribution); double infoGainSum = 0; int totalNumSplitPoints = 0; double[] passTestTargetCounts = new double[targetDict.size()]; // Maps feature index -> Hashtable, and each table // maps (split point) -> (info gain, split ratio) Hashtable[] featureToInfo = new Hashtable[dataDict.size()]; // Go through each feature's split points in ascending order for (int fi = 0; fi < dataDict.size(); fi++) { if ((fi+1) % 1000 == 0) logger.info("at feature " + (fi+1) + " / " + dataDict.size()); featureToInfo[fi] = new Hashtable(); Arrays.fill(passTestTargetCounts, 0); // Sort instances on this feature's values instIndices = sortInstances(ilist, instIndices, fi); // Iterate through the sorted instances for (int ii = 0; ii < numInsts-1; ii++) { Instance inst = ilist.get(instIndices[ii]); Instance instPlusOne = ilist.get(instIndices[ii+1]); FeatureVector fv1 = (FeatureVector) inst.getData(); FeatureVector fv2 = (FeatureVector) instPlusOne.getData(); double lower = fv1.value(fi); double higher = fv2.value(fi); // Accumulate the label weights for instances passing the test Labeling labeling = inst.getLabeling(); for (int ll = 0; ll < labeling.numLocations(); ll++) { int li = labeling.indexAtLocation(ll); double labelWeight = labeling.valueAtLocation(ll); passTestTargetCounts[li] += labelWeight; } if (Maths.almostEquals(lower, higher) || inst.getLabeling().toString().equals(instPlusOne.getLabeling().toString())) continue; // For this (feature, spilt point) pair, calculate the // info gain of using this pair to split insts into those // with value of feature <= p versus > p totalNumSplitPoints++; double splitPoint = (lower + higher) / 2; double numPassInsts = ii+1; // If this split point creates a partition // with too few instances, ignore it double numFailInsts = numInsts - numPassInsts; if (numPassInsts < minNumInsts || numFailInsts < minNumInsts) continue; // If all instances pass or fail this test, it is useless double passProportion = numPassInsts / numInsts; if (Maths.almostEquals(passProportion, 0) || Maths.almostEquals(passProportion, 1)) continue; // Calculate the entropy of instances passing and failing the test double passEntropy = 0; double failEntropy = 0; double p; for (int ci = 0; ci < targetDict.size(); ci++) { if (numPassInsts > 0) { p = passTestTargetCounts[ci] / numPassInsts; if (p > 0) passEntropy -= p * Math.log(p) / log2; } if (numFailInsts > 0) { double failTestTargetCount = targetCounts[ci] - passTestTargetCounts[ci]; p = failTestTargetCount / numFailInsts; if (p > 0) failEntropy -= p * Math.log(p) / log2; } } // Calculate Gain(D, T), the information gained // by testing on this (feature, split-point) pair double gainDT = baseEntropy - passProportion * passEntropy - (1-passProportion) * failEntropy; infoGainSum += gainDT; // Calculate Split(D, T), the split information double splitDT = - passProportion * Math.log(passProportion) / log2 - (1-passProportion) * Math.log(1-passProportion) / log2; // Calculate the gain ratio double gainRatio = gainDT / splitDT; featureToInfo[fi].put(new Double(splitPoint), new Point2D.Double(gainDT, gainRatio)); } // End loop through sorted instances } // End loop through features // For each feature's split point with at least average gain, // get the maximum gain ratio and the associated split point // (using the info gain as tie breaker) double[] gainRatios = new double[dataDict.size()]; double[] splitPoints = new double[dataDict.size()]; int numSplitsForBestFeature = 0; // If all feature vectors are identical or no splits are worthy, return all 0s if (totalNumSplitPoints == 0 || Maths.almostEquals(infoGainSum, 0)) return new Object[] {gainRatios, splitPoints, new Double(baseEntropy), baseLabelDistribution, new Integer(numSplitsForBestFeature)}; double avgInfoGain = infoGainSum / totalNumSplitPoints; double maxGainRatio = 0; double gainForMaxGainRatio = 0; // tie breaker int xxx = 0; for (int fi = 0; fi < dataDict.size(); fi++) { double featureMaxGainRatio = 0; double featureGainForMaxGainRatio = 0; double bestSplitPoint = Double.NaN; for (Iterator iter = featureToInfo[fi].keySet().iterator(); iter.hasNext(); ) { Object key = iter.next(); Point2D.Double pt = (Point2D.Double) featureToInfo[fi].get(key); double splitPoint = ((Double) key).doubleValue(); double infoGain = pt.getX(); double gainRatio = pt.getY(); if (infoGain >= avgInfoGain) { if (gainRatio > featureMaxGainRatio || (gainRatio == featureMaxGainRatio && infoGain > featureGainForMaxGainRatio)) { featureMaxGainRatio = gainRatio; featureGainForMaxGainRatio = infoGain; bestSplitPoint = splitPoint; } } else xxx++; } assert(bestSplitPoint != Double.NaN); gainRatios[fi] = featureMaxGainRatio; splitPoints[fi] = bestSplitPoint; if (featureMaxGainRatio > maxGainRatio || (featureMaxGainRatio == maxGainRatio && featureGainForMaxGainRatio > gainForMaxGainRatio)) { maxGainRatio = featureMaxGainRatio; gainForMaxGainRatio = featureGainForMaxGainRatio; numSplitsForBestFeature = featureToInfo[fi].size(); } } logger.info("label distrib:\n" + baseLabelDistribution); logger.info("base entropy=" + baseEntropy + ", info gain sum=" + infoGainSum + ", total num split points=" + totalNumSplitPoints + ", avg info gain=" + avgInfoGain + ", num splits with < avg gain=" + xxx); return new Object[] {gainRatios, splitPoints, new Double(baseEntropy), baseLabelDistribution, new Integer(numSplitsForBestFeature)}; } public static int[] sortInstances(InstanceList ilist, int[] instIndices, int featureIndex) { ArrayList list = new ArrayList(); for (int ii = 0; ii < instIndices.length; ii++) { Instance inst = ilist.get(instIndices[ii]); FeatureVector fv = (FeatureVector) inst.getData(); list.add(new Point2D.Double(instIndices[ii], fv.value(featureIndex))); } Collections.sort(list, new Comparator() { public int compare(Object o1, Object o2) { Point2D.Double p1 = (Point2D.Double) o1; Point2D.Double p2 = (Point2D.Double) o2; if (p1.y == p2.y) { assert(p1.x != p2.x); return p1.x > p2.x ? 1 : -1; } else return p1.y > p2.y ? 1 : -1; } }); int[] sorted = new int[instIndices.length]; for (int i = 0; i < list.size(); i++) sorted[i] = (int) ((Point2D.Double) list.get(i)).getX(); return sorted; } /** * Constructs a GainRatio object. */ public static GainRatio createGainRatio(InstanceList ilist) { int[] instIndices = new int[ilist.size()]; for (int ii = 0; ii < instIndices.length; ii++) instIndices[ii] = ii; return createGainRatio(ilist, instIndices, 2); } /** * Constructs a GainRatio object */ public static GainRatio createGainRatio(InstanceList ilist, int[] instIndices, int minNumInsts) { Object[] objs = calcGainRatios(ilist, instIndices, minNumInsts); double[] gainRatios = (double[]) objs[0]; double[] splitPoints = (double[]) objs[1]; double baseEntropy = ((Double) objs[2]).doubleValue(); LabelVector baseLabelDistribution = (LabelVector) objs[3]; int numSplitPointsForBestFeature = ((Integer) objs[4]).intValue(); return new GainRatio(ilist.getDataAlphabet(), gainRatios, splitPoints, baseEntropy, baseLabelDistribution, numSplitPointsForBestFeature, minNumInsts); } protected GainRatio(Alphabet dataAlphabet, double[] gainRatios, double[] splitPoints, double baseEntropy, LabelVector baseLabelDistribution, int numSplitPointsForBestFeature, int minNumInsts) { super (dataAlphabet, gainRatios); m_splitPoints = splitPoints; m_baseEntropy = baseEntropy; m_baseLabelDistribution = baseLabelDistribution; m_numSplitPointsForBestFeature = numSplitPointsForBestFeature; m_minNumInsts = minNumInsts; } /** * @return the threshold of the (feature, threshold) * pair with with maximum gain ratio */ public double getMaxValuedThreshold() { return getThresholdAtRank(0); } /** * @return the threshold of the (feature, threshold) * pair with the given rank */ public double getThresholdAtRank(int rank) { int index = getIndexAtRank(rank); return m_splitPoints[index]; } public double getBaseEntropy () { return m_baseEntropy; } public LabelVector getBaseLabelDistribution () { return m_baseLabelDistribution; } public int getNumSplitPointsForBestFeature() { return m_numSplitPointsForBestFeature; } }