package edu.stanford.nlp.ie.crf; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.PrintStream; import java.io.UnsupportedEncodingException; import java.util.ArrayList; import java.util.Properties; import edu.stanford.nlp.util.PropertiesUtils; import edu.stanford.nlp.util.StringUtils; import edu.stanford.nlp.util.Timing; public class TestThreadedCRFClassifier { TestThreadedCRFClassifier(Properties props) { inputEncoding = props.getProperty("inputEncoding", "UTF-8"); } // number of threads to run the first specified classifier under private static final int DEFAULT_SIM_THREADS = 3; private static final int DEFAULT_MULTIPLE_THREADS = 2; private final String inputEncoding; static CRFClassifier loadClassifier(String loadPath, Properties props) { CRFClassifier crf = new CRFClassifier(props); crf.loadClassifierNoExceptions(loadPath, props); return crf; } String runClassifier(CRFClassifier crf, String testFile) { try { ByteArrayOutputStream output = new ByteArrayOutputStream(); crf.classifyAndWriteAnswers(testFile, output, crf.makeReaderAndWriter(), true); return output.toString(inputEncoding); } catch (IOException e) { throw new RuntimeException(e); } } class CRFThread extends Thread { private final CRFClassifier crf; private final String filename; private final String threadName; private String resultsString = ""; public String getResultsString() { return resultsString; } CRFThread(CRFClassifier crf, String filename, String threadName) { this.crf = crf; this.filename = filename; this.threadName = threadName; } @Override public void run() { Timing t = new Timing(); resultsString = runClassifier(crf, filename); long millis = t.stop(); System.out.println("Thread " + threadName + " took " + millis + "ms to tag file " + filename); } } /** * Sample command line: * <br> * java -mx4g edu.stanford.nlp.ie.crf.TestThreadedCRFClassifier * -crf1 ../stanford-releases/stanford-ner-models/hgc_175m_600.ser.gz * -crf2 ../stanford-releases/stanford-ner-models/dewac_175m_600.ser.gz * -testFile ../data/german-ner/deu.testa -inputEncoding iso-8859-1 */ public static void main(String[] args) { try { System.setOut(new PrintStream(System.out, true, "UTF-8")); System.setErr(new PrintStream(System.err, true, "UTF-8")); } catch (UnsupportedEncodingException e) { throw new RuntimeException(e); } runTest(StringUtils.argsToProperties(args)); } static public void runTest(Properties props) { TestThreadedCRFClassifier test = new TestThreadedCRFClassifier(props); test.runThreadedTest(props); } void runThreadedTest(Properties props) { // TODO: check params final String testFile = props.getProperty("testFile"); ArrayList<String> baseResults = new ArrayList<String>(); ArrayList<String> modelNames = new ArrayList<String>(); ArrayList<CRFClassifier> classifiers = new ArrayList<CRFClassifier>(); for (int i = 1; props.getProperty("crf" + Integer.toString(i)) != null; ++i) { String model = props.getProperty("crf" + Integer.toString(i)); CRFClassifier crf = loadClassifier(model, props); System.out.println("Loaded model " + model); modelNames.add(model); classifiers.add(crf); String results = runClassifier(crf, testFile); // must run twice to account for "transductive learning" results = runClassifier(crf, testFile); baseResults.add(results); System.out.println("Stored base results for " + model + "; length " + results.length()); } // test to make sure loading and running multiple classifiers // hasn't messed with previous results for (int i = 0; i < classifiers.size(); ++i) { CRFClassifier crf = classifiers.get(i); String model = modelNames.get(i); String base = baseResults.get(i); String repeated = runClassifier(crf, testFile); if (!base.equals(repeated)) { throw new RuntimeException("Repeated unthreaded results " + "not the same for " + model + " run on file " + testFile); } } // test the first classifier in several simultaneous threads int numThreads = PropertiesUtils.getInt(props, "simThreads", DEFAULT_SIM_THREADS); ArrayList<CRFThread> threads = new ArrayList<CRFThread>(); for (int i = 0; i < numThreads; ++i) { threads.add(new CRFThread(classifiers.get(0), testFile, "Simultaneous-" + i)); } for (int i = 0; i < numThreads; ++i) { threads.get(i).start(); } for (int i = 0; i < numThreads; ++i) { try { threads.get(i).join(); } catch (InterruptedException e) { throw new RuntimeException(e); } if (baseResults.get(0).equals(threads.get(i).getResultsString())) { System.out.println("Yay!"); } else { throw new RuntimeException("Results not equal when running " + modelNames.get(0) + " under " + numThreads + " simultaneous threads"); } } // test multiple classifiers (if given) in multiple threads each if (classifiers.size() > 1) { numThreads = PropertiesUtils.getInt(props, "multipleThreads", DEFAULT_MULTIPLE_THREADS); threads = new ArrayList<CRFThread>(); for (int i = 0; i < numThreads * classifiers.size(); ++i) { int classifierNum = i % classifiers.size(); int repeatNum = i / classifiers.size(); threads.add(new CRFThread(classifiers.get(classifierNum), testFile, ("Simultaneous-" + classifierNum + "-" + repeatNum))); } for (CRFThread thread : threads) { thread.start(); } for (int i = 0; i < threads.size(); ++i) { int classifierNum = i % classifiers.size(); int repeatNum = i / classifiers.size(); try { threads.get(i).join(); } catch (InterruptedException e) { throw new RuntimeException(e); } String base = baseResults.get(classifierNum); String threadResults = threads.get(i).getResultsString(); if (base.equals(threadResults)) { System.out.println("Yay!"); } else { throw new RuntimeException("Results not equal when running " + modelNames.get(classifierNum) + " under " + numThreads + " threads with " + classifiers.size() + " total classifiers"); } } } // if no exceptions thrown, great success System.out.println("Everything worked!"); } }