package ir.ac.iust.nlp.dependencyparser.hybrid; import edu.stanford.nlp.parser.ensemble.utils.Eisner; import edu.stanford.nlp.parser.ensemble.utils.Scorer; import ir.ac.iust.nlp.dependencyparser.BasePanel; import ir.ac.iust.nlp.dependencyparser.parsing.RunnableParse; import ir.ac.iust.nlp.dependencyparser.training.RunnableTrain; import ir.ac.iust.nlp.dependencyparser.utility.enumeration.Flowchart; import ir.ac.iust.nlp.dependencyparser.utility.enumeration.HybridType; import ir.ac.iust.nlp.dependencyparser.utility.enumeration.ParserType; import ir.ac.iust.nlp.dependencyparser.utility.enumeration.ReparseType; import ir.ac.iust.nlp.dependencyparser.utility.parsing.MSTStackSettings; import ir.ac.iust.nlp.dependencyparser.utility.parsing.MaltSettings; import ir.ac.iust.nlp.dependencyparser.utility.parsing.MaltStackSettings; import ir.ac.iust.nlp.dependencyparser.utility.parsing.Settings; import java.io.File; import java.io.IOException; import java.io.PrintStream; import java.util.Calendar; import java.util.List; import optimizer.ValidationGenerator; /** * * @author Mojtaba Khallash */ public class RunnableHybrid implements Runnable { private PrintStream out = System.out; private HybridType type; BasePanel target; // Stacking setting Settings settings; private ParserType parser; // Ensemble settings String goldFile; List<String> sysFiles; String outFile; ReparseType reparseAlgorithm; public RunnableHybrid(BasePanel target, PrintStream out, ParserType parser, Settings settings) { this.target = target; if (out != null) { this.out = out; } this.type = HybridType.Stacking; this.parser = parser; this.settings = settings; } public RunnableHybrid(BasePanel target, PrintStream out, String goldFile, List<String> sysFiles, String outFile, ReparseType reparseAlgorithm) { this.target = target; if (out != null) { this.out = out; } this.type = HybridType.Ensemble; this.goldFile = goldFile; this.sysFiles = sysFiles; this.outFile = outFile; this.reparseAlgorithm = reparseAlgorithm; } @Override public void run() { try { switch (type) { case Ensemble: runEnsemble(); break; case Stacking: switch (parser) { case MSTParser: runMST(); break; case MaltParser: runMaltLevel0(); break; } break; } } finally { if (target != null) { target.threadFinished(); } } } private void runEnsemble() { String evalName = "eval07.pl"; boolean exist = true; try { exist = new File(evalName).exists(); if (reparseAlgorithm == ReparseType.chu_liu_edmond) { if (!exist) { java.io.BufferedWriter bwValidateFormat; try { bwValidateFormat = new java.io.BufferedWriter(new java.io.FileWriter(evalName)); bwValidateFormat.write(ValidationGenerator.generateEval07()); bwValidateFormat.close(); } catch (java.io.IOException e) { } } } else { File outPath = new File(outFile); if (!outPath.exists() && outPath.getParentFile() != null) { outPath.getParentFile().mkdirs(); } } Eisner.ensemble(goldFile, sysFiles, outFile, reparseAlgorithm.toString()); Scorer.Score s = Scorer.evaluate(outFile, goldFile); out.println(String.format("ensemble LAS: %.2f %d/%d", s.las, s.lcorrect, s.total)); out.println(String.format("ensemble UAS: %.2f %d/%d", s.uas, s.ucorrect, s.total)); } catch (IOException ex) {} finally { if (exist == false) { new File(evalName).delete(); } } } // <editor-fold defaultstate="collapsed" desc="Run Malt Level0"> private void runMaltLevel0() { MaltStackSettings params = new MaltStackSettings((MaltStackSettings)settings); if (params.Level != 0) { return; } // Create temp folder for run stacking String tmpFolder = String.valueOf( Calendar.getInstance().getTimeInMillis()) + File.separator; String currentDir = "tmp" + File.separator + tmpFolder; (new File(currentDir)).mkdirs(); params.WorkingDirectory = currentDir; try { //==============// // Train Level0 // //==============// out.println("Train Level0"); out.println("---------------------------------------------"); //==== Make Prediction of Train [train_pred.conll] ====// out.println("\nAugmenting training data with output predictions...\n"); // Make N gold file from "train.conll" file params.Chart = Flowchart.Train; params.preProcess(); MaltSettings maltSt; for (int i = 0; i < params.AugmentNParts; i++) { // Run train malt out.println("\nTraining classifier for partition " + i); maltSt = new MaltSettings((MaltSettings)params); maltSt.Model = currentDir + "modelname_level0_part" + i + ".mco"; maltSt.Input = currentDir + "_train" + i + ".conll"; RunnableTrain train = new RunnableTrain(null, ParserType.MaltParser, out, maltSt); train.run(); // Run test malt out.println("Making predictions for partition " + i); maltSt = new MaltSettings((MaltSettings)params); maltSt.Model = currentDir + "modelname_level0_part" + i + ".mco"; maltSt.Input = currentDir + "_test" + i + ".conll"; maltSt.Output = currentDir + "_parse" + i + ".conll"; RunnableParse parse = new RunnableParse(null, ParserType.MaltParser, out, maltSt); parse.run(); } // Merge N predicted part in "train_pred.conll" file params.postProcess(); //==== Train on the whole of "train.conll" ====// out.println("\nTraining the base classifier in the whole corpus..."); maltSt = new MaltSettings((MaltSettings)params); maltSt.Chart = Flowchart.Train; maltSt.Input = params.Input; maltSt.Model = currentDir + "modelname_level0.mco"; RunnableTrain runTrain = new RunnableTrain(null, ParserType.MaltParser, out, maltSt); runTrain.run(); //==============// // Parse Level0 // //==============// out.println("\nParse Level0"); out.println("---------------------------------------------"); params.Chart = Flowchart.Parse; params.preProcess(); maltSt = new MaltSettings((MaltSettings)params); maltSt.Model = currentDir + "modelname_level0.mco"; maltSt.Input = params.Gold; maltSt.Output = currentDir + "parse.conll"; RunnableParse parse = new RunnableParse(null, ParserType.MaltParser, out, maltSt); parse.run(); params.postProcess(); //=============// // Eval Level0 // //=============// out.println("\nEval Level0"); out.println("---------------------------------------------"); mstparser.DependencyParser.out = out; params.Chart = Flowchart.Eval; params.Output = maltSt.Output; mstparser.DependencyParser.main(params.getParameters()); } catch (Exception ex) { out.println("Error: " + ex.getMessage()); } } // </editor-fold> // <editor-fold defaultstate="collapsed" desc="Run MST"> private void runMST() { MSTStackSettings params = new MSTStackSettings((MSTStackSettings)settings); try { mstparser.DependencyParser.out = out; // Train Level out.println("Train Level" + params.Level); out.println("---------------------------------------------"); params.Chart = Flowchart.Train; params.preProcess(); mstparser.DependencyParser.main(params.getParameters()); params.postProcess(); // Parse out.println("\nParse Level" + params.Level); out.println("---------------------------------------------"); params.Chart = Flowchart.Parse; params.preProcess(); mstparser.DependencyParser.main(params.getParameters()); params.postProcess(); // Eval out.println("\nEval Level" + params.Level); out.println("---------------------------------------------"); params.Chart = Flowchart.Eval; mstparser.DependencyParser.main(params.getParameters()); } catch (Exception ex) { out.println("Error: " + ex.getMessage()); } } // </editor-fold> }