// CMMClassifier -- a conditional maximum-entropy markov model, mainly used for NER. // Copyright (c) 2002-2014 The Board of Trustees of // The Leland Stanford Junior University. All Rights Reserved. // // This program is free software; you can redistribute it and/or // modify it under the terms of the GNU General Public License // as published by the Free Software Foundation; either version 2 // of the License, or (at your option) any later version. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU General Public License for more details. // // You should have received a copy of the GNU General Public License // along with this program; if not, write to the Free Software Foundation, // Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. // // For more information, bug reports, fixes, contact: // Christopher Manning // Dept of Computer Science, Gates 1A // Stanford CA 94305-9010 // USA // Support/Questions: java-nlp-user@lists.stanford.edu // Licensing: java-nlp-support@lists.stanford.edu package edu.stanford.nlp.ie.ner; import java.io.BufferedInputStream; import java.io.BufferedReader; import java.io.File; import java.io.FileReader; import java.io.IOException; import java.io.InputStream; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.Iterator; import java.util.List; import java.util.Properties; import java.util.Set; import java.util.regex.Pattern; import edu.stanford.nlp.classify.Dataset; import edu.stanford.nlp.classify.LinearClassifier; import edu.stanford.nlp.classify.LinearClassifierFactory; import edu.stanford.nlp.classify.LogPrior; import edu.stanford.nlp.classify.NBLinearClassifierFactory; import edu.stanford.nlp.classify.ProbabilisticClassifier; import edu.stanford.nlp.classify.SVMLightClassifierFactory; import edu.stanford.nlp.ie.AbstractSequenceClassifier; import edu.stanford.nlp.ie.NERFeatureFactory; import edu.stanford.nlp.io.IOUtils; import edu.stanford.nlp.ling.BasicDatum; import edu.stanford.nlp.ling.CoreLabel; import edu.stanford.nlp.ling.Datum; import edu.stanford.nlp.ling.CoreAnnotations; import edu.stanford.nlp.math.ArrayMath; import edu.stanford.nlp.math.SloppyMath; import edu.stanford.nlp.objectbank.ObjectBank; import edu.stanford.nlp.sequences.BeamBestSequenceFinder; import edu.stanford.nlp.sequences.Clique; import edu.stanford.nlp.sequences.DocumentReaderAndWriter; import edu.stanford.nlp.sequences.ExactBestSequenceFinder; import edu.stanford.nlp.sequences.FeatureFactory; import edu.stanford.nlp.sequences.PlainTextDocumentReaderAndWriter; import edu.stanford.nlp.sequences.SeqClassifierFlags; import edu.stanford.nlp.sequences.SequenceModel; import edu.stanford.nlp.stats.ClassicCounter; import edu.stanford.nlp.stats.Counter; import edu.stanford.nlp.util.CoreMap; import edu.stanford.nlp.util.ErasureUtils; import edu.stanford.nlp.util.Generics; import edu.stanford.nlp.util.HashIndex; import edu.stanford.nlp.util.Index; import edu.stanford.nlp.util.PaddedList; import edu.stanford.nlp.util.Pair; import edu.stanford.nlp.util.StringUtils; /** * Does Sequence Classification using a Conditional Markov Model. * It could be used for other purposes, but the provided features * are aimed at doing Named Entity Recognition. * The code has functionality for different document encodings, but when * using the standard <code>ColumnDocumentReader</code>, * input files are expected to * be one word per line with the columns indicating things like the word, * POS, chunk, and class. * <p/> * <b>Typical usage</b> * <p>For running a trained model with a provided serialized classifier: <p> * <code> * java -server -mx1000m edu.stanford.nlp.ie.ner.CMMClassifier -loadClassifier * conll.ner.gz -textFile samplesentences.txt * </code><p> * When specifying all parameters in a properties file (train, test, or * runtime):<p> * <code> * java -mx1000m edu.stanford.nlp.ie.ner.CMMClassifier -prop propFile * </code><p> * To train and test a model from the command line:<p> * <code>java -mx1000m edu.stanford.nlp.ie.ner.CMMClassifier * -trainFile trainFile -testFile testFile -goodCoNLL > output </code> * <p/> * Features are defined by a {@link FeatureFactory}; the * {@link FeatureFactory} which is used by default is * {@link NERFeatureFactory}, and you should look there for feature templates. * Features are specified either by a Properties file (which is the * recommended method) or on the command line. The features are read into * a {@link SeqClassifierFlags} object, which the * user need not know much about, unless one wishes to add new features. * <p/> * CMMClassifier may also be used programmatically. When creating a new instance, you * <i>must</i> specify a properties file. The other way to get a CMMClassifier is to * deserialize one via {@link CMMClassifier#getClassifier(String)}, which returns a * deserialized classifier. You may then tag sentences using either the assorted * <code>test</code> or <code>testSentence</code> methods. * * @author Dan Klein * @author Jenny Finkel * @author Christopher Manning * @author Shipra Dingare * @author Huy Nguyen * @author Sarah Spikes (sdspikes@cs.stanford.edu) - cleanup and filling in types */ public class CMMClassifier<IN extends CoreLabel> extends AbstractSequenceClassifier<IN> { private ProbabilisticClassifier<String, String> classifier; /** The set of empirically legal label sequences (of length (order) at most * <code>flags.maxLeft</code>). Used to filter valid class sequences if * <code>useObuseObservedSequencesOnly</code> is set. */ Set<List<String>> answerArrays; /** Default place to look in Jar file for classifier. */ public static final String DEFAULT_CLASSIFIER = "/classifiers/ner-eng-ie.cmm-3-all2006.ser.gz"; protected CMMClassifier() { super(new SeqClassifierFlags()); } public CMMClassifier(Properties props) { super(props); } public CMMClassifier(SeqClassifierFlags flags) { super(flags); } /** * Returns the Set of entities recognized by this Classifier. * * @return The Set of entities recognized by this Classifier. */ public Set<String> getTags() { Set<String> tags = Generics.newHashSet(classIndex.objectsList()); tags.remove(flags.backgroundSymbol); return tags; } /** * Classify a {@link List} of {@link CoreLabel}s. * * @param document A {@link List} of {@link CoreLabel}s * to be classified. */ @Override public List<IN> classify(List<IN> document) { if (flags.useSequences) { classifySeq(document); } else { classifyNoSeq(document); } return document; } /** * Classify a List of {@link CoreLabel}s without using sequence information * (i.e. no Viterbi algorithm, just distribution over next class). * * @param document a List of {@link CoreLabel}s to be classified */ private void classifyNoSeq(List<IN> document) { if (flags.useReverse) { Collections.reverse(document); } if (flags.lowerNewgeneThreshold) { // Used to raise recall for task 1B System.err.println("Using NEWGENE threshold: " + flags.newgeneThreshold); for (int i = 0, docSize = document.size(); i < docSize; i++) { CoreLabel wordInfo = document.get(i); Datum<String, String> d = makeDatum(document, i, featureFactories); Counter<String> scores = classifier.scoresOf(d); //String answer = BACKGROUND; String answer = flags.backgroundSymbol; // HN: The evaluation of scoresOf seems to result in some // kind of side effect. Specifically, the symptom is that // if scoresOf is not evaluated at every position, the // answers are different if ("NEWGENE".equals(wordInfo.get(CoreAnnotations.GazAnnotation.class))) { for (String label : scores.keySet()) { if ("G".equals(label)) { System.err.println(wordInfo.word() + ':' + scores.getCount(label)); if (scores.getCount(label) > flags.newgeneThreshold) { answer = label; } } } } wordInfo.set(CoreAnnotations.AnswerAnnotation.class, answer); } } else { for (int i = 0, listSize = document.size(); i < listSize; i++) { String answer = classOf(document, i); CoreLabel wordInfo = document.get(i); //System.err.println("XXX answer for " + // wordInfo.word() + " is " + answer); wordInfo.set(CoreAnnotations.AnswerAnnotation.class, answer); } if (flags.justify && (classifier instanceof LinearClassifier)) { LinearClassifier<String, String> lc = (LinearClassifier<String, String>) classifier; for (int i = 0, lsize = document.size(); i < lsize; i++) { CoreLabel lineInfo = document.get(i); System.err.print("@@ Position " + i + ": "); System.err.println(lineInfo.word() + " chose " + lineInfo.get(CoreAnnotations.AnswerAnnotation.class)); lc.justificationOf(makeDatum(document, i, featureFactories)); } } } if (flags.useReverse) { Collections.reverse(document); } } /** * Returns the most likely class for the word at the given position. */ protected String classOf(List<IN> lineInfos, int pos) { Datum<String, String> d = makeDatum(lineInfos, pos, featureFactories); return classifier.classOf(d); } /** * Returns the log conditional likelihood of the given dataset. * * @return The log conditional likelihood of the given dataset. */ public double loglikelihood(List<IN> lineInfos) { double cll = 0.0; for (int i = 0; i < lineInfos.size(); i++) { Datum<String, String> d = makeDatum(lineInfos, i, featureFactories); Counter<String> c = classifier.logProbabilityOf(d); double total = Double.NEGATIVE_INFINITY; for (String s : c.keySet()) { total = SloppyMath.logAdd(total, c.getCount(s)); } cll -= c.getCount(d.label()) - total; } // quadratic prior // HN: TODO: add other priors if (classifier instanceof LinearClassifier) { double sigmaSq = flags.sigma * flags.sigma; LinearClassifier<String, String> lc = (LinearClassifier<String, String>)classifier; for (String feature: lc.features()) { for (String classLabel: classIndex) { double w = lc.weight(feature, classLabel); cll += w * w / 2.0 / sigmaSq; } } } return cll; } @Override public SequenceModel getSequenceModel(List<IN> document) { //System.err.println(flags.useReverse); if (flags.useReverse) { Collections.reverse(document); } // cdm Aug 2005: why is this next line needed? Seems really ugly!!! [2006: it broke things! removed] // document.add(0, new CoreLabel()); SequenceModel ts = new Scorer<IN>(document, classIndex, this, (!flags.useTaggySequences ? (flags.usePrevSequences ? 1 : 0) : flags.maxLeft), (flags.useNextSequences ? 1 : 0), answerArrays); return ts; } /** * Classify a List of {@link CoreLabel}s using sequence information * (i.e. Viterbi or Beam Search). * * @param document A List of {@link CoreLabel}s to be classified */ private void classifySeq(List<IN> document) { if (document.isEmpty()) { return; } SequenceModel ts = getSequenceModel(document); // TagScorer ts = new PrevOnlyScorer(document, tagIndex, this, (!flags.useTaggySequences ? (flags.usePrevSequences ? 1 : 0) : flags.maxLeft), 0, answerArrays); int[] tags; //System.err.println("***begin test***"); if (flags.useViterbi) { ExactBestSequenceFinder ti = new ExactBestSequenceFinder(); tags = ti.bestSequence(ts); } else { BeamBestSequenceFinder ti = new BeamBestSequenceFinder(flags.beamSize, true, true); tags = ti.bestSequence(ts, document.size()); } //System.err.println("***end test***"); // used to improve recall in task 1b if (flags.lowerNewgeneThreshold) { System.err.println("Using NEWGENE threshold: " + flags.newgeneThreshold); int[] copy = new int[tags.length]; System.arraycopy(tags, 0, copy, 0, tags.length); // for each sequence marked as NEWGENE in the gazette // tag the entire sequence as NEWGENE and sum the score // if the score is greater than newgeneThreshold, accept int ngTag = classIndex.indexOf("G"); //int bgTag = classIndex.indexOf(BACKGROUND); int bgTag = classIndex.indexOf(flags.backgroundSymbol); for (int i = 0, dSize = document.size(); i < dSize; i++) { CoreLabel wordInfo =document.get(i); if ("NEWGENE".equals(wordInfo.get(CoreAnnotations.GazAnnotation.class))) { int start = i; int j; for (j = i; j < document.size(); j++) { wordInfo = document.get(j); if (!"NEWGENE".equals(wordInfo.get(CoreAnnotations.GazAnnotation.class))) { break; } } int end = j; //int end = i + 1; int winStart = Math.max(0, start - 4); int winEnd = Math.min(tags.length, end + 4); // clear a window around the sequences for (j = winStart; j < winEnd; j++) { copy[j] = bgTag; } // score as nongene double bgScore = 0.0; for (j = start; j < end; j++) { double[] scores = ts.scoresOf(copy, j); scores = Scorer.recenter(scores); bgScore += scores[bgTag]; } // first pass, compute all of the scores ClassicCounter<Pair<Integer,Integer>> prevScores = new ClassicCounter<Pair<Integer,Integer>>(); for (j = start; j < end; j++) { // clear the sequence for (int k = start; k < end; k++) { copy[k] = bgTag; } // grow the sequence from j until the end for (int k = j; k < end; k++) { copy[k] = ngTag; // score the sequence double ngScore = 0.0; for (int m = start; m < end; m++) { double[] scores = ts.scoresOf(copy, m); scores = Scorer.recenter(scores); ngScore += scores[tags[m]]; } prevScores.incrementCount(new Pair<Integer,Integer>(Integer.valueOf(j), Integer.valueOf(k)), ngScore - bgScore); } } for (j = start; j < end; j++) { // grow the sequence from j until the end for (int k = j; k < end; k++) { double score = prevScores.getCount(new Pair<Integer,Integer>(Integer.valueOf(j), Integer.valueOf(k))); Pair<Integer, Integer> al = new Pair<Integer,Integer>(Integer.valueOf(j - 1), Integer.valueOf(k)); // adding a word to the left Pair<Integer, Integer> ar = new Pair<Integer,Integer>(Integer.valueOf(j), Integer.valueOf(k + 1)); // adding a word to the right Pair<Integer, Integer> sl = new Pair<Integer,Integer>(Integer.valueOf(j + 1), Integer.valueOf(k)); // subtracting word from left Pair<Integer, Integer> sr = new Pair<Integer,Integer>(Integer.valueOf(j), Integer.valueOf(k - 1)); // subtracting word from right // make sure the score is greater than all its neighbors (one add or subtract) if (score >= flags.newgeneThreshold && (!prevScores.containsKey(al) || score > prevScores.getCount(al)) && (!prevScores.containsKey(ar) || score > prevScores.getCount(ar)) && (!prevScores.containsKey(sl) || score > prevScores.getCount(sl)) && (!prevScores.containsKey(sr) || score > prevScores.getCount(sr))) { StringBuilder sb = new StringBuilder(); wordInfo = document.get(j); String docId = wordInfo.get(CoreAnnotations.IDAnnotation.class); String startIndex = wordInfo.get(CoreAnnotations.PositionAnnotation.class); wordInfo = document.get(k); String endIndex = wordInfo.get(CoreAnnotations.PositionAnnotation.class); for (int m = j; m <= k; m++) { wordInfo = document.get(m); sb.append(wordInfo.word()); sb.append(' '); } /*System.err.println(sb.toString()+"score:"+score+ " al:"+prevScores.getCount(al)+ " ar:"+prevScores.getCount(ar)+ " sl:"+prevScores.getCount(sl)+" sr:"+ prevScores.getCount(sr));*/ System.out.println(docId + '|' + startIndex + ' ' + endIndex + '|' + sb.toString().trim()); } } } // restore the original tags for (j = winStart; j < winEnd; j++) { copy[j] = tags[j]; } i = end; } } } for (int i = 0, docSize = document.size(); i < docSize; i++) { CoreLabel lineInfo = document.get(i); String answer = classIndex.get(tags[i]); lineInfo.set(CoreAnnotations.AnswerAnnotation.class, answer); } if (flags.justify && classifier instanceof LinearClassifier) { LinearClassifier<String, String> lc = (LinearClassifier<String, String>) classifier; if (flags.dump) { lc.dump(); } for (int i = 0, docSize = document.size(); i < docSize; i++) { CoreLabel lineInfo = document.get(i); System.err.print("@@ Position is: " + i + ": "); System.err.println(lineInfo.word() + ' ' + lineInfo.get(CoreAnnotations.AnswerAnnotation.class)); lc.justificationOf(makeDatum(document, i, featureFactories)); } } // document.remove(0); if (flags.useReverse) { Collections.reverse(document); } } // end testSeq /** * @param filename adaptation file * @param trainDataset original dataset (used in training) */ public void adapt(String filename, Dataset<String, String> trainDataset, DocumentReaderAndWriter<IN> readerWriter) { flags.ocrTrain = false; // ?? Do we need this? (Pi-Chuan Sat Nov 5 15:42:49 2005) ObjectBank<List<IN>> docs = makeObjectBankFromFile(filename, readerWriter); adapt(docs, trainDataset); } /** * @param featureLabels adaptation docs * @param trainDataset original dataset (used in training) */ public void adapt(ObjectBank<List<IN>> featureLabels, Dataset<String, String> trainDataset) { Dataset<String, String> adapt = getDataset(featureLabels, trainDataset); adapt(adapt); } /** * @param featureLabels retrain docs * @param featureIndex featureIndex of original dataset (used in training) * @param labelIndex labelIndex of original dataset (used in training) */ public void retrain(ObjectBank<List<IN>> featureLabels, Index<String> featureIndex, Index<String> labelIndex) { int fs = featureIndex.size(); // old dim int ls = labelIndex.size(); // old dim Dataset<String, String> adapt = getDataset(featureLabels, featureIndex, labelIndex); int prior = LogPrior.LogPriorType.QUADRATIC.ordinal(); LinearClassifier<String, String> lc = (LinearClassifier<String, String>) classifier; LinearClassifierFactory<String, String> lcf = new LinearClassifierFactory<String, String>(flags.tolerance, flags.useSum, prior, flags.sigma, flags.epsilon, flags.QNsize); double[][] weights = lc.weights(); // old dim Index<String> newF = adapt.featureIndex; Index<String> newL = adapt.labelIndex; int newFS = newF.size(); int newLS = newL.size(); double[] x = new double[newFS*newLS]; // new dim //System.err.println("old ["+fs+"]"+"["+ls+"]"); //System.err.println("new ["+newFS+"]"+"["+newLS+"]"); //System.err.println("new ["+newFS*newLS+"]"); for (int i = 0; i < fs; i++) { for (int j = 0; j < ls; j++) { String f = featureIndex.get(i); String l = labelIndex.get(j); int newi = newF.indexOf(f)*newLS+newL.indexOf(l); x[newi] = weights[i][j]; //if (newi == 144745*2) { //System.err.println("What??"+i+"\t"+j); //} } } //System.err.println("x[144745*2]"+x[144745*2]); weights = lcf.trainWeights(adapt, x); //System.err.println("x[144745*2]"+x[144745*2]); //System.err.println("weights[144745]"+"[0]="+weights[144745][0]); lc.setWeights(weights); /* int delme = 0; if (true) { for (double[] dd : weights) { delme++; for (double d : dd) { } } } System.err.println(weights[delme-1][0]); System.err.println("size of weights: "+delme); */ } public void retrain(ObjectBank<List<IN>> doc) { if (classifier == null) { throw new UnsupportedOperationException("Cannot retrain before you train!"); } Index<String> findex = ((LinearClassifier<String, String>)classifier).featureIndex(); Index<String> lindex = ((LinearClassifier<String, String>)classifier).labelIndex(); System.err.println("Starting retrain:\t# of original features"+findex.size()+", # of original labels"+lindex.size()); retrain(doc, findex, lindex); } @Override public void train(Collection<List<IN>> wordInfos, DocumentReaderAndWriter<IN> readerAndWriter) { Dataset<String, String> train = getDataset(wordInfos); //train.summaryStatistics(); //train.printSVMLightFormat(); // wordInfos = null; // cdm: I think this does no good as ptr exists in caller (could empty the list or better refactor so conversion done earlier?) train(train); for (int i = 0; i < flags.numTimesPruneFeatures; i++) { Index<String> featuresAboveThreshhold = getFeaturesAboveThreshhold(train, flags.featureDiffThresh); System.err.println("Removing features with weight below " + flags.featureDiffThresh + " and retraining..."); train = getDataset(train, featuresAboveThreshhold); int tmp = flags.QNsize; flags.QNsize = flags.QNsize2; train(train); flags.QNsize = tmp; } if (flags.doAdaptation && flags.adaptFile != null) { adapt(flags.adaptFile,train,readerAndWriter); } System.err.print("Built this classifier: "); if (classifier instanceof LinearClassifier) { String classString = ((LinearClassifier<String, String>)classifier).toString(flags.printClassifier, flags.printClassifierParam); System.err.println(classString); } else { String classString = classifier.toString(); System.err.println(classString); } } public Index<String> getFeaturesAboveThreshhold(Dataset<String, String> dataset, double thresh) { if (!(classifier instanceof LinearClassifier)) { throw new RuntimeException("Attempting to remove features based on weight from a non-linear classifier"); } Index<String> featureIndex = dataset.featureIndex; Index<String> labelIndex = dataset.labelIndex; Index<String> features = new HashIndex<String>(); Iterator<String> featureIt = featureIndex.iterator(); LinearClassifier<String, String> lc = (LinearClassifier<String, String>)classifier; LOOP: while (featureIt.hasNext()) { String f = featureIt.next(); Iterator<String> labelIt = labelIndex.iterator(); double smallest = Double.POSITIVE_INFINITY; double biggest = Double.NEGATIVE_INFINITY; while (labelIt.hasNext()) { String l = labelIt.next(); double weight = lc.weight(f, l); if (weight < smallest) { smallest = weight; } if (weight > biggest) { biggest = weight; } if (biggest - smallest > thresh) { features.add(f); continue LOOP; } } } return features; } /** * Build a Dataset from some data. Used for training a classifier. * * @param data This variable is a list of lists of CoreLabel. That is, * it is a collection of documents, each of which is represented * as a sequence of CoreLabel objects. * @return The Dataset which is an efficient encoding of the information * in a List of Datums */ public Dataset<String, String> getDataset(Collection<List<IN>> data) { return getDataset(data, null, null); } /** * Build a Dataset from some data. Used for training a classifier. * * By passing in extra featureIndex and classIndex, you can get a Dataset based on featureIndex and * classIndex * * @param data This variable is a list of lists of CoreLabel. That is, * it is a collection of documents, each of which is represented * as a sequence of CoreLabel objects. * @param classIndex if you want to get a Dataset based on featureIndex and * classIndex in an existing origDataset * @return The Dataset which is an efficient encoding of the information * in a List of Datums */ public Dataset<String, String> getDataset(Collection<List<IN>> data, Index<String> featureIndex, Index<String> classIndex) { makeAnswerArraysAndTagIndex(data); int size = 0; for (List<IN> doc : data) { size += doc.size(); } System.err.println("Making Dataset..."); Dataset<String, String> train; if (featureIndex != null && classIndex != null) { System.err.println("Using feature/class Index from existing Dataset..."); System.err.println("(This is used when getting Dataset from adaptation set. We want to make the index consistent.)"); //pichuan train = new Dataset<String, String>(size, featureIndex, classIndex); } else { train = new Dataset<String, String>(size); } for (List<IN> doc : data) { if (flags.useReverse) { Collections.reverse(doc); } for (int i = 0, dsize = doc.size(); i < dsize; i++) { Datum<String, String> d = makeDatum(doc, i, featureFactories); //CoreLabel fl = doc.get(i); train.add(d); } if (flags.useReverse) { Collections.reverse(doc); } } System.err.println("done."); if (flags.featThreshFile != null) { System.err.println("applying thresholds..."); List<Pair<Pattern, Integer>> thresh = getThresholds(flags.featThreshFile); train.applyFeatureCountThreshold(thresh); } else if (flags.featureThreshold > 1) { System.err.println("Removing Features with counts < " + flags.featureThreshold); train.applyFeatureCountThreshold(flags.featureThreshold); } train.summaryStatistics(); return train; } public Dataset<String, String> getBiasedDataset(ObjectBank<List<IN>> data, Index<String> featureIndex, Index<String> classIndex) { makeAnswerArraysAndTagIndex(data); Index<String> origFeatIndex = new HashIndex<String>(featureIndex.objectsList()); // mg2009: TODO: check int size = 0; for (List<IN> doc : data) { size += doc.size(); } System.err.println("Making Dataset..."); Dataset<String, String> train = new Dataset<String, String>(size, featureIndex, classIndex); for (List<IN> doc : data) { if (flags.useReverse) { Collections.reverse(doc); } for (int i = 0, dsize = doc.size(); i < dsize; i++) { Datum<String, String> d = makeDatum(doc, i, featureFactories); Collection<String> newFeats = new ArrayList<String>(); for (String f : d.asFeatures()) { if ( ! origFeatIndex.contains(f)) { newFeats.add(f); } } // System.err.println(d.label()+"\t"+d.asFeatures()+"\n\t"+newFeats); // d = new BasicDatum(newFeats, d.label()); train.add(d); } if (flags.useReverse) { Collections.reverse(doc); } } System.err.println("done."); if (flags.featThreshFile != null) { System.err.println("applying thresholds..."); List<Pair<Pattern, Integer>> thresh = getThresholds(flags.featThreshFile); train.applyFeatureCountThreshold(thresh); } else if (flags.featureThreshold > 1) { System.err.println("Removing Features with counts < " + flags.featureThreshold); train.applyFeatureCountThreshold(flags.featureThreshold); } train.summaryStatistics(); return train; } /** * Build a Dataset from some data. Used for training a classifier. * * By passing in an extra origDataset, you can get a Dataset based on featureIndex and * classIndex in an existing origDataset. * * @param data This variable is a list of lists of CoreLabel. That is, * it is a collection of documents, each of which is represented * as a sequence of CoreLabel objects. * @param origDataset if you want to get a Dataset based on featureIndex and * classIndex in an existing origDataset * @return The Dataset which is an efficient encoding of the information * in a List of Datums */ public Dataset<String, String> getDataset(ObjectBank<List<IN>> data, Dataset<String, String> origDataset) { if(origDataset == null) { return getDataset(data); } return getDataset(data, origDataset.featureIndex, origDataset.labelIndex); } /** * Build a Dataset from some data. * * @param oldData This {@link Dataset} represents data for which we which to * some features, specifically those features not in the {@link edu.stanford.nlp.util.Index} * goodFeatures. * @param goodFeatures An {@link edu.stanford.nlp.util.Index} of features we wish to retain. * @return A new {@link Dataset} wheres each data point contains only features * which were in goodFeatures. */ public Dataset<String, String> getDataset(Dataset<String, String> oldData, Index<String> goodFeatures) { //public Dataset getDataset(List data, Collection goodFeatures) { //makeAnswerArraysAndTagIndex(data); int[][] oldDataArray = oldData.getDataArray(); int[] oldLabelArray = oldData.getLabelsArray(); Index<String> oldFeatureIndex = oldData.featureIndex; int[] oldToNewFeatureMap = new int[oldFeatureIndex.size()]; int[][] newDataArray = new int[oldDataArray.length][]; System.err.print("Building reduced dataset..."); int size = oldFeatureIndex.size(); int max = 0; for (int i = 0; i < size; i++) { oldToNewFeatureMap[i] = goodFeatures.indexOf(oldFeatureIndex.get(i)); if (oldToNewFeatureMap[i] > max) { max = oldToNewFeatureMap[i]; } } for (int i = 0; i < oldDataArray.length; i++) { int[] data = oldDataArray[i]; size = 0; for (int oldF : data) { if (oldToNewFeatureMap[oldF] > 0) { size++; } } int[] newData = new int[size]; int index = 0; for (int oldF : data) { int f = oldToNewFeatureMap[oldF]; if (f > 0) { newData[index++] = f; } } newDataArray[i] = newData; } Dataset<String, String> train = new Dataset<String, String>(oldData.labelIndex, oldLabelArray, goodFeatures, newDataArray, newDataArray.length); System.err.println("done."); if (flags.featThreshFile != null) { System.err.println("applying thresholds..."); List<Pair<Pattern,Integer>> thresh = getThresholds(flags.featThreshFile); train.applyFeatureCountThreshold(thresh); } else if (flags.featureThreshold > 1) { System.err.println("Removing Features with counts < " + flags.featureThreshold); train.applyFeatureCountThreshold(flags.featureThreshold); } train.summaryStatistics(); return train; } private void adapt(Dataset<String, String> adapt) { if (flags.classifierType.equalsIgnoreCase("SVM")) { throw new UnsupportedOperationException(); } adaptMaxEnt(adapt); } private void adaptMaxEnt(Dataset<String, String> adapt) { if (classifier instanceof LinearClassifier) { // So far the adaptation is only done on Gaussian Prior. Haven't checked how it'll work on other kinds of priors. -pichuan int prior = LogPrior.LogPriorType.QUADRATIC.ordinal(); if (flags.useHuber) { throw new UnsupportedOperationException(); } else if (flags.useQuartic) { throw new UnsupportedOperationException(); } LinearClassifierFactory<String, String> lcf = new LinearClassifierFactory<String, String>(flags.tolerance, flags.useSum, prior, flags.adaptSigma, flags.epsilon, flags.QNsize); ((LinearClassifier<String, String>)classifier).adaptWeights(adapt,lcf); } else { throw new UnsupportedOperationException(); } } private void train(Dataset<String, String> train) { if (flags.classifierType.equalsIgnoreCase("SVM")) { trainSVM(train); } else { trainMaxEnt(train); } } private void trainSVM(Dataset<String, String> train) { SVMLightClassifierFactory<String, String> fact = new SVMLightClassifierFactory<String, String>(); classifier = fact.trainClassifier(train); } private void trainMaxEnt(Dataset<String, String> train) { int prior = LogPrior.LogPriorType.QUADRATIC.ordinal(); if (flags.useHuber) { prior = LogPrior.LogPriorType.HUBER.ordinal(); } else if (flags.useQuartic) { prior = LogPrior.LogPriorType.QUARTIC.ordinal(); } LinearClassifier<String, String> lc; if (flags.useNB) { lc = new NBLinearClassifierFactory<String, String>(flags.sigma).trainClassifier(train); } else { LinearClassifierFactory<String, String> lcf = new LinearClassifierFactory<String, String>(flags.tolerance, flags.useSum, prior, flags.sigma, flags.epsilon, flags.QNsize); if (flags.useQN) { lcf.useQuasiNewton(flags.useRobustQN); } else if(flags.useStochasticQN) { lcf.useStochasticQN(flags.initialGain,flags.stochasticBatchSize); } else if(flags.useSMD) { lcf.useStochasticMetaDescent(flags.initialGain, flags.stochasticBatchSize,flags.stochasticMethod,flags.SGDPasses); } else if(flags.useSGD) { lcf.useStochasticGradientDescent(flags.gainSGD,flags.stochasticBatchSize); } else if(flags.useSGDtoQN) { lcf.useStochasticGradientDescentToQuasiNewton(flags.initialGain, flags.stochasticBatchSize, flags.SGDPasses, flags.QNPasses, flags.SGD2QNhessSamples, flags.QNsize, flags.outputIterationsToFile); } else if(flags.useHybrid) { lcf.useHybridMinimizer(flags.initialGain, flags.stochasticBatchSize ,flags.stochasticMethod ,flags.hybridCutoffIteration ); } else { lcf.useConjugateGradientAscent(); } lc = lcf.trainClassifier(train); } this.classifier = lc; } private void trainSemiSup(Dataset<String, String> data, Dataset<String, String> biasedData, double[][] confusionMatrix) { int prior = LogPrior.LogPriorType.QUADRATIC.ordinal(); if (flags.useHuber) { prior = LogPrior.LogPriorType.HUBER.ordinal(); } else if (flags.useQuartic) { prior = LogPrior.LogPriorType.QUARTIC.ordinal(); } LinearClassifierFactory<String, String> lcf; lcf = new LinearClassifierFactory<String, String>(flags.tolerance, flags.useSum, prior, flags.sigma, flags.epsilon, flags.QNsize); if (flags.useQN) { lcf.useQuasiNewton(); } else{ lcf.useConjugateGradientAscent(); } this.classifier = (LinearClassifier<String, String>) lcf.trainClassifierSemiSup(data, biasedData, confusionMatrix, null); } // public void crossValidateTrainAndTest() throws Exception { // crossValidateTrainAndTest(flags.trainFile); // } // public void crossValidateTrainAndTest(String filename) throws Exception { // // wordshapes // for (int fold = flags.startFold; fold <= flags.endFold; fold++) { // System.err.println("fold " + fold + " of " + flags.endFold); // // train // List = makeObjectBank(filename); // List folds = split(data, flags.numFolds); // data = null; // List train = new ArrayList(); // for (int i = 0; i < flags.numFolds; i++) { // List docs = (List) folds.get(i); // if (i != fold) { // train.addAll(docs); // } // } // folds = null; // train(train); // train = null; // List test = new ArrayList(); // data = makeObjectBank(filename); // folds = split(data, flags.numFolds); // data = null; // for (int i = 0; i < flags.numFolds; i++) { // List docs = (List) folds.get(i); // if (i == fold) { // test.addAll(docs); // } // } // folds = null; // // test // test(test); // writeAnswers(test); // } // } // /** // * Splits the given train corpus into a train and a test corpus based on the fold number. // * 1 / numFolds documents are held out for test, with the offset determined by the fold number. // * // * @param data The original data // * @param numFolds The number of folds to split the data into // * @return A list of folds giving the new training set // */ // private List split(List data, int numFolds) { // List folds = new ArrayList(); // int foldSize = data.size() / numFolds; // int r = data.size() - (numFolds * foldSize); // int index = 0; // for (int i = 0; i < numFolds; i++) { // List fold = new ArrayList(); // int end = (i < r ? foldSize + 1 : foldSize); // for (int j = 0; j < end; j++) { // fold.add(data.get(index++)); // } // folds.add(fold); // } // return folds; // } @Override public void serializeClassifier(String serializePath) { System.err.print("Serializing classifier to " + serializePath + "..."); try { ObjectOutputStream oos = IOUtils.writeStreamFromString(serializePath); oos.writeObject(classifier); oos.writeObject(flags); oos.writeObject(featureFactories); oos.writeObject(classIndex); oos.writeObject(answerArrays); //oos.writeObject(WordShapeClassifier.getKnownLowerCaseWords()); oos.writeObject(knownLCWords); oos.close(); System.err.println("Done."); } catch (Exception e) { System.err.println("Error serializing to " + serializePath); e.printStackTrace(); } } /** * Used to load the default supplied classifier. **THIS FUNCTION * WILL ONLY WORK IF RUN INSIDE A JAR FILE** */ public void loadDefaultClassifier() { loadJarClassifier(DEFAULT_CLASSIFIER, null); } /** * Used to obtain the default classifier which is * stored inside a jar file. <i>THIS FUNCTION * WILL ONLY WORK IF RUN INSIDE A JAR FILE.</i> * * @return A Default CMMClassifier from a jar file */ public static CMMClassifier<? extends CoreLabel> getDefaultClassifier() { CMMClassifier<? extends CoreLabel> cmm = new CMMClassifier<CoreLabel>(); cmm.loadDefaultClassifier(); return cmm; } /** Load a classifier from the given Stream. * <i>Implementation note: </i> This method <i>does not</i> close the * Stream that it reads from. * * @param ois The ObjectInputStream to load the serialized classifier from * * @throws IOException If there are problems accessing the input stream * @throws ClassCastException If there are problems interpreting the serialized data * @throws ClassNotFoundException If there are problems interpreting the serialized data * */ @SuppressWarnings("unchecked") @Override public void loadClassifier(ObjectInputStream ois, Properties props) throws ClassCastException, IOException, ClassNotFoundException { classifier = (LinearClassifier<String, String>) ois.readObject(); flags = (SeqClassifierFlags) ois.readObject(); Object featureFactory = ois.readObject(); if (featureFactory instanceof List) { featureFactories = ErasureUtils.uncheckedCast(featureFactory); } else if (featureFactory instanceof FeatureFactory) { featureFactories = Generics.newArrayList(); featureFactories.add((FeatureFactory) featureFactory); } if (props != null) { flags.setProperties(props); } reinit(); classIndex = (Index<String>) ois.readObject(); answerArrays = (Set<List<String>>) ois.readObject(); knownLCWords = (Set<String>) ois.readObject(); } public static CMMClassifier<? extends CoreLabel> getClassifierNoExceptions(File file) { CMMClassifier<? extends CoreLabel> cmm = new CMMClassifier<CoreLabel>(); cmm.loadClassifierNoExceptions(file); return cmm; } public static CMMClassifier<? extends CoreLabel> getClassifier(File file) throws IOException, ClassCastException, ClassNotFoundException { CMMClassifier<? extends CoreLabel> cmm = new CMMClassifier<CoreLabel>(); cmm.loadClassifier(file); return cmm; } public static CMMClassifier<CoreLabel> getClassifierNoExceptions(String loadPath) { CMMClassifier<CoreLabel> cmm = new CMMClassifier<CoreLabel>(); cmm.loadClassifierNoExceptions(loadPath); return cmm; } public static CMMClassifier<? extends CoreLabel> getClassifier(String loadPath) throws IOException, ClassCastException, ClassNotFoundException { CMMClassifier<? extends CoreLabel> cmm = new CMMClassifier<CoreLabel>(); cmm.loadClassifier(loadPath); return cmm; } public static CMMClassifier<? extends CoreLabel> getClassifierNoExceptions(InputStream in) { CMMClassifier<? extends CoreLabel> cmm = new CMMClassifier<CoreLabel>(); cmm.loadClassifierNoExceptions(new BufferedInputStream(in), null); return cmm; } public static CMMClassifier<? extends CoreLabel> getClassifier(InputStream in) throws IOException, ClassCastException, ClassNotFoundException { CMMClassifier<? extends CoreLabel> cmm = new CMMClassifier<CoreLabel>(); cmm.loadClassifier(new BufferedInputStream(in)); return cmm; } /** This routine builds the <code>answerArrays</code> which give the * empirically legal label sequences (of length (order) at most * <code>flags.maxLeft</code>) and the <code>classIndex</code>, * which indexes known answer classes. * * @param docs The training data: A List of List of CoreLabel */ private void makeAnswerArraysAndTagIndex(Collection<List<IN>> docs) { if (answerArrays == null) { answerArrays = Generics.newHashSet(); } if (classIndex == null) { classIndex = new HashIndex<String>(); } for (List<IN> doc : docs) { if (flags.useReverse) { Collections.reverse(doc); } int leng = doc.size(); for (int start = 0; start < leng; start++) { for (int diff = 1; diff <= flags.maxLeft && start + diff <= leng; diff++) { String[] seq = new String[diff]; for (int i = start; i < start + diff; i++) { seq[i - start] = doc.get(i).get(CoreAnnotations.AnswerAnnotation.class); } answerArrays.add(Arrays.asList(seq)); } } for (int i = 0; i < leng; i++) { CoreLabel wordInfo = doc.get(i); classIndex.add(wordInfo.get(CoreAnnotations.AnswerAnnotation.class)); } if (flags.useReverse) { Collections.reverse(doc); } } } /** Make an individual Datum out of the data list info, focused at position * loc. * @param info A List of IN objects * @param loc The position in the info list to focus feature creation on * @param featureFactory The factory that constructs features out of the item * @return A Datum (BasicDatum) representing this data instance */ public Datum<String, String> makeDatum(List<IN> info, int loc, List<FeatureFactory<IN>> featureFactories) { PaddedList<IN> pInfo = new PaddedList<IN>(info, pad); Collection<String> features = new ArrayList<String>(); for (FeatureFactory featureFactory : featureFactories) { List<Clique> cliques = featureFactory.getCliques(); for (Clique c : cliques) { Collection<String> feats = featureFactory.getCliqueFeatures(pInfo, loc, c); feats = addOtherClasses(feats, pInfo, loc, c); features.addAll(feats); } } printFeatures(pInfo.get(loc), features); CoreLabel c = info.get(loc); return new BasicDatum<String, String>(features, c.get(CoreAnnotations.AnswerAnnotation.class)); } /** This adds to the feature name the name of classes that are other than * the current class that are involved in the clique. In the CMM, these * other classes become part of the conditioning feature, and only the * class of the current position is being predicted. * * @return A collection of features with extra class information put * into the feature name. */ private static Collection<String> addOtherClasses(Collection<String> feats, List<? extends CoreLabel> info, int loc, Clique c) { String addend = null; String pAnswer = info.get(loc - 1).get(CoreAnnotations.AnswerAnnotation.class); String p2Answer = info.get(loc - 2).get(CoreAnnotations.AnswerAnnotation.class); String p3Answer = info.get(loc - 3).get(CoreAnnotations.AnswerAnnotation.class); String p4Answer = info.get(loc - 4).get(CoreAnnotations.AnswerAnnotation.class); String p5Answer = info.get(loc - 5).get(CoreAnnotations.AnswerAnnotation.class); String nAnswer = info.get(loc + 1).get(CoreAnnotations.AnswerAnnotation.class); // cdm 2009: Is this really right? Do we not need to differentiate names that would collide??? if (c == FeatureFactory.cliqueCpC) { addend = '|' + pAnswer; } else if (c == FeatureFactory.cliqueCp2C) { addend = '|' + p2Answer; } else if (c == FeatureFactory.cliqueCp3C) { addend = '|' + p3Answer; } else if (c == FeatureFactory.cliqueCp4C) { addend = '|' + p4Answer; } else if (c == FeatureFactory.cliqueCp5C) { addend = '|' + p5Answer; } else if (c == FeatureFactory.cliqueCpCp2C) { addend = '|' + pAnswer + '-' + p2Answer; } else if (c == FeatureFactory.cliqueCpCp2Cp3C) { addend = '|' + pAnswer + '-' + p2Answer + '-' + p3Answer; } else if (c == FeatureFactory.cliqueCpCp2Cp3Cp4C) { addend = '|' + pAnswer + '-' + p2Answer + '-' + p3Answer + '-' + p4Answer; } else if (c == FeatureFactory.cliqueCpCp2Cp3Cp4Cp5C) { addend = '|' + pAnswer + '-' + p2Answer + '-' + p3Answer + '-' + p4Answer + '-' + p5Answer; } else if (c == FeatureFactory.cliqueCnC) { addend = '|' + nAnswer; } else if (c == FeatureFactory.cliqueCpCnC) { addend = '|' + pAnswer + '-' + nAnswer; } if (addend == null) { return feats; } Collection<String> newFeats = Generics.newHashSet(); for (String feat : feats) { String newFeat = feat + addend; newFeats.add(newFeat); } return newFeats; } private static List<Pair<Pattern, Integer>> getThresholds(String filename) { try { BufferedReader in = new BufferedReader(new FileReader(filename)); List<Pair<Pattern, Integer>> thresholds = new ArrayList<Pair<Pattern, Integer>>(); String line; while ((line = in.readLine()) != null) { int i = line.lastIndexOf(' '); Pattern p = Pattern.compile(line.substring(0, i)); //System.err.println(":"+line.substring(0,i)+":"); Integer t = Integer.valueOf(line.substring(i + 1)); Pair<Pattern, Integer> pair = new Pair<Pattern, Integer>(p, t); thresholds.add(pair); } in.close(); return thresholds; } catch (Exception e) { throw new RuntimeException("Error reading threshold file", e); } } public void trainSemiSup() { DocumentReaderAndWriter<IN> readerAndWriter = makeReaderAndWriter(); String filename = flags.trainFile; String biasedFilename = flags.biasedTrainFile; ObjectBank<List<IN>> data = makeObjectBankFromFile(filename, readerAndWriter); ObjectBank<List<IN>> biasedData = makeObjectBankFromFile(biasedFilename, readerAndWriter); Index<String> featureIndex = new HashIndex<String>(); Index<String> classIndex = new HashIndex<String>(); Dataset<String, String> dataset = getDataset(data, featureIndex, classIndex); Dataset<String, String> biasedDataset = getBiasedDataset(biasedData, featureIndex, classIndex); double[][] confusionMatrix = new double[classIndex.size()][classIndex.size()]; for (int i = 0; i < confusionMatrix.length; i++) { Arrays.fill(confusionMatrix[i], 0.0); confusionMatrix[i][i] = 1.0; } String cm = flags.confusionMatrix; String[] bits = cm.split(":"); for (String bit : bits) { String[] bits1 = bit.split("\\|"); int i1 = classIndex.indexOf(bits1[0]); int i2 = classIndex.indexOf(bits1[1]); double d = Double.parseDouble(bits1[2]); confusionMatrix[i2][i1] = d; } for (double[] row : confusionMatrix) { ArrayMath.normalize(row); } for (int i = 0; i < confusionMatrix.length; i++) { for (int j = 0; j < i; j++) { double d = confusionMatrix[i][j]; confusionMatrix[i][j] = confusionMatrix[j][i]; confusionMatrix[j][i] = d; } } for (int i = 0; i < confusionMatrix.length; i++) { for (int j = 0; j < confusionMatrix.length; j++) { System.err.println("P("+classIndex.get(j)+ '|' +classIndex.get(i)+") = "+confusionMatrix[j][i]); } } trainSemiSup(dataset, biasedDataset, confusionMatrix); } static class Scorer<INN extends CoreLabel> implements SequenceModel { private CMMClassifier<INN> classifier = null; private int[] tagArray = null; private int[] backgroundTags = null; private Index<String> tagIndex = null; private List<INN> lineInfos = null; private int pre = 0; private int post = 0; private Set<List<String>> legalTags = null; private static final boolean VERBOSE = false; void buildTagArray() { int sz = tagIndex.size(); tagArray = new int[sz]; for (int i = 0; i < sz; i++) { tagArray[i] = i; } } @Override public int length() { return lineInfos.size() - pre - post; } @Override public int leftWindow() { return pre; } @Override public int rightWindow() { return post; } @Override public int[] getPossibleValues(int position) { // if (position == 0 || position == lineInfos.size() - 1) { // int[] a = new int[1]; // a[0] = tagIndex.indexOf(BACKGROUND); // return a; // } if (tagArray == null) { buildTagArray(); } if (position < pre) { return backgroundTags; } return tagArray; } @Override public double scoreOf(int[] sequence) { throw new UnsupportedOperationException(); } private double[] scoreCache = null; private int[] lastWindow = null; //private int lastPos = -1; @Override public double scoreOf(int[] tags, int pos) { if (false) { return scoresOf(tags, pos)[tags[pos]]; } if (lastWindow == null) { lastWindow = new int[leftWindow() + rightWindow() + 1]; Arrays.fill(lastWindow, -1); } boolean match = (pos == lastPos); for (int i = pos - leftWindow(); i <= pos + rightWindow(); i++) { if (i == pos || i < 0) { continue; } /*System.err.println("p:"+pos); System.err.println("lw:"+leftWindow()); System.err.println("i:"+i);*/ match &= tags[i] == lastWindow[i - pos + leftWindow()]; } if (!match) { scoreCache = scoresOf(tags, pos); for (int i = pos - leftWindow(); i <= pos + rightWindow(); i++) { if (i < 0) { continue; } lastWindow[i - pos + leftWindow()] = tags[i]; } lastPos = pos; } return scoreCache[tags[pos]]; } private int percent = -1; private int num = 0; private long secs = System.currentTimeMillis(); private long hit = 0; private long tot = 0; @Override public double[] scoresOf(int[] tags, int pos) { if (VERBOSE) { int p = (100 * pos) / length(); if (p > percent) { long secs2 = System.currentTimeMillis(); System.err.println(StringUtils.padLeft(p, 3) + "% " + ((secs2 - secs == 0) ? 0 : (num * 1000 / (secs2 - secs))) + " hits per sec, position=" + pos + ", legal=" + ((tot == 0) ? 100 : ((100 * hit) / tot))); // + "% [hit=" + hit + ", tot=" + tot + "]"); percent = p; num = 0; secs = secs2; } tot++; } String[] answers = new String[1 + leftWindow() + rightWindow()]; String[] pre = new String[leftWindow()]; for (int i = 0; i < 1 + leftWindow() + rightWindow(); i++) { int absPos = pos - leftWindow() + i; if (absPos < 0) { continue; } answers[i] = tagIndex.get(tags[absPos]); CoreLabel li = lineInfos.get(absPos); li.set(CoreAnnotations.AnswerAnnotation.class, answers[i]); if (i < leftWindow()) { pre[i] = answers[i]; } } double[] scores = new double[tagIndex.size()]; //System.out.println("Considering: "+Arrays.asList(pre)); if (!legalTags.contains(Arrays.asList(pre)) && classifier.flags.useObservedSequencesOnly) { // System.out.println("Rejecting: " + Arrays.asList(pre)); // System.out.println(legalTags); Arrays.fill(scores, -1000);// Double.NEGATIVE_INFINITY; return scores; } num++; hit++; Counter<String> c = classifier.scoresOf(lineInfos, pos); //System.out.println("Pos "+pos+" hist "+Arrays.asList(pre)+" result "+c); //System.out.println(c); //if (false && flags.justify) { // System.out.println("Considering position " + pos + ", word is " + ((CoreLabel) lineInfos.get(pos)).word()); // //System.out.println("Datum is "+d.asFeatures()); // System.out.println("History: " + Arrays.asList(pre)); //} for (String s : c.keySet()) { int t = tagIndex.indexOf(s); if (t > -1) { int[] tA = getPossibleValues(pos); for (int j = 0; j < tA.length; j++) { if (tA[j] == t) { scores[j] = c.getCount(s); //if (false && flags.justify) { // System.out.println("Label " + s + " got score " + scores[j]); //} } } } } // normalize? if (classifier.normalize()) { ArrayMath.logNormalize(scores); } return scores; } static double[] recenter(double[] x) { double[] r = new double[x.length]; // double logTotal = Double.NEGATIVE_INFINITY; // for (int i = 0; i < x.length; i++) // logTotal = SloppyMath.logAdd(logTotal, x[i]); double logTotal = ArrayMath.logSum(x); for (int i = 0; i < x.length; i++) { r[i] = x[i] - logTotal; } return r; } /** * Build a Scorer. * * @param lineInfos List of INN data items to classify * @param classifier The trained Classifier * @param pre Number of previous tags that condition current tag * @param post Number of following tags that condition previous tag * (if pre and post are both nonzero, then you have a * dependency network tagger) */ Scorer(List<INN> lineInfos, Index<String> tagIndex, CMMClassifier<INN> classifier, int pre, int post, Set<List<String>> legalTags) { if (VERBOSE) { System.err.println("Built Scorer for " + lineInfos.size() + " words, clique pre=" + pre + " post=" + post); } this.pre = pre; this.post = post; this.lineInfos = lineInfos; this.tagIndex = tagIndex; this.classifier = classifier; this.legalTags = legalTags; backgroundTags = new int[]{tagIndex.indexOf(classifier.flags.backgroundSymbol)}; } } // end class Scorer private boolean normalize() { return flags.normalize; } static int lastPos = -1; public Counter<String> scoresOf(List<IN> lineInfos, int pos) { // if (pos != lastPos) { // System.err.print(pos+"."); // lastPos = pos; // } // System.err.print("!"); Datum<String, String> d = makeDatum(lineInfos, pos, featureFactories); return classifier.logProbabilityOf(d); } /** * Takes a {@link List} of {@link CoreLabel}s and prints the likelihood * of each possible label at each point. * TODO: Finish or delete this method! * * @param document A {@link List} of {@link CoreLabel}s. */ @Override public void printProbsDocument(List<IN> document) { //ClassicCounter<String> c = scoresOf(document, 0); } /** Command-line version of the classifier. See the class * comments for examples of use, and SeqClassifierFlags * for more information on supported flags. */ public static void main(String[] args) throws Exception { StringUtils.printErrInvocationString("CMMClassifier", args); Properties props = StringUtils.argsToProperties(args); CMMClassifier<CoreLabel> cmm = new CMMClassifier<CoreLabel>(props); String testFile = cmm.flags.testFile; String textFile = cmm.flags.textFile; String loadPath = cmm.flags.loadClassifier; String serializeTo = cmm.flags.serializeTo; // cmm.crossValidateTrainAndTest(trainFile); if (loadPath != null) { cmm.loadClassifierNoExceptions(loadPath, props); } else if (cmm.flags.loadJarClassifier != null) { cmm.loadJarClassifier(cmm.flags.loadJarClassifier, props); } else if (cmm.flags.trainFile != null) { if (cmm.flags.biasedTrainFile != null) { cmm.trainSemiSup(); } else { cmm.train(); } } else { cmm.loadDefaultClassifier(); } if (serializeTo != null) { cmm.serializeClassifier(serializeTo); } if (testFile != null) { cmm.classifyAndWriteAnswers(testFile, cmm.makeReaderAndWriter()); } else if (cmm.flags.testFiles != null) { cmm.classifyAndWriteAnswers(cmm.flags.baseTestDir, cmm.flags.testFiles, cmm.makeReaderAndWriter()); } if (textFile != null) { DocumentReaderAndWriter<CoreLabel> readerAndWriter = new PlainTextDocumentReaderAndWriter<CoreLabel>(); cmm.classifyAndWriteAnswers(textFile, readerAndWriter); } } // end main public double weight(String feature, String label) { return ((LinearClassifier<String, String>)classifier).weight(feature, label); } public double[][] weights() { return ((LinearClassifier<String, String>)classifier).weights(); } @Override public List<IN> classifyWithGlobalInformation(List<IN> tokenSeq, final CoreMap doc, final CoreMap sent) { return classify(tokenSeq); } } // end class CMMClassifier