package mstparser; import edu.umass.cs.mallet.base.classify.Classifier; import java.io.*; import java.util.Arrays; import java.util.Date; import mstparser.mallet.LabelClassifier; public class DependencyParser { public ParserOptions options; private DependencyPipe pipe; private DependencyDecoder decoder; private Parameters params; // Used in "MSTParserStacked" private Classifier classifier; private long startTime; public static PrintStream out = System.out; Parameters getParams() { return params; } public DependencyParser(DependencyPipe pipe, ParserOptions options) { this.pipe = pipe; this.options = options; // Set up arrays params = new Parameters(pipe.dataAlphabet.size()); decoder = options.secondOrder ? new DependencyDecoder2O(pipe) : new DependencyDecoder(pipe); } // afm 03-06-08 --- Count the real number of instances to be considered public int countActualInstances(int ignore[]) { int i; int numInstances = ignore.length; int numActualInstances = 0; for (i = 0; i < numInstances; i++) { if (ignore[i] == 0) // This sentence is not to be ignored { numActualInstances++; } } return numActualInstances; } public void augment(int[] instanceLengths, String trainfile, File train_forest, int numParts) throws IOException { //out.print("About to train. "); //out.print("Num Feats: " + pipe.dataAlphabet.size()); int i, j; int[] ignore = new int[instanceLengths.length]; //String trainpartfile; //createPartitions(instanceLengths, trainfile, numParts); //for(i = 0; i < numParts; i++) //{ // trainpartfile = trainfile + "." + i; //} int numInstances = instanceLengths.length; int numInstancesPerPart = numInstances / numParts; // The last partition becomes bigger pipe.initOutputFile(options.outfile); // Initialize the output file once for (j = 0; j < numParts; j++) { out.println("Training classifier for partition " + j); // Make partition for (i = 0; i < numInstances; i++) { if ((i >= j * numInstancesPerPart && i < (j + 1) * numInstancesPerPart) || (j == numParts - 1 && i >= numParts * numInstancesPerPart)) { ignore[i] = 1; // Mark to ignore this instance in training } else { ignore[i] = 0; } } // Train on one split params = new Parameters(pipe.dataAlphabet.size()); train(instanceLengths, ignore, trainfile, train_forest); // Test on the other split out.println("Making predictions for partition " + j); for (i = 0; i < numInstances; i++) { ignore[i] = 1 - ignore[i]; // Toggle ignore } outputParses(ignore); } pipe.close(); // Close the output file once } public void train(int[] instanceLengths, int[] ignore, String trainfile, File train_forest) throws IOException { //out.print("About to train. "); //out.print("Num Feats: " + pipe.dataAlphabet.size()); int i; int count = options.numIters; for (i = 0; i < count; i++) { out.print(" Iteration " + i); //out.println("========================"); //out.println("Iteration: " + i); //out.println("========================"); out.print("["); long start = System.currentTimeMillis(); trainingIter(instanceLengths, ignore, trainfile, train_forest, i + 1); long end = System.currentTimeMillis(); //out.println("Training iter took: " + (end-start)); out.println("|Time:" + (end - start) + "]"); } params.averageParams(i * countActualInstances(ignore)); // afm 06-04-08 if (options.separateLab) { LabelClassifier oc = new LabelClassifier(options, instanceLengths, ignore, trainfile, train_forest, this, pipe); try { classifier = oc.trainClassifier(100); } catch (Exception e) { e.printStackTrace(); } } } // Note: Change this to pass -1 for indices in instanceLengths[] that you // don't want to use on training (need to be careful because i is being used // in the for loop; need new index) private void trainingIter(int[] instanceLengths, int ignore[], String trainfile, File train_forest, int iter) throws IOException { int numUpd = 0; try (ObjectInputStream in = new ObjectInputStream(new FileInputStream(train_forest))) { boolean evaluateI = true; int numInstances = instanceLengths.length; // afm -- Count the real number of instances to be considered int numActualInstances = countActualInstances(ignore); int j = 0; for (int i = 0; i < numInstances; i++) { if ((i + 1) % 500 == 0) { out.print((i + 1) + ","); //out.println(" "+(i+1)+" instances"); } int length = instanceLengths[i]; // Get production crap. FeatureVector[][][] fvs = new FeatureVector[length][length][2]; double[][][] probs = new double[length][length][2]; FeatureVector[][][][] nt_fvs = new FeatureVector[length][pipe.types.length][2][2]; double[][][][] nt_probs = new double[length][pipe.types.length][2][2]; FeatureVector[][][] fvs_trips = new FeatureVector[length][length][length]; double[][][] probs_trips = new double[length][length][length]; FeatureVector[][][] fvs_sibs = new FeatureVector[length][length][2]; double[][][] probs_sibs = new double[length][length][2]; DependencyInstance inst; if (options.secondOrder) { inst = ((DependencyPipe2O) pipe).readInstance(in, length, fvs, probs, fvs_trips, probs_trips, fvs_sibs, probs_sibs, nt_fvs, nt_probs, params); } else { inst = pipe.readInstance(in, length, fvs, probs, nt_fvs, nt_probs, params); } // afm 03-06-08 if (ignore[i] != 0) // This sentence is to be ignored { continue; } double upd = (double) (options.numIters * numActualInstances - (numActualInstances * (iter - 1) + (j + 1)) + 1); int K = options.trainK; Object[][] d = null; if (options.decodeType.equals("proj")) { if (options.secondOrder) { d = ((DependencyDecoder2O) decoder).decodeProjective(inst, fvs, probs, fvs_trips, probs_trips, fvs_sibs, probs_sibs, nt_fvs, nt_probs, K); } else { d = decoder.decodeProjective(inst, fvs, probs, nt_fvs, nt_probs, K); } } if (options.decodeType.equals("non-proj")) { if (options.secondOrder) { d = ((DependencyDecoder2O) decoder).decodeNonProjective(inst, fvs, probs, fvs_trips, probs_trips, fvs_sibs, probs_sibs, nt_fvs, nt_probs, K); } else { d = decoder.decodeNonProjective(inst, fvs, probs, nt_fvs, nt_probs, K); } } params.updateParamsMIRA(inst, d, upd); j++; } //out.println(""); //out.println(" "+numInstances+" instances"); out.print(numActualInstances); } } /////////////////////////////////////////////////////// // Saving and loading models /////////////////////////////////////////////////////// public void saveModel(String file) throws IOException { try (ObjectOutputStream outStream = new ObjectOutputStream(new FileOutputStream(file))) { outStream.writeObject(params.parameters); outStream.writeObject(pipe.dataAlphabet); outStream.writeObject(pipe.typeAlphabet); // afm 06-04-08 if (options.separateLab) { outStream.writeObject(classifier); } } } public void loadModel(String file) throws Exception { try (ObjectInputStream in = new ObjectInputStream(new FileInputStream(file))) { params.parameters = (double[]) in.readObject(); pipe.dataAlphabet = (Alphabet) in.readObject(); pipe.typeAlphabet = (Alphabet) in.readObject(); // afm 06-04-08 if (options.separateLab) { classifier = (Classifier) in.readObject(); } } pipe.closeAlphabets(); } ////////////////////////////////////////////////////// // Get Best Parses /////////////////////////////////// ////////////////////////////////////////////////////// public void outputParses() throws IOException { String tFile = options.testfile; String file = options.outfile; ConfidenceEstimator confEstimator = null; if (options.confidenceEstimator != null) { confEstimator = ConfidenceEstimator.resolveByName(options.confidenceEstimator, this); out.println("Applying confidence estimation: " + options.confidenceEstimator); } long start = System.currentTimeMillis(); pipe.initInputFile(tFile); if (!options.train || !options.stackedLevel0) // afm 03-07-2008 --- If this is called for each partition, must have initialized output file before { pipe.initOutputFile(file); } out.print("Processing Sentence: "); DependencyInstance instance = pipe.nextInstance(); int cnt = 0; while (instance != null) { cnt++; out.print(cnt + " "); String[] forms = instance.forms; String[] formsNoRoot = new String[forms.length - 1]; String[] posNoRoot = new String[formsNoRoot.length]; String[] labels = new String[formsNoRoot.length]; int[] heads = new int[formsNoRoot.length]; decode(instance, options.testK, params, formsNoRoot, posNoRoot, labels, heads); if (confEstimator != null) { double[] confidenceScores = confEstimator.estimateConfidence(instance); pipe.outputInstance(new DependencyInstance(formsNoRoot, instance.lemmas, posNoRoot, instance.postags, instance.feats, labels, heads, null, confidenceScores, instance.numbers)); } else { pipe.outputInstance(new DependencyInstance(formsNoRoot, instance.lemmas, posNoRoot, instance.postags, instance.feats, labels, heads, instance.numbers)); } //String line1 = ""; String line2 = ""; String line3 = ""; String line4 = ""; //for(int j = 1; j < pos.length; j++) { // String[] trip = res[j-1].split("[\\|:]"); // line1+= sent[j] + "\t"; line2 += pos[j] + "\t"; // line4 += trip[0] + "\t"; line3 += pipe.types[Integer.parseInt(trip[2])] + "\t"; //} //pred.write(line1.trim() + "\n" + line2.trim() + "\n" // + (pipe.labeled ? line3.trim() + "\n" : "") // + line4.trim() + "\n\n"); instance = pipe.nextInstance(); } pipe.close(); long end = System.currentTimeMillis(); out.println("Took: " + (end - start)); } public void outputParses(int[] ignore) throws IOException { String tFile = options.testfile; String file = options.outfile; ConfidenceEstimator confEstimator = null; if (options.confidenceEstimator != null) { confEstimator = ConfidenceEstimator.resolveByName(options.confidenceEstimator, this); out.println("Applying confidence estimation: " + options.confidenceEstimator); } long start = System.currentTimeMillis(); pipe.initInputFile(tFile); //if (ignore == null) // afm 03-07-2008 --- If this is called for each partition, must have initialized output file before if (!options.train || !options.stackedLevel0) // afm 03-07-2008 --- If this is called for each partition, must have initialized output file before { pipe.initOutputFile(file); } out.print("Processing Sentence: "); DependencyInstance instance = pipe.nextInstance(); int cnt = 0; int i = 0; LabelClassifier oc = new LabelClassifier(options); while (instance != null) { cnt++; out.print(cnt + " "); String[] forms = instance.forms; int length = forms.length; // afm 03-07-08 --- If this instance is to be ignored, just go for the next one if (ignore != null && ignore[i] != 0) { instance = pipe.nextInstance(); i++; continue; } FeatureVector[][][] fvs = new FeatureVector[forms.length][forms.length][2]; double[][][] probs = new double[forms.length][forms.length][2]; FeatureVector[][][][] nt_fvs = new FeatureVector[forms.length][pipe.types.length][2][2]; double[][][][] nt_probs = new double[forms.length][pipe.types.length][2][2]; FeatureVector[][][] fvs_trips = new FeatureVector[length][length][length]; double[][][] probs_trips = new double[length][length][length]; FeatureVector[][][] fvs_sibs = new FeatureVector[length][length][2]; double[][][] probs_sibs = new double[length][length][2]; if (options.secondOrder) { ((DependencyPipe2O) pipe).fillFeatureVectors(instance, fvs, probs, fvs_trips, probs_trips, fvs_sibs, probs_sibs, nt_fvs, nt_probs, params); } else { pipe.fillFeatureVectors(instance, fvs, probs, nt_fvs, nt_probs, params); } int K = options.testK; Object[][] d = null; if (options.decodeType.equals("proj")) { if (options.secondOrder) { d = ((DependencyDecoder2O) decoder).decodeProjective(instance, fvs, probs, fvs_trips, probs_trips, fvs_sibs, probs_sibs, nt_fvs, nt_probs, K); } else { d = decoder.decodeProjective(instance, fvs, probs, nt_fvs, nt_probs, K); } } if (options.decodeType.equals("non-proj")) { if (options.secondOrder) { d = ((DependencyDecoder2O) decoder).decodeNonProjective(instance, fvs, probs, fvs_trips, probs_trips, fvs_sibs, probs_sibs, nt_fvs, nt_probs, K); } else { d = decoder.decodeNonProjective(instance, fvs, probs, nt_fvs, nt_probs, K); } } String[] res = ((String) d[0][1]).split(" "); String[] pos = instance.cpostags; String[] formsNoRoot = new String[forms.length - 1]; String[] posNoRoot = new String[formsNoRoot.length]; String[] labels = new String[formsNoRoot.length]; int[] heads = new int[formsNoRoot.length]; Arrays.toString(forms); Arrays.toString(res); for (int j = 0; j < formsNoRoot.length; j++) { formsNoRoot[j] = forms[j + 1]; posNoRoot[j] = pos[j + 1]; String[] trip = res[j].split("[\\|:]"); labels[j] = pipe.types[Integer.parseInt(trip[2])]; heads[j] = Integer.parseInt(trip[0]); } // afm 06-04-08 if (options.separateLab) { /* * ask whether instance contains level0 information */ /* * Note, forms and pos have the root. labels and heads do not */ if (options.stackedLevel1) { labels = oc.outputLabels(classifier, instance.forms, instance.cpostags, labels, heads, instance.deprels_pred, instance.heads_pred, instance); } else { labels = oc.outputLabels(classifier, instance.forms, instance.cpostags, labels, heads, null, null, instance); } } // afm 03-07-08 //if (ignore == null) if (options.stackedLevel0 == false) { pipe.outputInstance(new DependencyInstance(formsNoRoot, instance.lemmas, posNoRoot, instance.postags, instance.feats, labels, heads, instance.numbers)); } else { int[] headsNoRoot = new int[instance.heads.length - 1]; String[] labelsNoRoot = new String[instance.heads.length - 1]; for (int j = 0; j < headsNoRoot.length; j++) { headsNoRoot[j] = instance.heads[j + 1]; labelsNoRoot[j] = instance.deprels[j + 1]; } DependencyInstance out_inst; if (confEstimator != null) { double[] confidenceScores = confEstimator.estimateConfidence(instance); out_inst = new DependencyInstance(formsNoRoot, instance.lemmas, posNoRoot, instance.postags, instance.feats, labelsNoRoot, headsNoRoot, null, confidenceScores, instance.numbers); } else { out_inst = new DependencyInstance(formsNoRoot, instance.lemmas, posNoRoot, instance.postags, instance.feats, labelsNoRoot, headsNoRoot, instance.numbers); } out_inst.stacked = true; out_inst.heads_pred = heads; out_inst.deprels_pred = labels; pipe.outputInstance(out_inst); } //String line1 = ""; String line2 = ""; String line3 = ""; String line4 = ""; //for(int j = 1; j < pos.length; j++) { // String[] trip = res[j-1].split("[\\|:]"); // line1+= sent[j] + "\t"; line2 += pos[j] + "\t"; // line4 += trip[0] + "\t"; line3 += pipe.types[Integer.parseInt(trip[2])] + "\t"; //} //pred.write(line1.trim() + "\n" + line2.trim() + "\n" // + (pipe.labeled ? line3.trim() + "\n" : "") // + line4.trim() + "\n\n"); instance = pipe.nextInstance(); i++; } //if (ignore == null) // afm 03-07-2008 --- If this is called for each partition (ignore != null), must close pipe outside the loop if (!options.train || !options.stackedLevel0) // afm 03-07-2008 --- If this is called for each partition (ignore != null), must close pipe outside the loop { pipe.close(); } long end = System.currentTimeMillis(); out.println("Took: " + (end - start)); } ////////////////////////////////////////////////////// // Decode single instance ////////////////////////////////////////////////////// String[] decode(DependencyInstance instance, int K, Parameters params) { String[] forms = instance.forms; int length = forms.length; FeatureVector[][][] fvs = new FeatureVector[forms.length][forms.length][2]; double[][][] probs = new double[forms.length][forms.length][2]; FeatureVector[][][][] nt_fvs = new FeatureVector[forms.length][pipe.types.length][2][2]; double[][][][] nt_probs = new double[forms.length][pipe.types.length][2][2]; FeatureVector[][][] fvs_trips = new FeatureVector[length][length][length]; double[][][] probs_trips = new double[length][length][length]; FeatureVector[][][] fvs_sibs = new FeatureVector[length][length][2]; double[][][] probs_sibs = new double[length][length][2]; if (options.secondOrder) { ((DependencyPipe2O) pipe).fillFeatureVectors(instance, fvs, probs, fvs_trips, probs_trips, fvs_sibs, probs_sibs, nt_fvs, nt_probs, params); } else { pipe.fillFeatureVectors(instance, fvs, probs, nt_fvs, nt_probs, params); } Object[][] d = null; if (options.decodeType.equals("proj")) { if (options.secondOrder) { d = ((DependencyDecoder2O) decoder).decodeProjective(instance, fvs, probs, fvs_trips, probs_trips, fvs_sibs, probs_sibs, nt_fvs, nt_probs, K); } else { d = decoder.decodeProjective(instance, fvs, probs, nt_fvs, nt_probs, K); } } if (options.decodeType.equals("non-proj")) { if (options.secondOrder) { d = ((DependencyDecoder2O) decoder).decodeNonProjective(instance, fvs, probs, fvs_trips, probs_trips, fvs_sibs, probs_sibs, nt_fvs, nt_probs, K); } else { d = decoder.decodeNonProjective(instance, fvs, probs, nt_fvs, nt_probs, K); } } String[] res = ((String) d[0][1]).split(" "); return res; } public void decode(DependencyInstance instance, int K, Parameters params, String[] formsNoRoot, String[] posNoRoot, String[] labels, int[] heads) { String[] forms = instance.forms; String[] res = decode(instance, K, params); String[] pos = instance.cpostags; for (int j = 0; j < forms.length - 1; j++) { formsNoRoot[j] = forms[j + 1]; posNoRoot[j] = pos[j + 1]; String[] trip = res[j].split("[\\|:]"); labels[j] = pipe.types[Integer.parseInt(trip[2])]; heads[j] = Integer.parseInt(trip[0]); } } public void decode(DependencyInstance instance, int K, Parameters params, int[] heads) { String[] res = decode(instance, K, params); for (int j = 0; j < instance.forms.length - 1; j++) { String[] trip = res[j].split("[\\|:]"); heads[j] = Integer.parseInt(trip[0]); } } ///////////////////////////////////////////////////// // RUNNING THE PARSER //////////////////////////////////////////////////// public static void main(String[] args) throws FileNotFoundException, Exception { out.print("Started: " + new Date(System.currentTimeMillis()) +"\n"); System.setProperty("java.io.tmpdir", "./tmp/"); ParserOptions options = new ParserOptions(args); out.println("Default temp directory:" + System.getProperty("java.io.tmpdir")); out.println("Separate labeling: " + options.separateLab); if (options.train) { DependencyPipe pipe = options.secondOrder ? new DependencyPipe2O(options) : new DependencyPipe(options); int[] instanceLengths = pipe.createInstances(options.trainfile, options.trainforest); pipe.closeAlphabets(); DependencyParser dp = new DependencyParser(pipe, options); dp.startTime = System.currentTimeMillis(); int numFeats = pipe.dataAlphabet.size(); int numTypes = pipe.typeAlphabet.size(); out.print("Num Feats: " + numFeats); out.println(".\tNum Edge Labels: " + numTypes); if (options.stackedLevel0) // Augment training data with output predictions, for stacked learning (afm 03-03-08) { // Output data augmented with output predictions out.println("Augmenting training data with output predictions..."); options.testfile = options.trainfile; dp.augment(instanceLengths, options.trainfile, options.trainforest, options.augmentNumParts); // Now train the base classifier in the whole corpus, nothing being ignored out.println("Training the base classifier in the whole corpus..."); } // afm 03-06-08 --- To allow some instances to be ignored int ignore[] = new int[instanceLengths.length]; for (int i = 0; i < instanceLengths.length; i++) { ignore[i] = 0; } dp.params = new Parameters(pipe.dataAlphabet.size()); dp.train(instanceLengths, ignore, options.trainfile, options.trainforest); out.print("Saving model..."); dp.saveModel(options.modelName); out.print("done."); out.println("\nTraining Time: " + CalculateTime(dp.startTime)); } if (options.test) { DependencyPipe pipe = options.secondOrder ? new DependencyPipe2O(options) : new DependencyPipe(options); DependencyParser dp = new DependencyParser(pipe, options); dp.startTime = System.currentTimeMillis(); out.print("\tLoading model..."); dp.loadModel(options.modelName); out.println("done."); if (options.separateLab == true || options.stackedLevel0 == true || options.stackedLevel1 == true) { pipe.printModelStats(dp.params); } pipe.closeAlphabets(); dp.outputParses(null); out.println("Parsing Time: " + CalculateTime(dp.startTime)); } if (options.eval) { out.println("EVALUATION PERFORMANCE:"); DependencyEvaluator.evaluate(options.goldfile, options.outfile, options.format, (options.confidenceEstimator != null)); } if (options.rankEdgesByConfidence) { out.println("\nRank edges by confidence:"); EdgeRankerByConfidence edgeRanker = new EdgeRankerByConfidence(); edgeRanker.rankEdgesByConfidence(options.goldfile, options.outfile, options.format); } out.println("Finished: " + new Date(System.currentTimeMillis())); } public static String CalculateTime(long startTime) { int a = 1000000; long time = (System.currentTimeMillis()-startTime)/1000; int hour = (int)(time / 3600); String hh = String.valueOf(hour); if (hh.length() == 1) { hh = "0" + hh; } time = time % 3600; int min = (int)(time / 60); String mm = String.valueOf(min); if (mm.length() == 1) { mm = "0" + mm; } int second = (int)(time % 60); String ss = String.valueOf(second); if (ss.length() == 1) { ss = "0" + ss; } return String.format("%s:%s:%s", hh, mm, ss); } }