package LBJ2.classify; import java.io.PrintStream; import java.util.Arrays; import java.util.Comparator; import java.util.Date; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.Map; import LBJ2.parse.Parser; import LBJ2.util.ClassUtils; import LBJ2.util.TableFormat; import LBJ2.learn.Learner; import LBJ2.learn.Lexicon; /** * This class is a program that can evaluate any <code>Classifier</code> * against an oracle <code>Classifier</code> on the objects returned from a * <code>Parser</code>. * * <p> Usage: * <blockquote> * <code> * java LBJ2.classify.TestDiscrete [-t <n>] <classifier> * <oracle> <parser> * <input file> [<null label> * [<null label> ...]] * </code> * </blockquote> * * <p> <b>Options:</b> The <code>-t <n></code> option is similar to the * LBJ compiler's command line option of the same name. When * <code><n></code> is greater than 0, a time stamp is printed to * <code>STDOUT</code> after every <code><n></code> examples are * processed. * * <p> <b>Input:</b> The first three command line parameters are fully * qualified class names, e.g. <code>myPackage.myClassifier</code>. * Next, <code><input file></code> is passed (as a <code>String</code>) * to the constructor of <code><parser></code>. The optional parameter * <code><null label></code> identifies one of the possible labels * produced by <code><oracle></code> as representing "no * classification". It is used during the computation of overall precision, * recall, and F1 scores. Finally, it is also assumed that * <code><classifier></code> is discrete, and that its * <code>discreteValue(Object)</code> method is implemented. * * <p> <b>Output:</b> First some timing information is presented. The first * time reported is the time taken to load the specified classifier's Java * class into memory. This reflects the time taken for LBJ to load the * classifier's internal representation <b>if</b> the classifier does * <b>not</b> make use of the <code>cachedin</code> keyword. Next, the time * taken to evaluate the first example is reported. It isn't particularly * informative unless the classifier <b>does</b> make use of the * <code>cachedin</code> keyword. In this case, it reflects the time LBJ * takes to load the classifier's internal representation better than the * first time reported. Finally, the average time taken to execute the * classifier's <code>discreteValue(Object)</code> method is reported. * * <p> After the timing information, an ASCII table is written to * <code>STDOUT</code> reporting precision, recall, and F<sub>1</sub> scores * itemized by the values that either the classifier or the oracle produced * during the test. The two rightmost columns are named * <code>"LCount"</code> and <code>"PCount"</code> (standing for "labeled * count" and "predicted count" respectively), and they report the number of * times the oracle produced each label and the number of times the * classifier predicted each label respectively. If a "null label" is * specified, overall precision, recall, and F<sub>1</sub> scores and a total * count of non-null-labeled examples are reported at the bottom of the * table. In the last row, whether a "null label" is specified or not, * overall accuracy is reported in the precision column. In the count * column, the total number of predictions (or labels, equivalently) is * reported. **/ public class TestDiscrete { /** References the classifier that is to be tested. */ private static Classifier classifier; /** References the oracle classifier to test against. */ private static Classifier oracle; /** References the parser supplying the testing objects. */ private static Parser parser; /** The number of examples processed in between time stamp messages. */ private static int outputGranularity; /** * The entry point of this program. * * @param args The command line parameters. **/ public static void main(String[] args) { long totalTime = -System.currentTimeMillis(); TestDiscrete tester = instantiate(args); totalTime += System.currentTimeMillis(); System.out.println("Classifier loaded in " + (totalTime / 1000.0) + " seconds."); testDiscrete(tester, classifier, oracle, parser, true, outputGranularity); } /** * Tests the given discrete classifier against the given oracle using the * given parser to provide the labeled testing data. This simplified * interface to * {@link #testDiscrete(TestDiscrete,Classifier,Classifier,Parser,boolean,int)} * assumes there are no null predictions and that output should not be * generated on <code>STDOUT</code>. * * @param classifier The classifier to be tested. * @param oracle The classifier to test against. * @param parser The parser supplying the labeled example objects. * @return A new <code>TestDiscrete</code> object filled with testing * statistics. **/ public static TestDiscrete testDiscrete(Classifier classifier, Classifier oracle, Parser parser) { return testDiscrete(new TestDiscrete(), classifier, oracle, parser, false, 0); } /** * Tests the given discrete classifier against the given oracle using the * given parser to provide the labeled testing data. If the parser returns * examples as <code>Object[]</code>s containing arrays of * <code>int</code>s and <code>double</code>s, as would be the case if * pre-extraction was performed, then it is assumed that this example array * already includes the label, so this is used directly and the oracle * classifier is ignored. In this case, it is also assumed that the given * discrete classifier is an instance of <code>Learner</code> and thus * a lexicon of label mappings can be retrieved from it. * * @param tester An object of this class that has already been told via * {@link #addNull(String)} which prediction values are * considered to be null predictions. * @param classifier The classifier to be tested. * @param oracle The classifier to test against. * @param parser The parser supplying the labeled example objects. * @param output Whether or not to produce output on * <code>STDOUT</code>. * @param outputGranularity * The number of examples processed in between time stamp * messages. * @return The same <code>TestDiscrete</code> object passed in the first * argument, after being filled with statistics. **/ public static TestDiscrete testDiscrete(TestDiscrete tester, Classifier classifier, Classifier oracle, Parser parser, boolean output, int outputGranularity) { int processed = 1; long totalTime = 0; Lexicon labelLexicon = null; Runtime runtime = null; boolean preExtraction = false; if (output && outputGranularity > 0) { runtime = Runtime.getRuntime(); System.out.println("0 examples tested at " + new Date()); System.out.println("Total memory before first example: " + runtime.totalMemory()); Object example = parser.next(); if (example == null) return tester; totalTime -= System.currentTimeMillis(); String prediction = classifier.discreteValue(example); totalTime += System.currentTimeMillis(); System.out.println("First example processed in " + (totalTime / 1000.0) + " seconds."); System.out.println("Total memory after first example: " + runtime.totalMemory()); String gold; if (example instanceof Object[] && ((Object[]) example)[0] instanceof int[]) { preExtraction = true; labelLexicon = ((Learner) classifier).getLabelLexicon(); gold = ((Feature) labelLexicon.lookupKey(((int[]) ((Object[]) example)[2])[0])) .getStringValue(); } else gold = oracle.discreteValue(example); tester.reportPrediction(prediction, gold); for (example = parser.next(); example != null; example = parser.next(), ++processed) { if (processed % outputGranularity == 0) System.out.println(processed + " examples tested at " + new Date()); totalTime -= System.currentTimeMillis(); prediction = classifier.discreteValue(example); totalTime += System.currentTimeMillis(); assert prediction != null : "Classifier returned null prediction for example " + example; if (preExtraction) gold = ((Feature) labelLexicon.lookupKey(((int[]) ((Object[]) example)[2])[0])) .getStringValue(); else gold = oracle.discreteValue(example); tester.reportPrediction(prediction, gold); } System.out.println(processed + " examples tested at " + new Date() + "\n"); } else { if (output) { runtime = Runtime.getRuntime(); System.out.println("Total memory before first example: " + runtime.totalMemory()); } Object example = parser.next(); if (example == null) return tester; totalTime -= System.currentTimeMillis(); String prediction = classifier.discreteValue(example); totalTime += System.currentTimeMillis(); if (output) { System.out.println("First example processed in " + (totalTime / 1000.0) + " seconds."); System.out.println("Total memory after first example: " + runtime.totalMemory()); } String gold; if (example instanceof Object[] && ((Object[]) example)[0] instanceof int[]) { preExtraction = true; labelLexicon = ((Learner) classifier).getLabelLexicon(); gold = ((Feature) labelLexicon.lookupKey(((int[]) ((Object[]) example)[2])[0])) .getStringValue(); } else gold = oracle.discreteValue(example); tester.reportPrediction(prediction, gold); for (example = parser.next(); example != null; example = parser.next(), ++processed) { totalTime -= System.currentTimeMillis(); prediction = classifier.discreteValue(example); totalTime += System.currentTimeMillis(); assert prediction != null : "Classifier returned null prediction for example " + example; if (preExtraction) gold = ((Feature) labelLexicon.lookupKey(((int[]) ((Object[]) example)[2])[0])) .getStringValue(); else gold = oracle.discreteValue(example); tester.reportPrediction(prediction, gold); } } if (output) { System.out.println("Average evaluation time: " + (totalTime / (1000.0 * processed)) + " seconds\n"); tester.printPerformance(System.out); } return tester; } /** * Given command line parameters representing the fully qualified names of * the classifier to be tested, the oracle classifier to test against, the * parser supplying the testing objects, and the input parameter to the * parser's constructor this method instantiates all three objects. * * @param args The command line. * @return A new tester object containing the "null" labels. **/ private static TestDiscrete instantiate(String[] args) { String classifierName = null, oracleName = null, parserName = null; String inputFile = null; TestDiscrete result = new TestDiscrete(); try { int offset = 0; if (args[0].charAt(0) == '-') { if (!args[0].equals("-t")) throw new Exception(); outputGranularity = Integer.parseInt(args[1]); offset = 2; } classifierName = args[offset]; oracleName = args[offset + 1]; parserName = args[offset + 2]; inputFile = args[offset + 3]; for (int i = offset + 4; i < args.length; ++i) result.addNull(args[i]); } catch (Exception e) { System.err.println( "usage:\n" + " java LBJ2.classify.TestDiscrete [-t <n>] <classifier> <oracle> \\\n" + " <parser> <input file> \\\n" + " [<null label> [<null label> ...]]"); System.exit(1); } classifier = ClassUtils.getClassifier(classifierName); oracle = ClassUtils.getClassifier(oracleName); parser = ClassUtils.getParser(parserName, new Class[]{ String.class }, new String[]{ inputFile }); return result; } /** The histogram of correct labels. */ protected HashMap goldHistogram; /** The histogram of predictions. */ protected HashMap predictionHistogram; /** The histogram of correct predictions. */ protected HashMap correctHistogram; /** * The set of "null" labels whose statistics are not included in overall * precision, recall, F1, or accuracy. **/ protected HashSet nullLabels; /** Default constructor. */ public TestDiscrete() { goldHistogram = new HashMap(); predictionHistogram = new HashMap(); correctHistogram = new HashMap(); nullLabels = new HashSet(); } /** * Whenever a prediction is made, report that prediction and the correct * label with this method. * * @param p The prediction. * @param l The correct label. **/ public void reportPrediction(String p, String l) { histogramAdd(goldHistogram, l, 1); histogramAdd(predictionHistogram, p, 1); if (p.equals(l)) histogramAdd(correctHistogram, p, 1); } /** * Report all the predictions in the argument's histograms. * * @param t Another object of this class. **/ public void reportAll(TestDiscrete t) { histogramAddAll(goldHistogram, t.goldHistogram); histogramAddAll(predictionHistogram, t.predictionHistogram); histogramAddAll(correctHistogram, t.correctHistogram); } /** * Returns the set of labels that have been reported so far. * * @return An array containing the labels that have been reported so far. **/ public String[] getLabels() { return (String[]) goldHistogram.keySet().toArray(new String[0]); } /** * Returns the set of predictions that have been reported so far. * * @return An array containing the predictions that have been reported so * far. **/ public String[] getPredictions() { return (String[]) predictionHistogram.keySet().toArray(new String[0]); } /** * Returns the set of all classes reported as either predictions or labels. * * @return An array containing all classes reported as either predictions * or labels. **/ public String[] getAllClasses() { HashSet result = new HashSet(goldHistogram.keySet()); result.addAll(predictionHistogram.keySet()); return (String[]) result.toArray(new String[0]); } /** * Adds a label to the set of "null" labels. * * @param n The label to add. **/ public void addNull(String n) { nullLabels.add(n); } /** * Removes a label from the set of "null" labels. * * @param n The label to remove. **/ public void removeNull(String n) { nullLabels.remove(n); } /** * Determines if a label is treated as a "null" label. * * @param n The label in question. * @return <code>true</code> iff <code>n</code> is one of the "null" * labels. **/ public boolean isNull(String n) { return nullLabels.contains(n); } /** Returns <code>true</code> iff there exist "null" labels. */ public boolean hasNulls() { return nullLabels.size() > 0; } /** * Takes a histogram implemented as a map and increments the count for the * given key by the given amount. * * @param histogram The histogram. * @param key The key whose count should be incremented. * @param amount The amount by which to increment. **/ protected void histogramAdd(HashMap histogram, String key, int amount) { Integer I = (Integer) histogram.get(key); if (I == null) I = new Integer(0); histogram.put(key, new Integer(I.intValue() + amount)); } /** * Takes a histogram implemented as a map and retrieves the count for the * given key. * * @param histogram The histogram. * @param key The key whose count should be retrieved. * @return The count of the specified key. **/ protected int histogramGet(HashMap histogram, String key) { Integer I = (Integer) histogram.get(key); if (I == null) I = new Integer(0); return I.intValue(); } /** * Takes two histograms implemented as maps and adds the amounts found in * the second histogram to the amounts found in the first. * * @param h1 The first histogram, whose values will be modified. * @param h2 The second histogram, whose values will be added into the * first's. **/ protected void histogramAddAll(HashMap h1, HashMap h2) { for (Iterator I = h2.entrySet().iterator(); I.hasNext(); ) { Map.Entry e = (Map.Entry) I.next(); histogramAdd(h1, (String) e.getKey(), ((Integer) e.getValue()).intValue()); } } /** * Returns the number of times the requested label was reported. * * @param l The label in question. * @return The number of times <code>l</code> was reported. **/ public int getLabeled(String l) { return histogramGet(goldHistogram, l); } /** * Returns the number of times the requested prediction was reported. * * @param p The prediction in question. * @return The number of times <code>p</code> was reported. **/ public int getPredicted(String p) { return histogramGet(predictionHistogram, p); } /** * Returns the number of times the requested prediction was reported * correctly. * * @param p The prediction in question. * @return The number of times <code>p</code> was reported. **/ public int getCorrect(String p) { return histogramGet(correctHistogram, p); } /** * Returns the precision associated with the given prediction. * * @param p The given prediction. * @return The precision associated with <code>p</code>. **/ public double getPrecision(String p) { return getCorrect(p) / (double) getPredicted(p); } /** * Returns the recall associated with the given label. * * @param l The given label. * @return The precision associated with <code>l</code>. **/ public double getRecall(String l) { return getCorrect(l) / (double) getLabeled(l); } /** * Returns the F<sub>1</sub> score associated with the given label. * * @param l The given label. * @return The F<sub>1</sub> score associated with <code>l</code>. **/ public double getF1(String l) { return getF(1, l); } /** * Returns the F<sub>beta</sub> score associated with the given label. * F<sub>beta</sub> is defined as: * <blockquote> * <i>F<sub>beta</sub> = (beta<sup>2</sup> + 1) * P * R</i> * <i>/ (beta<sup>2</sup> * P + R)</i> * </blockquote> * * @param b The value of beta. * @param l The given label. * @return The F<sub>beta</sub> score associated with <code>l</code>. **/ public double getF(double b, String l) { double precision = getPrecision(l); double recall = getRecall(l); return (b * b + 1) * precision * recall / (b * b * precision + recall); } /** * Computes overall the overall statistics precision, recall, * F<sub>1</sub>, and accuracy. Note that these statistics are all * equivalent unless "null" labels have been added. * * @return An array in which the first element represents overall * precision, the second represents overall recall, then F1, and * finally accuracy. **/ public double[] getOverallStats() { return getOverallStats(1); } /** * Computes overall the overall statistics precision, recall, * F<sub>beta</sub>, and accuracy. Note that these statistics are all * equivalent unless "null" labels have been added. * * @param b The value of beta. * @return An array in which the first element represents overall * precision, the second represents overall recall, then F1, and * finally accuracy. **/ public double[] getOverallStats(double b) { String[] allClasses = getAllClasses(); int totalCorrect = 0; int totalPredicted = 0; int notNullCorrect = 0; int notNullPredicted = 0; int notNullLabeled = 0; for (int i = 0; i < allClasses.length; ++i) { int correct = getCorrect(allClasses[i]); int predicted = getPredicted(allClasses[i]); int labeled = getLabeled(allClasses[i]); totalCorrect += correct; totalPredicted += predicted; if (hasNulls() && !isNull(allClasses[i])) { notNullCorrect += correct; notNullPredicted += predicted; notNullLabeled += labeled; } } double[] result = new double[4]; result[3] = totalCorrect / (double) totalPredicted; if (hasNulls()) { result[0] = notNullCorrect / (double) notNullPredicted; result[1] = notNullCorrect / (double) notNullLabeled; result[2] = (b * b + 1) * result[0] * result[1] / (b * b * result[0] + result[1]); } else result[0] = result[1] = result[2] = result[3]; return result; } /** * Performance results are written to the given stream in the form of * precision, recall, and F1 statistics. * * @param out The stream to write to. **/ public void printPerformance(PrintStream out) { String[] allClasses = getAllClasses(); final HashSet n = nullLabels; Arrays.sort(allClasses, new Comparator() { public int compare(Object o1, Object o2) { String s1 = (String) o1; String s2 = (String) o2; int n1 = n.contains(s1) ? 1 : 0; int n2 = n.contains(s2) ? 1 : 0; if (n1 != n2) return n1 - n2; return s1.compareTo(s2); } }); int rows = allClasses.length + 1; if (hasNulls()) ++rows; String[] rowLabels = new String[rows]; System.arraycopy(allClasses, 0, rowLabels, 0, allClasses.length); rowLabels[rows - 1] = "Accuracy"; if (hasNulls()) rowLabels[rows - 2] = "Overall"; String[] columnLabels = new String[]{ "Label", "Precision", "Recall", "F1", "LCount", "PCount" }; int totalCorrect = 0; int totalPredicted = 0; int notNullCorrect = 0; int notNullPredicted = 0; int notNullLabeled = 0; Double[][] table = new Double[rows][]; Double zero = new Double(0); for (int i = 0; i < allClasses.length; ++i) { int correct = getCorrect(allClasses[i]); int predicted = getPredicted(allClasses[i]); int labeled = getLabeled(allClasses[i]); totalCorrect += correct; totalPredicted += predicted; if (hasNulls() && !isNull(allClasses[i])) { notNullCorrect += correct; notNullPredicted += predicted; notNullLabeled += labeled; } table[i] = new Double[]{ zero, zero, zero, new Double(labeled), new Double(predicted) }; if (predicted > 0) table[i][0] = new Double(100 * correct / (double) predicted); if (labeled > 0) table[i][1] = new Double(100 * correct / (double) labeled); if (correct > 0) { double p = table[i][0].doubleValue(); double r = table[i][1].doubleValue(); table[i][2] = new Double(2 * p * r / (p + r)); } } int[] dashRows = null; if (hasNulls()) { table[rows - 2] = new Double[]{ zero, zero, zero, new Double(notNullLabeled), new Double(notNullPredicted) }; if (notNullPredicted > 0) table[rows - 2][0] = new Double(100 * notNullCorrect / (double) notNullPredicted); if (notNullLabeled > 0) table[rows - 2][1] = new Double(100 * notNullCorrect / (double) notNullLabeled); if (notNullCorrect > 0) { double p = table[rows - 2][0].doubleValue(); double r = table[rows - 2][1].doubleValue(); table[rows - 2][2] = new Double(2 * p * r / (p + r)); } int nonNullLabels = allClasses.length - nullLabels.size(); dashRows = new int[]{ 0, nonNullLabels, allClasses.length }; } else dashRows = new int[]{ 0, allClasses.length }; double accuracy = totalPredicted == 0 ? 0 : 100 * totalCorrect / (double) totalPredicted; table[rows - 1] = new Double[]{ new Double(accuracy), null, null, null, new Double(totalPredicted) }; TableFormat.printTableFormat(out, columnLabels, rowLabels, table, new int[]{ 3, 3, 3, 0, 0 }, dashRows); } }