package edu.stanford.nlp.classify.demo; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.File; import java.io.ObjectOutputStream; import java.io.ObjectInputStream; import java.io.IOException; import edu.stanford.nlp.classify.Classifier; import edu.stanford.nlp.classify.ColumnDataClassifier; import edu.stanford.nlp.classify.LinearClassifier; import edu.stanford.nlp.ling.Datum; import edu.stanford.nlp.objectbank.ObjectBank; import edu.stanford.nlp.util.ErasureUtils; import edu.stanford.nlp.util.Pair; class ClassifierDemo { private static String where = ""; public static void main(String[] args) throws Exception { if (args.length > 0) { where = args[0] + File.separator; } System.out.println("Training ColumnDataClassifier"); ColumnDataClassifier cdc = new ColumnDataClassifier(where + "examples/cheese2007.prop"); cdc.trainClassifier(where + "examples/cheeseDisease.train"); System.out.println(); System.out.println("Testing predictions of ColumnDataClassifier"); for (String line : ObjectBank.getLineIterator(where + "examples/cheeseDisease.test", "utf-8")) { // instead of the method in the line below, if you have the individual elements // already you can use cdc.makeDatumFromStrings(String[]) Datum<String,String> d = cdc.makeDatumFromLine(line); System.out.printf("%s ==> %s (%.4f)%n", line, cdc.classOf(d), cdc.scoresOf(d).getCount(cdc.classOf(d))); } System.out.println(); System.out.println("Testing accuracy of ColumnDataClassifier"); Pair<Double, Double> performance = cdc.testClassifier(where + "examples/cheeseDisease.test"); System.out.printf("Accuracy: %.3f; macro-F1: %.3f%n", performance.first(), performance.second()); demonstrateSerialization(); demonstrateSerializationColumnDataClassifier(); } private static void demonstrateSerialization() throws IOException, ClassNotFoundException { System.out.println(); System.out.println("Demonstrating working with a serialized classifier"); ColumnDataClassifier cdc = new ColumnDataClassifier(where + "examples/cheese2007.prop"); Classifier<String,String> cl = cdc.makeClassifier(cdc.readTrainingExamples(where + "examples/cheeseDisease.train")); // Exhibit serialization and deserialization working. Serialized to bytes in memory for simplicity System.out.println(); ByteArrayOutputStream baos = new ByteArrayOutputStream(); ObjectOutputStream oos = new ObjectOutputStream(baos); oos.writeObject(cl); oos.close(); byte[] object = baos.toByteArray(); ByteArrayInputStream bais = new ByteArrayInputStream(object); ObjectInputStream ois = new ObjectInputStream(bais); LinearClassifier<String,String> lc = ErasureUtils.uncheckedCast(ois.readObject()); ois.close(); ColumnDataClassifier cdc2 = new ColumnDataClassifier(where + "examples/cheese2007.prop"); // We compare the output of the deserialized classifier lc versus the original one cl // For both we use a ColumnDataClassifier to convert text lines to examples System.out.println(); System.out.println("Making predictions with both classifiers"); for (String line : ObjectBank.getLineIterator(where + "examples/cheeseDisease.test", "utf-8")) { Datum<String,String> d = cdc.makeDatumFromLine(line); Datum<String,String> d2 = cdc2.makeDatumFromLine(line); System.out.printf("%s =origi=> %s (%.4f)%n", line, cl.classOf(d), cl.scoresOf(d).getCount(cl.classOf(d))); System.out.printf("%s =deser=> %s (%.4f)%n", line, lc.classOf(d2), lc.scoresOf(d).getCount(lc.classOf(d))); } } private static void demonstrateSerializationColumnDataClassifier() throws IOException, ClassNotFoundException { System.out.println(); System.out.println("Demonstrating working with a serialized classifier using serializeTo"); ColumnDataClassifier cdc = new ColumnDataClassifier(where + "examples/cheese2007.prop"); cdc.trainClassifier(where + "examples/cheeseDisease.train"); // Exhibit serialization and deserialization working. Serialized to bytes in memory for simplicity System.out.println(); ByteArrayOutputStream baos = new ByteArrayOutputStream(); ObjectOutputStream oos = new ObjectOutputStream(baos); cdc.serializeClassifier(oos); oos.close(); byte[] object = baos.toByteArray(); ByteArrayInputStream bais = new ByteArrayInputStream(object); ObjectInputStream ois = new ObjectInputStream(bais); ColumnDataClassifier cdc2 = ColumnDataClassifier.getClassifier(ois); ois.close(); // We compare the output of the deserialized classifier cdc2 versus the original one cl // For both we use a ColumnDataClassifier to convert text lines to examples System.out.println("Making predictions with both classifiers"); for (String line : ObjectBank.getLineIterator(where + "examples/cheeseDisease.test", "utf-8")) { Datum<String,String> d = cdc.makeDatumFromLine(line); Datum<String,String> d2 = cdc2.makeDatumFromLine(line); System.out.printf("%s =origi=> %s (%.4f)%n", line, cdc.classOf(d), cdc.scoresOf(d).getCount(cdc.classOf(d))); System.out.printf("%s =deser=> %s (%.4f)%n", line, cdc2.classOf(d2), cdc2.scoresOf(d).getCount(cdc2.classOf(d))); } } }