// CRFClassifier -- a probabilistic (CRF) sequence model, mainly used for NER.
// Copyright (c) 2002-2008 The Board of Trustees of
// The Leland Stanford Junior University. All Rights Reserved.
//
// This program is free software; you can redistribute it and/or
// modify it under the terms of the GNU General Public License
// as published by the Free Software Foundation; either version 2
// of the License, or (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program; if not, write to the Free Software
// Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
//
// For more information, bug reports, fixes, contact:
// Christopher Manning
// Dept of Computer Science, Gates 1A
// Stanford CA 94305-9010
// USA
// Support/Questions: java-nlp-user@lists.stanford.edu
// Licensing: java-nlp-support@lists.stanford.edu
// http://nlp.stanford.edu/downloads/crf-classifier.shtml
package edu.stanford.nlp.ie.crf;
import edu.stanford.nlp.ie.*;
import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.ling.CoreAnnotations.AnswerAnnotation;
import edu.stanford.nlp.ling.CoreAnnotations.WordAnnotation;
import edu.stanford.nlp.ling.CoreLabel;
import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.maxent.Convert;
import edu.stanford.nlp.objectbank.ObjectBank;
import edu.stanford.nlp.optimization.*;
import edu.stanford.nlp.sequences.*;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.PaddedList;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.StringUtils;
import java.io.*;
import java.lang.reflect.InvocationTargetException;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.*;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;
/**
* Class for Sequence Classification using a Conditional Random Field model.
* The code has functionality for different document formats, but when
* using the standard {@link ColumnDocumentReaderAndWriter} for training
* or testing models, input files are expected to
* be one token per line with the columns indicating things like the word,
* POS, chunk, and answer class. The default for
* <code>ColumnDocumentReaderAndWriter</code> training data is 3 column input,
* with the columns containing a word, its POS, and its gold class, but
* this can be specified via the <code>map</code> property.
* <p/>
* When run on a file with <code>-textFile</code>,
* the file is assumed to be plain English text (or perhaps simple HTML/XML),
* and a reasonable attempt is made at English tokenization by
* {@link PlainTextDocumentReaderAndWriter}.
* <p/>
* <b>Typical command-line usage</b>
* <p>For running a trained model with a provided serialized classifier on a
* text file: <p>
* <code>
* java -mx500m edu.stanford.nlp.ie.crf.CRFClassifier -loadClassifier
* conll.ner.gz -textFile samplesentences.txt
* </code><p>
* When specifying all parameters in a properties file (train, test, or
* runtime):<p>
* <code>
* java -mx1g edu.stanford.nlp.ie.crf.CRFClassifier -prop propFile
* </code><p>
* To train and test a simple NER model from the command line:<p>
* <code>java -mx1000m edu.stanford.nlp.ie.crf.CRFClassifier
* -trainFile trainFile -testFile testFile -macro > output </code>
* <p/>
* Features are defined by a {@link edu.stanford.nlp.sequences.FeatureFactory}.
* {@link NERFeatureFactory} is used by default, and
* you should look there for feature templates and properties or flags that
* will cause certain features to be used when training an NER classifier.
* There is also
* a {@link edu.stanford.nlp.wordseg.SighanFeatureFactory}, and various
* successors such as
* {@link edu.stanford.nlp.wordseg.ChineseSegmenterFeatureFactory},
* which are used for Chinese word segmentation.
* Features are specified either by a Properties file (which is the
* recommended method) or by flags on the command line. The flags are read
* into a {@link SeqClassifierFlags} object, which the
* user need not be concerned with, unless wishing to add new features.
* <p/>
* CRFClassifier may also be used programatically. When creating a new
* instance, you <i>must</i> specify a Properties object. You may then
* call train methods to train a classifier, or load a classifier.
* The other way to get a CRFClassifier is to deserialize one via
* the static {@link CRFClassifier#getClassifier(String)} methods, which
* return a deserialized
* classifier. You may then tag (classify the items of) documents
* using either the assorted
* <code>classify()</code> or the assorted <code>classify</code> methods in
* {@link AbstractSequenceClassifier}.
* Probabilities assigned by the CRF can be interrogated using either the
* <code>printProbsDocument()</code> or
* <code>getCliqueTrees()</code> methods.
*
* @author Jenny Finkel
*/
public class CRFClassifier extends AbstractSequenceClassifier {
Index<CRFLabel>[] labelIndices;
/** Parameter weights of the classifier. */
double[][] weights;
Index<String> featureIndex;
int[] map; // caches the featureIndex
Minimizer minimizer;
/** Name of default serialized classifier resource to look for in a jar file.
*/
public static final String DEFAULT_CLASSIFIER = "ner-eng-ie.crf-3-all2008.ser.gz";
private static final boolean VERBOSE = false;
// List selftraindatums = new ArrayList();
protected CRFClassifier() {
super(new SeqClassifierFlags());
}
public CRFClassifier(Properties props) {
super(props);
}
public void dropFeaturesBelowThreshold(double threshold) {
Index<String> newFeatureIndex = new Index<String>();
for (int i = 0; i < weights.length; i++) {
double smallest = weights[i][0];
double biggest = weights[i][0];
for (int j = 1; j < weights[i].length; j++) {
if (weights[i][j] > biggest) {
biggest = weights[i][j];
}
if (weights[i][j] < smallest) {
smallest = weights[i][j];
}
if (biggest - smallest > threshold) {
newFeatureIndex.add(featureIndex.get(i));
break;
}
}
}
int[] newMap = new int[newFeatureIndex.size()];
for (int i = 0; i < newMap.length; i++) {
int index = featureIndex.indexOf(newFeatureIndex.get(i));
newMap[i] = map[index];
}
map = newMap;
featureIndex = newFeatureIndex;
}
/**
* Convert a document List into arrays storing the data features and labels.
*
* @param document Training documents
* @return A Pair, where the first element is an int[][][] representing the data
* and the second element is an int[] representing the labels
*/
public Pair<int[][][],int[]> documentToDataAndLabels(List<? extends CoreLabel> document) {
int docSize = document.size();
// first index is position in the document also the index of the clique/factor table
// second index is the number of elements in the clique/window these features are for (starting with last element)
// third index is position of the feature in the array that holds them
// element in data[j][k][m] is the index of the mth feature occurring in position k of the jth clique
int[][][] data = new int[docSize][windowSize][];
// index is the position in the document
// element in labels[j] is the index of the correct label (if it exists) at position j of document
int[] labels = new int[docSize];
if (flags.useReverse) {
Collections.reverse(document);
}
//System.err.println("docSize:"+docSize);
for (int j = 0; j < docSize; j++) {
CRFDatum d = makeDatum(document, j, featureFactory);
List features = d.asFeatures();
for (int k = 0, fSize = features.size(); k < fSize; k++) {
Collection<String> cliqueFeatures = (Collection<String>) features.get(k);
data[j][k] = new int[cliqueFeatures.size()];
int m = 0;
for (String feature : cliqueFeatures) {
int index = featureIndex.indexOf(feature);
if (index >= 0) {
data[j][k][m] = index;
m++;
} else {
// this is where we end up when we do feature threshhold cutoffs
}
}
// Reduce memory use when some feaures were cut out by threshold
if (m < data[j][k].length) {
int[] f = new int[m];
System.arraycopy(data[j][k], 0, f, 0, m);
data[j][k] = f;
}
}
CoreLabel wi = document.get(j);
labels[j] = classIndex.indexOf(wi.get(AnswerAnnotation.class));
}
if (flags.useReverse) {
Collections.reverse(document);
}
// System.err.println("numClasses: "+classIndex.size()+" "+classIndex);
// System.err.println("numDocuments: 1");
// System.err.println("numDatums: "+data.length);
// System.err.println("numFeatures: "+featureIndex.size());
return new Pair<int[][][],int[]>(data, labels);
}
public void printLabelInformation(String testFile) throws Exception {
ObjectBank<List<CoreLabel>> documents = makeObjectBankFromFile(testFile);
for (List<CoreLabel> document : documents) {
printLabelValue(document);
}
}
public void printLabelValue(List<CoreLabel> document) {
if (flags.useReverse) {
Collections.reverse(document);
}
NumberFormat nf = new DecimalFormat();
List<String> classes = new ArrayList<String>();
for (int i = 0; i < classIndex.size(); i++) {
classes.add(classIndex.get(i));
}
String[] columnHeaders = classes.toArray(new String[classes.size()]);
//System.err.println("docSize:"+docSize);
for (int j = 0; j < document.size(); j++) {
System.out.println("--== "+document.get(j).get(WordAnnotation.class)+" ==--");
List<String[]> lines = new ArrayList<String[]>();
List<String> rowHeaders = new ArrayList<String>();
List<String> line = new ArrayList<String>();
for (int p = 0; p < labelIndices.length; p++) {
if (j+p >= document.size()) { continue; }
CRFDatum d = makeDatum(document, j+p, featureFactory);
List features = d.asFeatures();
for (int k = p, fSize = features.size(); k < fSize; k++) {
Collection<String> cliqueFeatures = (Collection<String>) features.get(k);
for (String feature : cliqueFeatures) {
int index = featureIndex.indexOf(feature);
if (index >= 0) {
// line.add(feature+"["+(-p)+"]");
rowHeaders.add(feature+ '[' + (-p) + ']');
double[] values = new double[labelIndices[0].size()];
for (CRFLabel label : labelIndices[k]) {
int[] l = label.getLabel();
double v = weights[index][labelIndices[k].indexOf(label)];
values[l[l.length-1-p]] += v;
}
for (double value : values) {
line.add(nf.format(value));
}
lines.add(line.toArray(new String[line.size()]));
line = new ArrayList<String>();
}
}
}
// lines.add(Collections.<String>emptyList());
System.out.println(StringUtils.makeAsciiTable(lines.toArray(new String[lines.size()][0]),
rowHeaders.toArray(new String[rowHeaders.size()]),
columnHeaders,
0, 1, true));
System.out.println();
}
// System.err.println(edu.stanford.nlp.util.StringUtils.join(lines,"\n"));
}
if (flags.useReverse) {
Collections.reverse(document);
}
}
/** Convert an ObjectBank to arrays of data features and labels.
*
* @param documents
* @return A Pair, where the first element is an int[][][][] representing the data
* and the second element is an int[][] representing the labels.
*/
public Pair<int[][][][],int[][]> documentsToDataAndLabels(ObjectBank<List<CoreLabel>> documents) {
// first index is the number of the document
// second index is position in the document also the index of the clique/factor table
// third index is the number of elements in the clique/window thase features are for (starting with last element)
// fourth index is position of the feature in the array that holds them
// element in data[i][j][k][m] is the index of the mth feature occurring in position k of the jth clique of the ith document
// int[][][][] data = new int[documentsSize][][][];
List<int[][][]> data = new ArrayList<int[][][]>();
// first index is the number of the document
// second index is the position in the document
// element in labels[i][j] is the index of the correct label (if it exists) at position j in document i
// int[][] labels = new int[documentsSize][];
List<int[]> labels = new ArrayList<int[]>();
int numDatums = 0;
for (List<CoreLabel> doc : documents) {
Pair<int[][][],int[]> docPair = documentToDataAndLabels(doc);
data.add(docPair.first());
labels.add(docPair.second());
numDatums += doc.size();
}
System.err.println("numClasses: " + classIndex.size() + " " + classIndex);
System.err.println("numDocuments: " + data.size());
System.err.println("numDatums: " + numDatums);
System.err.println("numFeatures: " + featureIndex.size());
printFeatures();
int[][][][] dataA = new int[0][][][];
int[][] labelsA = new int[0][];
return new Pair<int[][][][],int[][]>(data.toArray(dataA), labels.toArray(labelsA));
}
private void printFeatures() {
if (flags.printFeatures == null) {
return;
}
try {
String enc = flags.inputEncoding;
if (flags.inputEncoding == null) {
System.err.println("flags.inputEncoding doesn't exist, Use UTF-8 as default");
enc = "UTF-8";
}
PrintWriter pw = new PrintWriter(new OutputStreamWriter(
new FileOutputStream("feats-" + flags.printFeatures + ".txt"), enc), true);
for (int i = 0; i < featureIndex.size(); i++) {
pw.println(featureIndex.get(i));
}
pw.close();
} catch (IOException ioe) {
ioe.printStackTrace();
}
}
/** This routine builds the <code>labelIndices</code> which give the
* empirically legal label sequences (of length (order) at most
* <code>windowSize</code>)
* and the <code>classIndex</code>,
* which indexes known answer classes.
*
* @param ob The training data: Read from an ObjectBank, each
* item in it is a List<CoreLabel>.
*/
private void makeAnswerArraysAndTagIndex(ObjectBank<List<CoreLabel>> ob) {
HashSet<String>[] featureIndices = new HashSet[windowSize];
for (int i = 0; i < windowSize; i++) {
featureIndices[i] = new HashSet<String>();
}
labelIndices = new Index[windowSize];
for (int i = 0; i < labelIndices.length; i++) {
labelIndices[i] = new Index<CRFLabel>();
}
Index<CRFLabel> labelIndex = labelIndices[windowSize - 1];
classIndex = new Index<String>();
//classIndex.add("O");
classIndex.add(flags.backgroundSymbol);
HashSet[] seenBackgroundFeatures = new HashSet[2];
seenBackgroundFeatures[0] = new HashSet();
seenBackgroundFeatures[1] = new HashSet();
//int count = 0;
for (List<CoreLabel> doc : ob) {
//if (count % 100 == 0) {
//System.err.println(count);
//}
//count++;
if (flags.useReverse) {
Collections.reverse(doc);
}
int docSize = doc.size();
//create the full set of labels in classIndex
//note: update to use addAll later
for (int j = 0; j < docSize; j++) {
String ans = doc.get(j).get(AnswerAnnotation.class);
classIndex.add(ans);
}
for (int j = 0; j < docSize; j++) {
CRFDatum<Serializable,CRFLabel> d = makeDatum(doc, j, featureFactory);
labelIndex.add(d.label());
List<Serializable> features = d.asFeatures();
for (int k = 0, fsize = features.size(); k < fsize; k++) {
Collection<String> cliqueFeatures = (Collection<String>) features.get(k);
if (k < 2 && flags.removeBackgroundSingletonFeatures) {
String ans = doc.get(j).get(AnswerAnnotation.class);
boolean background = ans.equals(flags.backgroundSymbol);
if (k == 1 && j > 0 && background) {
ans = doc.get(j - 1).get(AnswerAnnotation.class);
background = ans.equals(flags.backgroundSymbol);
}
if (background) {
for (String f : cliqueFeatures) {
if (!featureIndices[k].contains(f)) {
if (seenBackgroundFeatures[k].contains(f)) {
seenBackgroundFeatures[k].remove(f);
featureIndices[k].add(f);
} else {
seenBackgroundFeatures[k].add(f);
}
}
}
} else {
seenBackgroundFeatures[k].removeAll(cliqueFeatures);
featureIndices[k].addAll(cliqueFeatures);
}
} else {
featureIndices[k].addAll(cliqueFeatures);
}
}
}
if (flags.useReverse) {
Collections.reverse(doc);
}
}
// String[] fs = new String[featureIndices[0].size()];
// for (Iterator iter = featureIndices[0].iterator(); iter.hasNext(); ) {
// System.err.println(iter.next());
// }
int numFeatures = 0;
for (int i = 0; i < windowSize; i++) {
numFeatures += featureIndices[i].size();
}
featureIndex = new Index<String>();
map = new int[numFeatures];
for (int i = 0; i < windowSize; i++) {
featureIndex.addAll(featureIndices[i]);
for (String str : featureIndices[i]) {
map[featureIndex.indexOf(str)] = i;
}
}
if (flags.useObservedSequencesOnly) {
for (int i = 0, liSize = labelIndex.size(); i < liSize; i++) {
CRFLabel label = labelIndex.get(i);
for (int j = windowSize - 2; j >= 0; j--) {
label = label.getOneSmallerLabel();
labelIndices[j].add(label);
}
}
} else {
for (int i = 0; i < labelIndices.length; i++) {
labelIndices[i] = allLabels(i + 1, classIndex);
}
}
if (VERBOSE) {
for (int i = 0, fiSize = featureIndex.size(); i < fiSize; i++) {
System.out.println(i + ": " + featureIndex.get(i));
}
}
}
protected static Index<CRFLabel> allLabels(int window, Index classIndex) {
int[] label = new int[window];
// cdm july 2005: below array initialization isn't necessary: JLS (3rd ed.) 4.12.5
// Arrays.fill(label, 0);
int numClasses = classIndex.size();
Index<CRFLabel> labelIndex = new Index<CRFLabel>();
OUTER: while (true) {
CRFLabel l = new CRFLabel(label);
labelIndex.add(l);
int[] label1 = new int[window];
System.arraycopy(label, 0, label1, 0, label.length);
label = label1;
for (int j = 0; j < label.length; j++) {
label[j]++;
if (label[j] >= numClasses) {
label[j] = 0;
if (j == label.length - 1) {
break OUTER;
}
} else {
break;
}
}
}
return labelIndex;
}
/** Makes a CRFDatum by producing features and a label from input data
* at a specific position, using the provided factory.
* @param info The input data
* @param loc The position to build a datum at
* @param featureFactory The FeatureFactory to use to extract features
* @return The constructed CRFDatum
*/
public CRFDatum<Serializable,CRFLabel> makeDatum(List<? extends CoreLabel> info, int loc, edu.stanford.nlp.sequences.FeatureFactory featureFactory) {
pad.set(AnswerAnnotation.class, flags.backgroundSymbol);
PaddedList<? extends CoreLabel> pInfo = new PaddedList<CoreLabel>((List<CoreLabel>)info, pad);
ArrayList features = new ArrayList();
// for (int i = 0; i < windowSize; i++) {
// List featuresC = new ArrayList();
// for (int j = 0; j < FeatureFactory.win[i].length; j++) {
// featuresC.addAll(featureFactory.features(info, loc, FeatureFactory.win[i][j]));
// }
// features.add(featuresC);
// }
Collection<Clique> done = new HashSet<Clique>();
for (int i = 0; i < windowSize; i++) {
List featuresC = new ArrayList();
Collection<Clique> windowCliques = featureFactory.getCliques(i, 0);
windowCliques.removeAll(done);
done.addAll(windowCliques);
for (Clique c : windowCliques) {
featuresC.addAll(featureFactory.getCliqueFeatures(pInfo, loc, c));
}
features.add(featuresC);
}
int[] labels = new int[windowSize];
for (int i = 0; i < windowSize; i++) {
String answer = pInfo.get(loc + i - windowSize + 1).get(AnswerAnnotation.class);
labels[i] = classIndex.indexOf(answer);
}
CRFDatum<Serializable,CRFLabel> d = new CRFDatum<Serializable,CRFLabel>(features, new CRFLabel(labels));
//System.err.println(d);
return d;
}
public static class TestSequenceModel implements SequenceModel {
private int window;
private int numClasses;
//private FactorTable[] factorTables;
private CRFCliqueTree cliqueTree;
private int[] tags;
private int[] backgroundTag;
//public Scorer(FactorTable[] factorTables) {
public TestSequenceModel(CRFCliqueTree cliqueTree) {
//this.factorTables = factorTables;
this.cliqueTree = cliqueTree;
//this.window = factorTables[0].windowSize();
this.window = cliqueTree.window();
//this.numClasses = factorTables[0].numClasses();
this.numClasses = cliqueTree.getNumClasses();
tags = new int[numClasses];
for (int i = 0; i < tags.length; i++) {
tags[i] = i;
}
backgroundTag = new int[]{cliqueTree.backgroundIndex()};
}
public int length() {
return cliqueTree.length();
}
public int leftWindow() {
return window - 1;
}
public int rightWindow() {
return 0;
}
public int[] getPossibleValues(int pos) {
if (pos < window - 1) {
return backgroundTag;
}
return tags;
}
public double scoreOf(int[] tags, int pos) {
int[] previous = new int[window - 1];
int realPos = pos - window + 1;
for (int i = 0; i < window - 1; i++) {
previous[i] = tags[realPos + i];
}
return cliqueTree.condLogProbGivenPrevious(realPos, tags[pos], previous);
}
public double[] scoresOf(int[] tags, int pos) {
int realPos = pos - window + 1;
double[] scores = new double[numClasses];
int[] previous = new int[window - 1];
for (int i = 0; i < window - 1; i++) {
previous[i] = tags[realPos + i];
}
for (int i = 0; i < numClasses; i++) {
scores[i] = cliqueTree.condLogProbGivenPrevious(realPos, i, previous);
}
return scores;
}
public double scoreOf(int[] sequence) {
throw new UnsupportedOperationException();
}
} // end class TestSequenceModel
@Override
public List<CoreLabel> classify(List<CoreLabel> document) {
if (flags.doGibbs) {
try {
return classifyGibbs(document);
} catch (Exception e) {
System.err.println("Error running testGibbs inference!");
e.printStackTrace();
return null;
}
} else if (flags.crfType.equalsIgnoreCase("maxent")) {
return classifyMaxEnt(document);
} else {
throw new RuntimeException("Unsupported inference type: " + flags.crfType);
}
}
@Override
public SequenceModel getSequenceModel(List<? extends CoreLabel> doc) {
Pair<int[][][],int[]> p = documentToDataAndLabels(doc);
int[][][] data = p.first();
CRFCliqueTree cliqueTree = CRFCliqueTree.getCalibratedCliqueTree(weights, data, labelIndices, classIndex.size(), classIndex, flags.backgroundSymbol);
//Scorer scorer = new Scorer(factorTables);
return new TestSequenceModel(cliqueTree);
}
/** Do standard sequence inference, using either Viterbi or Beam inference
* depending on the value of <code>flags.inferenceType</code>.
*
* @param document Document to classify. Classification happens in place.
* This document is modified.
* @return The classified document
*/
public List<CoreLabel> classifyMaxEnt(List<CoreLabel> document) {
if (document.isEmpty()) {
return document;
}
SequenceModel model = getSequenceModel(document);
if (flags.inferenceType == null) { flags.inferenceType = "Viterbi"; }
BestSequenceFinder tagInference;
if (flags.inferenceType.equalsIgnoreCase("Viterbi")) {
tagInference = new ExactBestSequenceFinder();
} else if (flags.inferenceType.equalsIgnoreCase("Beam")) {
tagInference = new BeamBestSequenceFinder(flags.beamSize);
} else {
throw new RuntimeException("Unknown inference type: "+flags.inferenceType+". Your options are Viterbi|Beam.");
}
int[] bestSequence = tagInference.bestSequence(model);
if (flags.useReverse) {
Collections.reverse(document);
}
for (int j = 0, docSize = document.size(); j < docSize; j++) {
CoreLabel wi = document.get(j);
String guess = classIndex.get(bestSequence[j + windowSize - 1]);
wi.set(AnswerAnnotation.class, guess);
}
if (flags.useReverse) {
Collections.reverse(document);
}
return document;
}
public List<CoreLabel> classifyGibbs(List<CoreLabel> document) throws ClassNotFoundException, SecurityException, NoSuchMethodException, IllegalArgumentException, InstantiationException, IllegalAccessException, InvocationTargetException {
System.err.println("Testing using Gibbs sampling.");
Pair<int[][][],int[]> p = documentToDataAndLabels(document);
int[][][] data = p.first();
List<CoreLabel> newDocument = document; // reversed if necessary
if (flags.useReverse) {
Collections.reverse(document);
newDocument = new ArrayList<CoreLabel>(document);
Collections.reverse(document);
}
CRFCliqueTree cliqueTree = CRFCliqueTree.getCalibratedCliqueTree(weights, data, labelIndices, classIndex.size(), classIndex, flags.backgroundSymbol);
SequenceModel model = cliqueTree;
SequenceListener listener = cliqueTree;
EntityCachingAbstractSequencePrior prior;
if (flags.useNERPrior) {
prior = new EmpiricalNERPrior(flags.backgroundSymbol, classIndex, newDocument);
// SamplingNERPrior prior = new SamplingNERPrior(flags.backgroundSymbol, classIndex, newDocument);
} else if (flags.useAcqPrior) {
prior = new AcquisitionsPrior(flags.backgroundSymbol, classIndex, newDocument);
} else if (flags.useSemPrior) {
prior = new SeminarsPrior(flags.backgroundSymbol, classIndex, newDocument);
} else {
throw new RuntimeException("no prior specified");
}
model = new FactoredSequenceModel(model, prior);
listener = new FactoredSequenceListener(listener, prior);
SequenceGibbsSampler sampler = new SequenceGibbsSampler(0, 0, listener);
int[] sequence = new int[cliqueTree.length()];
if (flags.initViterbi) {
TestSequenceModel testSequenceModel = new TestSequenceModel(cliqueTree);
ExactBestSequenceFinder tagInference = new ExactBestSequenceFinder();
int[] bestSequence = tagInference.bestSequence(testSequenceModel);
System.arraycopy(bestSequence, windowSize-1, sequence, 0, sequence.length);
} else {
int[] initialSequence = SequenceGibbsSampler.getRandomSequence(model);
System.arraycopy(initialSequence, 0, sequence, 0, sequence.length);
}
sampler.verbose = 0;
if (flags.annealingType.equalsIgnoreCase("linear")) {
sequence = sampler.findBestUsingAnnealing(model, CoolingSchedule.getLinearSchedule(1.0, flags.numSamples), sequence);
} else if (flags.annealingType.equalsIgnoreCase("exp") || flags.annealingType.equalsIgnoreCase("exponential")) {
sequence = sampler.findBestUsingAnnealing(model, CoolingSchedule.getExponentialSchedule(1.0, flags.annealingRate, flags.numSamples), sequence);
} else {
throw new RuntimeException("No annealing type specified");
}
//System.err.println(ArrayMath.toString(sequence));
if (flags.useReverse) {
Collections.reverse(document);
}
for (int j = 0, dsize = newDocument.size(); j < dsize; j++) {
CoreLabel wi = document.get(j);
if (wi==null) throw new RuntimeException("");
if (classIndex==null) throw new RuntimeException("");
wi.set(AnswerAnnotation.class, classIndex.get(sequence[j]));
}
if (flags.useReverse) {
Collections.reverse(document);
}
return document;
}
/**
* Takes a {@link List} of {@link CoreLabel}s and prints the likelihood
* of each possible label at each point.
*
* @param document A {@link List} of {@link CoreLabel}s.
*/
@Override
public void printProbsDocument(List<CoreLabel> document) {
Pair<int[][][],int[]> p = documentToDataAndLabels(document);
int[][][] data = p.first();
//FactorTable[] factorTables = CRFLogConditionalObjectiveFunction.getCalibratedCliqueTree(weights, data, labelIndices, classIndex.size());
CRFCliqueTree cliqueTree = CRFCliqueTree.getCalibratedCliqueTree(weights, data, labelIndices, classIndex.size(), classIndex, flags.backgroundSymbol);
// for (int i = 0; i < factorTables.length; i++) {
for (int i = 0; i < cliqueTree.length(); i++) {
CoreLabel wi = document.get(i);
System.out.print(wi.word() + "\t");
for (Iterator<String> iter = classIndex.iterator(); iter.hasNext();) {
String label = iter.next();
int index = classIndex.indexOf(label);
// double prob = Math.pow(Math.E, factorTables[i].logProbEnd(index));
double prob = cliqueTree.prob(i, index);
System.out.print(label + "=" + prob);
if (iter.hasNext()) {
System.out.print("\t");
} else {
System.out.print("\n");
}
}
}
}
/**
* Takes the file, reads it in, and prints out the likelihood of
* each possible label at each point. This gives a simple way to examine
* the probability distributions of the CRF. See
* <code>getCliqueTrees()</code> for more.
*
* @param filename The path to the specified file
*/
public void printFirstOrderProbs(String filename) {
// only for the OCR data does this matter
flags.ocrTrain = false;
ObjectBank<List<CoreLabel>> docs = makeObjectBankFromFile(filename);
printFirstOrderProbsDocuments(docs);
}
/**
* Takes a {@link List} of documents and prints the likelihood
* of each possible label at each point.
*
* @param documents A {@link List} of {@link List} of {@link CoreLabel}s.
*/
public void printFirstOrderProbsDocuments(ObjectBank<List<CoreLabel>> documents) {
for (List<CoreLabel> doc : documents) {
printFirstOrderProbsDocument(doc);
System.out.println();
}
}
/**
* Want to make arbitrary probability queries? Then this is the method for you.
* Given the filename, it reads it in and breaks it into documents, and then makes
* a CRFCliqueTree for each document. you can then ask the clique tree for marginals
* and conditional probabilities of almost anything you want.
*/
public List<CRFCliqueTree> getCliqueTrees(String filename) {
// only for the OCR data does this matter
flags.ocrTrain = false;
List<CRFCliqueTree> cts = new ArrayList<CRFCliqueTree>();
ObjectBank<List<CoreLabel>> docs = makeObjectBankFromFile(filename);
for (List<CoreLabel> doc : docs) {
cts.add(getCliqueTree(doc));
}
return cts;
}
private CRFCliqueTree getCliqueTree(List<CoreLabel> document) {
Pair<int[][][],int[]> p = documentToDataAndLabels(document);
int[][][] data = p.first();
//FactorTable[] factorTables = CRFLogConditionalObjectiveFunction.getCalibratedCliqueTree(weights, data, labelIndices, classIndex.size());
return CRFCliqueTree.getCalibratedCliqueTree(weights, data, labelIndices, classIndex.size(), classIndex, flags.backgroundSymbol);
}
/**
* Takes a {@link List} of {@link CoreLabel}s and prints the likelihood
* of each possible label at each point.
*
* @param document A {@link List} of {@link CoreLabel}s.
*/
public void printFirstOrderProbsDocument(List<CoreLabel> document) {
CRFCliqueTree cliqueTree = getCliqueTree(document);
// for (int i = 0; i < factorTables.length; i++) {
for (int i = 0; i < cliqueTree.length(); i++) {
CoreLabel wi = document.get(i);
System.out.print(wi.word() + "\t");
for (Iterator<String> iter = classIndex.iterator(); iter.hasNext();) {
String label = iter.next();
int index = classIndex.indexOf(label);
if (i == 0) {
//double prob = Math.pow(Math.E, factorTables[i].logProbEnd(index));
double prob = cliqueTree.prob(i, index);
System.out.print(label + "=" + prob);
if (iter.hasNext()) {
System.out.print("\t");
} else {
System.out.print("\n");
}
} else {
for (Iterator<String> iter1 = classIndex.iterator(); iter1.hasNext();) {
String label1 = iter1.next();
int index1 = classIndex.indexOf(label1);
//double prob = Math.pow(Math.E, factorTables[i].logProbEnd(new int[]{index1, index}));
double prob = cliqueTree.prob(i, new int[]{index1, index});
System.out.print(label1 + "_" + label + "=" + prob);
if (iter.hasNext() || iter1.hasNext()) {
System.out.print("\t");
} else {
System.out.print("\n");
}
}
}
}
}
}
/** Train a classifier from documents.
*
* @param docs An objectbank representation of documents.
*/
@Override
public void train(ObjectBank<List<CoreLabel>> docs) {
makeAnswerArraysAndTagIndex(docs);
for (int i = 0; i <= flags.numTimesPruneFeatures; i++) {
Pair dataAndLabels = documentsToDataAndLabels(docs);
if (flags.numTimesPruneFeatures == i) {
docs = null; // hopefully saves memory
}
// save feature index to disk and read in later
File featIndexFile = null;
if (flags.saveFeatureIndexToDisk) {
try {
System.err.println("Writing feature index to temporary file.");
featIndexFile = IOUtils.writeObjectToTempFile(featureIndex, "featIndex" + i+ ".tmp");
featureIndex = null;
} catch (IOException e) {
throw new RuntimeException("Could not open temporary feature index file for writing.");
}
}
// first index is the number of the document
// second index is position in the document also the index of the clique/factor table
// third index is the number of elements in the clique/window thase features are for (starting with last element)
// fourth index is position of the feature in the array that holds them
// element in data[i][j][k][m] is the index of the mth feature occurring in position k of the jth clique of the ith document
int[][][][] data = (int[][][][]) dataAndLabels.first();
// first index is the number of the document
// second index is the position in the document
// element in labels[i][j] is the index of the correct label (if it exists) at position j in document i
int[][] labels = (int[][]) dataAndLabels.second();
if (flags.loadProcessedData != null) {
List processedData = loadProcessedData(flags.loadProcessedData);
if (processedData != null) {
// enlarge the data and labels array
int[][][][] allData = new int[data.length + processedData.size()][][][];
int[][] allLabels = new int[labels.length + processedData.size()][];
System.arraycopy(data, 0, allData, 0, data.length);
System.arraycopy(labels, 0, allLabels, 0, labels.length);
// add to the data and labels array
addProcessedData(processedData, allData, allLabels, data.length);
data = allData;
labels = allLabels;
}
}
if (flags.useFloat) {
CRFLogConditionalObjectiveFloatFunction func = new CRFLogConditionalObjectiveFloatFunction(data, labels, featureIndex, windowSize, classIndex, labelIndices, map, flags.backgroundSymbol, flags.sigma);
func.crfType = flags.crfType;
QNMinimizer minimizer;
if (flags.interimOutputFreq != 0) {
FloatFunction monitor = new ResultStoringFloatMonitor(flags.interimOutputFreq, flags.serializeTo);
minimizer = new QNMinimizer(monitor);
} else {
minimizer = new QNMinimizer();
}
if (i == 0) {
minimizer.setM(flags.QNsize);
} else {
minimizer.setM(flags.QNsize2);
}
float[] initialWeights;
if (flags.initialWeights == null) {
initialWeights = func.initial();
} else {
try {
System.err.println("Reading initial weights from file " + flags.initialWeights);
DataInputStream dis = new DataInputStream(new BufferedInputStream(new GZIPInputStream(new FileInputStream(flags.initialWeights))));
initialWeights = Convert.readFloatArr(dis);
} catch (IOException e) {
throw new RuntimeException("Could not read from float initial weight file " + flags.initialWeights);
}
}
System.err.println("numWeights: " + initialWeights.length);
float[] weights = minimizer.minimize(func, (float) flags.tolerance, initialWeights);
this.weights = ArrayMath.floatArrayToDoubleArray(func.to2D(weights));
} else {
/*double[] estimate = null;
if(flags.estimateInitial){
int[][][][] approxData = new int[data.length/100][][][];
int[][] approxLabels = new int[data.length/100][];
Random generator = new Random(1);
for(int k=0;k<approxData.length; k++){
int thisInd = generator.nextInt(data.length);
approxData[k] = data[ thisInd];
approxLabels[k] = labels[ thisInd];
}
CRFLogConditionalObjectiveFunction approxFunc = new CRFLogConditionalObjectiveFunction(approxData, approxLabels, featureIndex, windowSize, classIndex, labelIndices, map, flags.backgroundSymbol, flags.sigma);
approxFunc.crfType = flags.crfType;
minimizer = new QNMinimizer(10);
if (flags.initialWeights == null) {
estimate = approxFunc.initial();
} else {
try {
System.err.println("Reading initial weights from file " + flags.initialWeights);
DataInputStream dis = new DataInputStream(new BufferedInputStream(new GZIPInputStream(new FileInputStream(flags.initialWeights))));
estimate = Convert.readDoubleArr(dis);
} catch (IOException e) {
throw new RuntimeException("Could not read from double initial weight file " + flags.initialWeights);
}
}
estimate = minimizer.minimize(approxFunc, 1e-2, estimate);
}
*/
CRFLogConditionalObjectiveFunction func = new CRFLogConditionalObjectiveFunction(data, labels, featureIndex, windowSize, classIndex, labelIndices, map, flags.backgroundSymbol, flags.sigma);
func.crfType = flags.crfType;
minimizer = getMinimizer(i);
double[] initialWeights;
if (flags.initialWeights == null) {
initialWeights = func.initial();
} else {
try {
System.err.println("Reading initial weights from file " + flags.initialWeights);
DataInputStream dis = new DataInputStream(new BufferedInputStream(new GZIPInputStream(new FileInputStream(flags.initialWeights))));
initialWeights = Convert.readDoubleArr(dis);
} catch (IOException e) {
throw new RuntimeException("Could not read from double initial weight file " + flags.initialWeights);
}
}
System.err.println("numWeights: " + initialWeights.length);
if (flags.testObjFunction) {
StochasticDiffFunctionTester tester = new StochasticDiffFunctionTester(func);
if(tester.testSumOfBatches(initialWeights,1e-4)){
System.err.println("Testing complete... exiting");
System.exit(1);
} else {
System.err.println("Testing failed....exiting");
System.exit(1);
}
}
double[] weights = minimizer.minimize(func, flags.tolerance, initialWeights);
this.weights = func.to2D(weights);
}
// save feature index to disk and read in later
if (flags.saveFeatureIndexToDisk) {
try {
System.err.println("Reading temporary feature index file.");
featureIndex = (Index<String>) IOUtils.readObjectFromFile(featIndexFile);
} catch (Exception e) {
throw new RuntimeException("Could not open temporary feature index file for reading.");
}
}
if (i != flags.numTimesPruneFeatures) {
dropFeaturesBelowThreshold(flags.featureDiffThresh);
System.err.println("Removing features with weight below " + flags.featureDiffThresh + " and retraining...");
}
}
}
protected Minimizer getMinimizer(){
return getMinimizer(0);
}
protected Minimizer getMinimizer(int featurePruneIteration){
if( flags.useQN ){
int QNmem;
if (featurePruneIteration == 0) {
QNmem = flags.QNsize;
} else {
QNmem = flags.QNsize2;
}
if (flags.interimOutputFreq != 0) {
Function monitor = new ResultStoringMonitor(flags.interimOutputFreq, flags.serializeTo);
minimizer = new QNMinimizer(monitor,QNmem,flags.useRobustQN);
} else {
minimizer = new QNMinimizer(QNmem,flags.useRobustQN);
}
} else if( flags.useSGDtoQN ) {
minimizer = new SGDToQNMinimizer(flags);
} else if( flags.useSMD){
minimizer = new SMDMinimizer(flags.initialGain, flags.stochasticBatchSize, flags.stochasticMethod,flags.SGDPasses);
} else if( flags.useSGD){
minimizer = new SGDMinimizer(flags.initialGain,flags.stochasticBatchSize);
} else if( flags.useScaledSGD){
minimizer = new ScaledSGDMinimizer(flags.initialGain,flags.stochasticBatchSize,flags.SGDPasses,flags.scaledSGDMethod);
}
if(minimizer==null){
throw new RuntimeException("No minimizer assigned!");
}
return minimizer;
}
/**
* Creates a new CRFDatum from the preprocessed allData format, given the document number,
* position number, and a List of Object labels.
*
* @param allData
* @param beginPosition
* @param endPosition
* @param labeledWordInfos
* @return A new CRFDatum
*/
protected List<CRFDatum> extractDatumSequence(int[][][] allData, int beginPosition, int endPosition, List<CoreLabel> labeledWordInfos) {
List<CRFDatum> result = new ArrayList<CRFDatum>();
int beginContext = beginPosition - windowSize + 1;
if (beginContext < 0) {
beginContext = 0;
}
// for the beginning context, add some dummy datums with no features!
// TODO: is there any better way to do this?
for (int position = beginContext; position < beginPosition; position++) {
List cliqueFeatures = new ArrayList();
for (int i = 0; i < windowSize; i++) {
// create a feature list
cliqueFeatures.add(Collections.EMPTY_SET);
}
CRFDatum<Serializable,String> datum = new CRFDatum<Serializable,String>(cliqueFeatures, labeledWordInfos.get(position).get(AnswerAnnotation.class));
result.add(datum);
}
// now add the real datums
for (int position = beginPosition; position <= endPosition; position++) {
List cliqueFeatures = new ArrayList();
for (int i = 0; i < windowSize; i++) {
// create a feature list
Collection<String> features = new ArrayList<String>();
for (int j = 0; j < allData[position][i].length; j++) {
features.add(featureIndex.get(allData[position][i][j]));
}
cliqueFeatures.add(features);
}
CRFDatum<Serializable,String> datum = new CRFDatum<Serializable,String>(cliqueFeatures, labeledWordInfos.get(position).get(AnswerAnnotation.class));
result.add(datum);
}
return result;
}
/**
* Adds the List of Lists of CRFDatums to the data and labels arrays, treating each datum as if
* it were its own document.
* Adds context labels in addition to the target label for each datum, meaning that for a particular
* document, the number of labels will be windowSize-1 greater than the number of datums.
*
* @param processedData a List of Lists of CRFDatums
* @param data
* @param labels
* @param offset
*/
protected void addProcessedData(List<List<CRFDatum>> processedData, int[][][][] data, int[][] labels, int offset) {
for (int i = 0, pdSize = processedData.size(); i < pdSize; i++) {
int dataIndex = i + offset;
List<CRFDatum> document = processedData.get(i);
int dsize = document.size();
labels[dataIndex] = new int[dsize];
data[dataIndex] = new int[dsize][][];
for (int j = 0; j < dsize; j++) {
CRFDatum crfDatum = document.get(j);
// add label, they are offset by extra context
labels[dataIndex][j] = classIndex.indexOf((String) crfDatum.label());
// add features
List<Collection<String>> cliques = crfDatum.asFeatures();
int csize = cliques.size();
data[dataIndex][j] = new int[csize][];
for (int k = 0; k < csize; k++) {
Collection<String> features = cliques.get(k);
// Debug only: Remove
// if (j < windowSize) {
// System.err.println("addProcessedData: Features Size: " + features.size());
// }
data[dataIndex][j][k] = new int[features.size()];
int m = 0;
try {
for (String feature : features) {
//System.err.println("feature " + feature);
// if (featureIndex.indexOf(feature)) ;
if (featureIndex == null) {
System.out.println("Feature is NULL!");
}
data[dataIndex][j][k][m] = featureIndex.indexOf(feature);
m++;
}
} catch (Exception e) {
e.printStackTrace();
System.err.printf("[index=%d, j=%d, k=%d, m=%d]\n", dataIndex, j, k, m);
System.err.println("data.length " + data.length);
System.err.println("data[dataIndex].length " + data[dataIndex].length);
System.err.println("data[dataIndex][j].length " + data[dataIndex][j].length);
System.err.println("data[dataIndex][j][k].length " + data[dataIndex][j].length);
System.err.println("data[dataIndex][j][k][m] " + data[dataIndex][j][k][m]);
return;
}
}
}
}
}
protected static void saveProcessedData(List datums, String filename) {
System.err.print("Saving processsed data of size " + datums.size() + " to serialized file...");
ObjectOutputStream oos = null;
try {
oos = new ObjectOutputStream(new FileOutputStream(filename));
oos.writeObject(datums);
} catch (IOException e) {
// do nothing
} finally {
if (oos != null) {
try {
oos.close();
} catch (IOException e) {
}
}
}
System.err.println("done.");
}
protected static List loadProcessedData(String filename) {
System.err.print("Loading processed data from serialized file...");
ObjectInputStream ois = null;
List result = Collections.EMPTY_LIST;
try {
ois = new ObjectInputStream(new FileInputStream(filename));
result = (List) ois.readObject();
} catch (Exception e) {
e.printStackTrace();
} finally {
if (ois != null) {
try {
ois.close();
} catch (IOException e) {
}
}
}
System.err.println("done. Got " + result.size() + " datums.");
return result;
}
public void loadTextClassifier(String text, Properties props) throws ClassCastException, IOException, ClassNotFoundException, InstantiationException, IllegalAccessException {
//System.err.println("DEBUG: in loadTextClassifier");
System.err.println("Loading Text Classifier from "+text);
BufferedReader br = new BufferedReader(new InputStreamReader(new GZIPInputStream(new FileInputStream(text))));
String line = br.readLine();
// first line should be this format:
// labelIndices.length=\t%d
String[] toks = line.split("\\t");
if (!toks[0].equals("labelIndices.length=")) { throw new RuntimeException("format error"); }
int size = Integer.parseInt(toks[1]);
labelIndices = new Index[size];
for (int labelIndicesIdx = 0; labelIndicesIdx < size; labelIndicesIdx++) {
line = br.readLine();
// first line should be this format:
// labelIndices.length=\t%d
// labelIndices[0].size()=\t%d
toks = line.split("\\t");
if (! (toks[0].startsWith("labelIndices[") && toks[0].endsWith("].size()="))) {
throw new RuntimeException("format error");
}
int labelIndexSize = Integer.parseInt(toks[1]);
labelIndices[labelIndicesIdx] = new Index<CRFLabel>();
int count = 0;
while(count<labelIndexSize) {
line = br.readLine();
toks = line.split("\\t");
int idx = Integer.parseInt(toks[0]);
if (count!=idx) { throw new RuntimeException("format error"); }
String[] crflabelstr = toks[1].split(" ");
int[] crflabel = new int[crflabelstr.length];
for (int i=0; i < crflabelstr.length; i++) {
crflabel[i] = Integer.parseInt(crflabelstr[i]);
}
CRFLabel crfL = new CRFLabel(crflabel);
labelIndices[labelIndicesIdx].add(crfL);
count++;
}
}
/**************************************/
System.err.printf("DEBUG: labelIndices.length=\t%d\n",labelIndices.length);
for(int i = 0; i < labelIndices.length; i++) {
System.err.printf("DEBUG: labelIndices[%d].size()=\t%d\n", i, labelIndices[i].size());
for(int j = 0; j < labelIndices[i].size(); j++) {
int[] label = labelIndices[i].get(j).getLabel();
List<Integer> list = new ArrayList<Integer>();
for(int l : label) {
list.add(l);
}
System.err.printf("DEBUG: %d\t%s\n", j, StringUtils.join(list, " "));
}
}
/**************************************/
line = br.readLine();
toks = line.split("\\t");
if (!toks[0].equals("classIndex.size()=")) { throw new RuntimeException("format error"); }
int classIndexSize = Integer.parseInt(toks[1]);
classIndex = new Index<String>();
int count = 0;
while(count<classIndexSize) {
line = br.readLine();
toks = line.split("\\t");
int idx = Integer.parseInt(toks[0]);
if (count!=idx) { throw new RuntimeException("format error"); }
classIndex.add(toks[1]);
count++;
}
/******************************************/
System.err.printf("DEBUG: classIndex.size()=\t%d\n", classIndex.size());
for(int i = 0; i < classIndex.size(); i++) {
System.err.printf("DEBUG: %d\t%s\n", i, classIndex.get(i));
}
/******************************************/
line = br.readLine();
toks = line.split("\\t");
if (!toks[0].equals("featureIndex.size()=")) { throw new RuntimeException("format error"); }
int featureIndexSize = Integer.parseInt(toks[1]);
featureIndex = new Index<String>();
count = 0;
while(count<featureIndexSize) {
line = br.readLine();
toks = line.split("\\t");
int idx = Integer.parseInt(toks[0]);
if (count!=idx) { throw new RuntimeException("format error"); }
featureIndex.add(toks[1]);
count++;
}
/***************************************/
System.err.printf("DEBUG: featureIndex.size()=\t%d\n", featureIndex.size());
/*
for(int i = 0; i < featureIndex.size(); i++) {
System.err.printf("DEBUG: %d\t%s\n", i, featureIndex.get(i));
}
*/
/***************************************/
line = br.readLine();
if (!line.equals("<flags>")) { throw new RuntimeException("format error"); }
Properties p = new Properties();
line = br.readLine();
while(!line.equals("</flags>")) {
//System.err.println("DEBUG: flags line: "+line);
String[] keyValue = line.split("=");
//System.err.printf("DEBUG: p.setProperty(%s,%s)\n", keyValue[0], keyValue[1]);
p.setProperty(keyValue[0], keyValue[1]);
line = br.readLine();
}
//System.err.println("DEBUG: out from flags");
flags = new SeqClassifierFlags(p);
System.err.println("DEBUG: <flags>");
System.err.print(flags.toString());
System.err.println("DEBUG: </flags>");
// <featureFactory> edu.stanford.nlp.wordseg.Gale2007ChineseSegmenterFeatureFactory </featureFactory>
line = br.readLine();
String[] featureFactoryName = line.split(" ");
if (!featureFactoryName[0].equals("<featureFactory>")
|| !featureFactoryName[2].equals("</featureFactory>")) {
throw new RuntimeException("format error");
}
featureFactory = (edu.stanford.nlp.sequences.FeatureFactory)Class.forName(featureFactoryName[1]).newInstance();
featureFactory.init(flags);
reinit();
// <windowSize> 2 </windowSize>
line = br.readLine();
String[] windowSizeName = line.split(" ");
if (!windowSizeName[0].equals("<windowSize>")
|| !windowSizeName[2].equals("</windowSize>")) {
throw new RuntimeException("format error");
}
windowSize = Integer.parseInt(windowSizeName[1]);
// weights.length= 2655170
line = br.readLine();
toks = line.split("\\t");
if (!toks[0].equals("weights.length=")) { throw new RuntimeException("format error"); }
int weightsLength = Integer.parseInt(toks[1]);
weights = new double[weightsLength][];
count = 0;
while(count < weightsLength) {
line = br.readLine();
toks = line.split("\\t");
int weights2Length = Integer.parseInt(toks[0]);
weights[count] = new double[weights2Length];
String[] weightsValue = toks[1].split(" ");
if (weights2Length != weightsValue.length)
{ throw new RuntimeException("weights format error"); }
for(int i2 = 0; i2 < weights2Length; i2++) {
weights[count][i2] = Double.parseDouble(weightsValue[i2]);
}
count++;
}
System.err.printf("DEBUG: double[%d][] weights loaded\n", weightsLength);
line = br.readLine();
if (line != null)
{ throw new RuntimeException("weights format error"); }
}
/**
* Serialize the model to a human readable format.
* It's not yet complete. It should now work for Chinese segmenter though.
* TODO: check things in serializeClassifier and add other necessary serialization back
*
* @param serializePath File to write text format of classifier to.
*/
public void serializeTextClassifier(String serializePath) {
System.err.print("Serializing Text classifier to " + serializePath + "...");
try {
PrintWriter pw = new PrintWriter(new GZIPOutputStream(new FileOutputStream(serializePath)));
pw.printf("labelIndices.length=\t%d\n",labelIndices.length);
for(int i = 0; i < labelIndices.length; i++) {
pw.printf("labelIndices[%d].size()=\t%d\n", i, labelIndices[i].size());
for(int j = 0; j < labelIndices[i].size(); j++) {
int[] label = labelIndices[i].get(j).getLabel();
List<Integer> list = new ArrayList<Integer>();
for(int l : label) {
list.add(l);
}
pw.printf("%d\t%s\n", j, StringUtils.join(list, " "));
}
}
pw.printf("classIndex.size()=\t%d\n", classIndex.size());
for(int i = 0; i < classIndex.size(); i++) {
pw.printf("%d\t%s\n", i, classIndex.get(i));
}
//pw.printf("</classIndex>\n");
pw.printf("featureIndex.size()=\t%d\n", featureIndex.size());
for(int i = 0; i < featureIndex.size(); i++) {
pw.printf("%d\t%s\n", i, featureIndex.get(i));
}
//pw.printf("</featureIndex>\n");
pw.println("<flags>");
pw.print(flags.toString());
pw.println("</flags>");
pw.printf("<featureFactory> %s </featureFactory>\n",featureFactory.getClass().getName());
pw.printf("<windowSize> %d </windowSize>\n", windowSize);
pw.printf("weights.length=\t%d\n", weights.length);
for (double[] ws : weights) {
ArrayList<Double> list = new ArrayList<Double>();
for (double w : ws) {
list.add(w);
}
pw.printf("%d\t%s\n", ws.length, StringUtils.join(list, " "));
}
pw.close();
System.err.println("done.");
} catch (Exception e) {
System.err.println("Failed");
e.printStackTrace();
// don't actually exit in case they're testing too
//System.exit(1);
}
}
/** {@inheritDoc}
*/
@Override
public void serializeClassifier(String serializePath) {
System.err.print("Serializing classifier to " + serializePath + "...");
try {
ObjectOutputStream oos = IOUtils.writeStreamFromString(serializePath);
oos.writeObject(labelIndices);
oos.writeObject(classIndex);
oos.writeObject(featureIndex);
oos.writeObject(flags);
oos.writeObject(featureFactory);
oos.writeInt(windowSize);
oos.writeObject(weights);
//oos.writeObject(WordShapeClassifier.getKnownLowerCaseWords());
if (readerAndWriter instanceof TrueCasingDocumentReaderAndWriter) {
oos.writeObject(TrueCasingDocumentReaderAndWriter.knownWords);
}
oos.writeObject(knownLCWords);
oos.close();
System.err.println("done.");
} catch (Exception e) {
System.err.println("Failed");
e.printStackTrace();
// don't actually exit in case they're testing too
//System.exit(1);
}
}
/**
* Loads a classifier from the specified InputStream.
* This version works quietly (unless VERBOSE is true).
* If props is non-null then any properties it specifies override
* those in the serialized file. However, only some properties are
* sensible to change (you shouldn't change how features are defined).
* <p>
* <i>Note:</i> This method does not close the ObjectInputStream. (But
* earlier versions of the code used to, so beware....)
*/
@Override
@SuppressWarnings({"unchecked"}) // can't have right types in deserialization
public void loadClassifier(ObjectInputStream ois, Properties props) throws ClassCastException, IOException, ClassNotFoundException {
labelIndices = (Index<CRFLabel>[]) ois.readObject();
classIndex = (Index<String>) ois.readObject();
featureIndex = (Index<String>) ois.readObject();
flags = (SeqClassifierFlags) ois.readObject();
featureFactory = (edu.stanford.nlp.sequences.FeatureFactory) ois.readObject();
if (props != null) {
flags.setProperties(props, false);
}
reinit();
windowSize = ois.readInt();
weights = (double[][]) ois.readObject();
if (readerAndWriter instanceof TrueCasingDocumentReaderAndWriter) {
TrueCasingDocumentReaderAndWriter.knownWords = (Set)ois.readObject();
}
//WordShapeClassifier.setKnownLowerCaseWords((Set) ois.readObject());
knownLCWords = (Set<String>) ois.readObject();
if (VERBOSE) {
System.err.println("windowSize=" + windowSize);
System.err.println("flags=\n" + flags);
}
}
/**
* This is used to load the default supplied classifier stored within
* the jar file.
* THIS FUNCTION WILL ONLY WORK IF THE CODE WAS LOADED FROM A JAR FILE
* WHICH HAS A SERIALIZED CLASSIFIER STORED INSIDE IT.
*/
public void loadDefaultClassifier() {
loadJarClassifier(DEFAULT_CLASSIFIER, null);
}
/**
* Used to get the default supplied classifier inside the jar file.
* THIS FUNCTION WILL ONLY WORK IF THE CODE WAS LOADED FROM A JAR FILE
* WHICH HAS A SERIALIZED CLASSIFIER STORED INSIDE IT.
*
* @return The default CRFClassifier in the jar file (if there is one)
*/
public static CRFClassifier getDefaultClassifier() {
CRFClassifier crf = new CRFClassifier();
crf.loadDefaultClassifier();
return crf;
}
/**
* Used to load a classifier stored as a resource inside a jar file.
* THIS FUNCTION WILL ONLY WORK IF THE CODE WAS LOADED FROM A JAR FILE
* WHICH HAS A SERIALIZED CLASSIFIER STORED INSIDE IT.
*
* @param resourceName Name of clasifier resource inside the jar file.
* @return A CRFClassifier stored in the jar file
*/
public static CRFClassifier getJarClassifier(String resourceName, Properties props) {
CRFClassifier crf = new CRFClassifier();
crf.loadJarClassifier(resourceName, props);
return crf;
}
/** Loads a CRF classifier from a filepath, and returns it.
*
* @param file File to load classifier from
* @return The CRF classifier
*
* @throws IOException If there are problems accessing the input stream
* @throws ClassCastException If there are problems interpreting the serialized data
* @throws ClassNotFoundException If there are problems interpreting the serialized data
*/
public static CRFClassifier getClassifier(File file) throws IOException, ClassCastException, ClassNotFoundException {
CRFClassifier crf = new CRFClassifier();
crf.loadClassifier(file);
return crf;
}
/** Loads a CRF classifier from an InputStream, and returns it. This method
* does not buffer the InputStream, so you should have buffered it before
* calling this method.
*
* @param in InputStream to load classifier from
* @return The CRF classifier
*
* @throws IOException If there are problems accessing the input stream
* @throws ClassCastException If there are problems interpreting the serialized data
* @throws ClassNotFoundException If there are problems interpreting the serialized data
*/
public static CRFClassifier getClassifier(InputStream in) throws IOException, ClassCastException, ClassNotFoundException {
CRFClassifier crf = new CRFClassifier();
crf.loadClassifier(in);
return crf;
}
public static CRFClassifier getClassifierNoExceptions(String loadPath) {
CRFClassifier crf = new CRFClassifier();
crf.loadClassifierNoExceptions(loadPath);
return crf;
}
public static CRFClassifier getClassifier(String loadPath) throws IOException, ClassCastException, ClassNotFoundException {
CRFClassifier crf = new CRFClassifier();
crf.loadClassifier(loadPath);
return crf;
}
/** The main method. See the class documentation. */
public static void main(String[] args) throws Exception {
StringUtils.printErrInvocationString("CRFClassifier", args);
Properties props = StringUtils.argsToProperties(args);
CRFClassifier crf = new CRFClassifier(props);
String testFile = crf.flags.testFile;
String textFile = crf.flags.textFile;
String loadPath = crf.flags.loadClassifier;
String loadTextPath = crf.flags.loadTextClassifier;
String serializeTo = crf.flags.serializeTo;
String serializeToText = crf.flags.serializeToText;
if (loadPath != null) {
crf.loadClassifierNoExceptions(loadPath, props);
} else if (loadTextPath != null) {
System.err.println("Warning: this is now only tested for Chinese Segmenter");
System.err.println("(Sun Dec 23 00:59:39 2007) (pichuan)");
try {
crf.loadTextClassifier(loadTextPath, props);
//System.err.println("DEBUG: out from crf.loadTextClassifier");
} catch (Exception e) {
e.printStackTrace();
throw new RuntimeException("error loading "+loadTextPath);
}
} else if (crf.flags.loadJarClassifier != null) {
crf.loadJarClassifier(crf.flags.loadJarClassifier, props);
} else if (crf.flags.trainFile != null || crf.flags.trainFileList != null) {
crf.train();
} else {
crf.loadDefaultClassifier();
}
// System.err.println("Using " + crf.flags.featureFactory);
// System.err.println("Using " + StringUtils.getShortClassName(crf.readerAndWriter));
if (serializeTo != null) {
crf.serializeClassifier(serializeTo);
}
if (serializeToText != null) {
crf.serializeTextClassifier(serializeToText);
}
if (testFile != null) {
if (crf.flags.searchGraphPrefix != null) {
crf.classifyAndWriteViterbiSearchGraph(testFile,crf.flags.searchGraphPrefix);
} else if (crf.flags.printFirstOrderProbs) {
crf.printFirstOrderProbs(testFile);
} else if (crf.flags.printProbs) {
crf.printProbs(testFile);
} else if (crf.flags.useKBest) {
int k = crf.flags.kBest;
crf.classifyAndWriteAnswersKBest(testFile, k);
} else if (crf.flags.printLabelValue) {
crf.printLabelInformation(testFile);
} else {
crf.classifyAndWriteAnswers(testFile);
}
}
if (textFile != null) {
DocumentReaderAndWriter oldRW = crf.readerAndWriter;
crf.readerAndWriter = new PlainTextDocumentReaderAndWriter();
crf.readerAndWriter.init(crf.flags);
crf.classifyAndWriteAnswers(textFile);
crf.readerAndWriter = oldRW;
}
} // end main
} // end class CRFClassifier