package edu.stanford.nlp.loglinear.benchmarks; import edu.stanford.nlp.util.logging.Redwood; import edu.stanford.nlp.loglinear.inference.CliqueTree; import edu.stanford.nlp.loglinear.learning.AbstractBatchOptimizer; import edu.stanford.nlp.loglinear.learning.BacktrackingAdaGradOptimizer; import edu.stanford.nlp.loglinear.learning.LogLikelihoodDifferentiableFunction; import edu.stanford.nlp.loglinear.model.ConcatVector; import edu.stanford.nlp.loglinear.model.ConcatVectorNamespace; import edu.stanford.nlp.loglinear.model.GraphicalModel; import java.io.*; import java.util.*; import java.util.zip.GZIPInputStream; import java.util.zip.GZIPOutputStream; /** * Created on 8/26/15. * @author keenon * <p> * This loads the CoNLL dataset and 300 dimensional google word embeddings and trains a model on the data using binary * and unary factors. This is a nice explanation of why it is key to have ConcatVector as a datastructure, since there * is no need to specify the number of words in advance anywhere, and data structures will happily resize with a minimum * of GCC wastage. */ public class CoNLLBenchmark { /** A logger for this class */ private static Redwood.RedwoodChannels log = Redwood.channels(CoNLLBenchmark.class); static final String DATA_PATH = "/u/nlp/data/ner/conll/"; Map<String, double[]> embeddings = new HashMap<>(); public static void main(String[] args) throws Exception { new CoNLLBenchmark().benchmarkOptimizer(); } public void benchmarkOptimizer() throws Exception { List<CoNLLSentence> train = getSentences(DATA_PATH + "conll.iob.4class.train"); List<CoNLLSentence> testA = getSentences(DATA_PATH + "conll.iob.4class.testa"); List<CoNLLSentence> testB = getSentences(DATA_PATH + "conll.iob.4class.testb"); List<CoNLLSentence> allData = new ArrayList<>(); allData.addAll(train); allData.addAll(testA); allData.addAll(testB); Set<String> tagsSet = new HashSet<>(); for (CoNLLSentence sentence : allData) for (String nerTag : sentence.ner) tagsSet.add(nerTag); List<String> tags = new ArrayList<>(); tags.addAll(tagsSet); embeddings = getEmbeddings(DATA_PATH + "google-300-trimmed.ser.gz", allData); log.info("Making the training set..."); ConcatVectorNamespace namespace = new ConcatVectorNamespace(); int trainSize = train.size(); GraphicalModel[] trainingSet = new GraphicalModel[trainSize]; for (int i = 0; i < trainSize; i++) { if (i % 10 == 0) { log.info(i + "/" + trainSize); } trainingSet[i] = generateSentenceModel(namespace, train.get(i), tags); } log.info("Training system..."); AbstractBatchOptimizer opt = new BacktrackingAdaGradOptimizer(); // This training call is basically what we want the benchmark for. It should take 99% of the wall clock time ConcatVector weights = opt.optimize(trainingSet, new LogLikelihoodDifferentiableFunction(), namespace.newWeightsVector(), 0.01, 1.0e-5, false); log.info("Testing system..."); // Evaluation method lifted from the CoNLL 2004 perl script Map<String, Double> correctChunk = new HashMap<>(); Map<String, Double> foundCorrect = new HashMap<>(); Map<String, Double> foundGuessed = new HashMap<>(); double correct = 0.0; double total = 0.0; for (CoNLLSentence sentence : testA) { GraphicalModel model = generateSentenceModel(namespace, sentence, tags); int[] guesses = new CliqueTree(model, weights).calculateMAP(); String[] nerGuesses = new String[guesses.length]; for (int i = 0; i < guesses.length; i++) { nerGuesses[i] = tags.get(guesses[i]); if (nerGuesses[i].equals(sentence.ner.get(i))) { correct++; correctChunk.put(nerGuesses[i], correctChunk.getOrDefault(nerGuesses[i], 0.) + 1); } total++; foundCorrect.put(sentence.ner.get(i), foundCorrect.getOrDefault(sentence.ner.get(i), 0.) + 1); foundGuessed.put(nerGuesses[i], foundGuessed.getOrDefault(nerGuesses[i], 0.) + 1); } } log.info("\nSystem results:\n"); log.info("Accuracy: " + (correct / total) + "\n"); for (String tag : tags) { double precision = foundGuessed.getOrDefault(tag, 0.0) == 0 ? 0.0 : correctChunk.getOrDefault(tag, 0.0) / foundGuessed.get(tag); double recall = foundCorrect.getOrDefault(tag, 0.0) == 0 ? 0.0 : correctChunk.getOrDefault(tag, 0.0) / foundCorrect.get(tag); double f1 = (precision + recall == 0.0) ? 0.0 : (precision * recall * 2) / (precision + recall); log.info(tag + " (" + foundCorrect.getOrDefault(tag, 0.0).intValue() + ")"); log.info("\tP:" + precision + " (" + correctChunk.getOrDefault(tag, 0.0).intValue() + "/" + foundGuessed.getOrDefault(tag, 0.0).intValue() + ")"); log.info("\tR:" + recall + " (" + correctChunk.getOrDefault(tag, 0.0).intValue() + "/" + foundCorrect.getOrDefault(tag, 0.0).intValue() + ")"); log.info("\tF1:" + f1); } } //////////////////////////////////////////////////////////////////////////////////////////// // GENERATING MODELS //////////////////////////////////////////////////////////////////////////////////////////// private static String getWordShape(String string) { if (string.toUpperCase().equals(string) && string.toLowerCase().equals(string)) return "no-case"; if (string.toUpperCase().equals(string)) return "upper-case"; if (string.toLowerCase().equals(string)) return "lower-case"; if (string.length() > 1 && Character.isUpperCase(string.charAt(0)) && string.substring(1).toLowerCase().equals(string.substring(1))) return "capitalized"; return "mixed-case"; } public GraphicalModel generateSentenceModel(ConcatVectorNamespace namespace, CoNLLSentence sentence, List<String> tags) { GraphicalModel model = new GraphicalModel(); for (int i = 0; i < sentence.token.size(); i++) { // Add the training label Map<String, String> metadata = model.getVariableMetaDataByReference(i); metadata.put(LogLikelihoodDifferentiableFunction.VARIABLE_TRAINING_VALUE, "" + tags.indexOf(sentence.ner.get(i))); metadata.put("TOKEN", "" + sentence.token.get(i)); metadata.put("POS", "" + sentence.pos.get(i)); metadata.put("CHUNK", "" + sentence.npchunk.get(i)); metadata.put("TAG", "" + sentence.ner.get(i)); } CoNLLFeaturizer.annotate(model, tags, namespace, embeddings); assert (model.factors != null); for (GraphicalModel.Factor f : model.factors) { assert (f != null); } return model; } //////////////////////////////////////////////////////////////////////////////////////////// // LOADING DATA FROM FILES //////////////////////////////////////////////////////////////////////////////////////////// public static class CoNLLSentence { public List<String> token = new ArrayList<>(); public List<String> ner = new ArrayList<>(); public List<String> pos = new ArrayList<>(); public List<String> npchunk = new ArrayList<>(); public CoNLLSentence(List<String> token, List<String> ner, List<String> pos, List<String> npchunk) { this.token = token; this.ner = ner; this.pos = pos; this.npchunk = npchunk; } } public List<CoNLLSentence> getSentences(String filename) throws IOException { List<CoNLLSentence> sentences = new ArrayList<>(); List<String> tokens = new ArrayList<>(); List<String> ner = new ArrayList<>(); List<String> pos = new ArrayList<>(); List<String> npchunk = new ArrayList<>(); BufferedReader br = new BufferedReader(new FileReader(filename)); String line; while ((line = br.readLine()) != null) { String[] parts = line.split("\t"); if (parts.length == 4) { tokens.add(parts[0]); pos.add(parts[1]); npchunk.add(parts[2]); String tag = parts[3]; if (tag.contains("-")) { ner.add(tag.split("-")[1]); } else { ner.add(tag); } if (parts[0].equals(".")) { sentences.add(new CoNLLSentence(tokens, ner, pos, npchunk)); tokens = new ArrayList<>(); ner = new ArrayList<>(); pos = new ArrayList<>(); npchunk = new ArrayList<>(); } } } return sentences; } @SuppressWarnings("unchecked") public Map<String, double[]> getEmbeddings(String cacheFilename, List<CoNLLSentence> sentences) throws IOException, ClassNotFoundException { File f = new File(cacheFilename); Map<String, double[]> trimmedSet; if (!f.exists()) { trimmedSet = new HashMap<>(); Map<String, double[]> massiveSet = loadEmbeddingsFromFile("../google-300.txt"); log.info("Got massive embedding set size " + massiveSet.size()); for (CoNLLSentence sentence : sentences) { for (String token : sentence.token) { if (massiveSet.containsKey(token)) { trimmedSet.put(token, massiveSet.get(token)); } } } log.info("Got trimmed embedding set size " + trimmedSet.size()); f.createNewFile(); ObjectOutputStream oos = new ObjectOutputStream(new GZIPOutputStream(new FileOutputStream(cacheFilename))); oos.writeObject(trimmedSet); oos.close(); log.info("Wrote trimmed set to file"); } else { ObjectInputStream ois = new ObjectInputStream(new GZIPInputStream(new FileInputStream(cacheFilename))); trimmedSet = (Map<String, double[]>) ois.readObject(); } return trimmedSet; } public Map<String, double[]> loadEmbeddingsFromFile(String filename) throws IOException { Map<String, double[]> embeddings = new HashMap<>(); BufferedReader br = new BufferedReader(new FileReader(filename)); int readLines = 0; String line = br.readLine(); while ((line = br.readLine()) != null) { String[] parts = line.split(" "); if (parts.length == 302) { String token = parts[0]; double[] embedding = new double[300]; for (int i = 1; i < parts.length - 1; i++) { embedding[i - 1] = Double.parseDouble(parts[i]); } embeddings.put(token, embedding); } readLines++; if (readLines % 10000 == 0) { log.info("Read " + readLines + " lines"); } } return embeddings; } }