package is2.parserR2; import extractors.Extractor; import extractors.ExtractorFactory; import is2.data.*; import is2.io.CONLLReader09; import is2.io.CONLLWriter09; import is2.tools.Tool; import is2.util.DB; import is2.util.OptionsSuper; import is2.util.ParserEvaluator; import java.io.*; import java.text.SimpleDateFormat; import java.util.Date; import java.util.HashMap; import java.util.List; import java.util.Map.Entry; import java.util.zip.ZipEntry; import java.util.zip.ZipInputStream; import java.util.zip.ZipOutputStream; public class Parser implements Tool { // output evaluation info private static final boolean MAX_INFO = true; public static int THREADS = 4; Long2IntInterface l2i; ParametersFloat params; Pipe pipe; OptionsSuper options; HashMap<Integer, Integer> rank = new HashMap<>(); int amongxbest = 0, amongxbest_ula = 0, nbest = 0, bestProj = 0, smallestErrorSum = 0, countAllNodes = 0; static int NBest = 1000; ExtractorFactory extractorFactory = new ExtractorFactory(ExtractorFactory.StackedClusteredR2); /** * Initialize the parser * * @param options */ public Parser(OptionsSuper options) { this.options = options; pipe = new Pipe(options); params = new ParametersFloat(0); // load the model try { readModel(options, pipe, params); } catch (Exception e) { e.printStackTrace(); } } /** * @param modelFileName The file name of the parsing model */ public Parser(String modelFileName) { this(new Options(new String[]{"-model", modelFileName})); } /** * */ public Parser() { } public static void main(String[] args) throws Exception { long start = System.currentTimeMillis(); OptionsSuper options = new Options(args); NBest = options.best; DB.println("n-best" + NBest); Runtime runtime = Runtime.getRuntime(); THREADS = runtime.availableProcessors(); if (options.cores < THREADS && options.cores > 0) { THREADS = options.cores; } DB.println("Found " + runtime.availableProcessors() + " cores use " + THREADS); if (options.train) { Parser p = new Parser(); p.options = options; p.l2i = new Long2Int(options.hsize); p.pipe = new Pipe(options); Instances is = new Instances(); p.pipe.extractor = new Extractor[THREADS]; for (int t = 0; t < THREADS; t++) { p.pipe.extractor[t] = p.extractorFactory.getExtractor(p.l2i); } p.params = new ParametersFloat(p.l2i.size()); if (options.useMapping != null) { String model = options.modelName; options.modelName = options.useMapping; DB.println("Using mapping of model " + options.modelName); ZipInputStream zis = new ZipInputStream(new BufferedInputStream(new FileInputStream(options.modelName))); zis.getNextEntry(); try (DataInputStream dis = new DataInputStream(new BufferedInputStream(zis))) { p.pipe.mf.read(dis); DB.println("read\n" + p.pipe.mf.toString()); ParametersFloat params = new ParametersFloat(0); params.read(dis); Edges.read(dis); } DB.println("end read model"); options.modelName = model; } p.pipe.createInstances(options.trainfile, is); p.train(options, p.pipe, p.params, is, p.pipe.cl); p.writeModell(options, p.params, null, p.pipe.cl); } if (options.test) { Parser p = new Parser(); p.options = options; p.pipe = new Pipe(options); p.params = new ParametersFloat(0); // total should be zero and the parameters are later read // load the model p.readModel(options, p.pipe, p.params); DB.println("test on " + options.testfile); is2.parser.Parser.out.println("" + p.pipe.mf.toString()); p.outputParses(options, p.pipe, p.params, !MAX_INFO); } is2.parser.Parser.out.println(); if (options.eval) { is2.parser.Parser.out.println("\nEVALUATION PERFORMANCE:"); ParserEvaluator.evaluate(options.goldfile, options.outfile); } long end = System.currentTimeMillis(); is2.parser.Parser.out.println("used time " + ((float) ((end - start) / 100) / 10)); Decoder.executerService.shutdown(); Pipe.executerService.shutdown(); is2.parser.Parser.out.println("end."); } /** * Read the models and mapping * * @param options * @param pipe * @param params * @throws IOException */ public void readModel(OptionsSuper options, Pipe pipe, Parameters params) throws IOException { DB.println("Reading data started"); // prepare zipped reader ZipInputStream zis = new ZipInputStream(new BufferedInputStream(new FileInputStream(options.modelName))); zis.getNextEntry(); try (DataInputStream dis = new DataInputStream(new BufferedInputStream(zis))) { pipe.mf.read(dis); pipe.cl = new Cluster(dis); params.read(dis); this.l2i = new Long2Int(params.size()); DB.println("parsing -- li size " + l2i.size()); pipe.extractor = new Extractor[THREADS]; for (int t = 0; t < THREADS; t++) { pipe.extractor[t] = this.extractorFactory.getExtractor(l2i); } Edges.read(dis); options.decodeProjective = dis.readBoolean(); int maxForm = dis.readInt(); for (int t = 0; t < THREADS; t++) { pipe.extractor[t].setMaxForm(maxForm); pipe.extractor[t].initStat(); pipe.extractor[t].init(); } boolean foundInfo = false; try { String info; int icnt = dis.readInt(); for (int i = 0; i < icnt; i++) { info = dis.readUTF(); is2.parser.Parser.out.println(info); } } catch (Exception e) { if (!foundInfo) { is2.parser.Parser.out.println("no info about training"); } } } DB.println("Reading data finnished"); Decoder.NON_PROJECTIVITY_THRESHOLD = (float) options.decodeTH; for (int t = 0; t < THREADS; t++) { pipe.extractor[t].initStat(); pipe.extractor[t].init(); } } /** * Do the training * * @param instanceLengths * @param options * @param pipe * @param params * @param is * @param cluster * @throws IOException * @throws InterruptedException * @throws ClassNotFoundException */ public void train(OptionsSuper options, Pipe pipe, ParametersFloat params, Instances is, Cluster cluster) throws IOException, InterruptedException, ClassNotFoundException { DB.println("\nTraining Information "); DB.println("-------------------- "); Decoder.NON_PROJECTIVITY_THRESHOLD = (float) options.decodeTH; if (options.decodeProjective) { is2.parser.Parser.out.println("Decoding: " + (options.decodeProjective ? "projective" : "non-projective")); } else { is2.parser.Parser.out.println("" + Decoder.getInfo()); } int numInstances = is.size(); int maxLenInstances = 0; for (int i = 0; i < numInstances; i++) { if (maxLenInstances < is.length(i)) { maxLenInstances = is.length(i); } } DataF data = new DataF(maxLenInstances, pipe.mf.getFeatureCounter().get(PipeGen.REL).shortValue()); int iter = 0; int del = 0; float error; float f1; FV pred = new FV(); FV act = new FV(); double upd = (double) (numInstances * options.numIters) + 1; for (; iter < options.numIters; iter++) { is2.parser.Parser.out.print("Iteration " + iter + ": "); long start = System.currentTimeMillis(); long last = System.currentTimeMillis(); error = 0; f1 = 0; for (int n = 0; n < numInstances; n++) { upd--; if (is.labels[n].length > options.maxLen) { continue; } String info = " td " + ((Decoder.timeDecotder) / 1000000F) + " tr " + ((Decoder.timeRearrange) / 1000000F) + " te " + ((Pipe.timeExtract) / 1000000F); if ((n + 1) % 500 == 0) { PipeGen.outValueErr(n + 1, Math.round(error * 1000) / 1000, f1 / n, last, upd, info); } short pos[] = is.pposs[n]; data = pipe.fillVector((F2SF) params.getFV(), is, n, data, cluster, THREADS, l2i); List<ParseNBest> parses = Decoder.decode(pos, data, options.decodeProjective, pipe.extractor[0]); Parse d = parses.get(0); double e = pipe.errors(is, n, d); if (d.f1 > 0) { f1 += (d.labels.length - 1 - e) / (d.labels.length - 1); } if (e <= 0) { continue; } // get predicted feature vector pred.clear(); pipe.extractor[0].encodeCat(is, n, pos, is.forms[n], is.plemmas[n], d.heads, d.labels, is.feats[n], pipe.cl, pred); error += e; act.clear(); pipe.extractor[0].encodeCat(is, n, pos, is.forms[n], is.plemmas[n], is.heads[n], is.labels[n], is.feats[n], pipe.cl, act); params.update(act, pred, is, n, d, upd, e); } String info = " td " + ((Decoder.timeDecotder) / 1000000F) + " tr " + ((Decoder.timeRearrange) / 1000000F) + " te " + ((Pipe.timeExtract) / 1000000F) + " nz " + params.countNZ(); PipeGen.outValueErr(numInstances, Math.round(error * 1000) / 1000, f1 / numInstances, last, upd, info); del = 0; long end = System.currentTimeMillis(); is2.parser.Parser.out.println(" time:" + (end - start)); ParametersFloat pf = params.average2((iter + 1) * is.size()); try { if (options.testfile != null) { outputParses(options, pipe, pf, !MAX_INFO); ParserEvaluator.evaluate(options.goldfile, options.outfile); // writeModell(options, pf, ""+(iter+1),pipe.cl); } } catch (Exception e) { e.printStackTrace(); } Decoder.timeDecotder = 0; Decoder.timeRearrange = 0; Pipe.timeExtract = 0; } params.average(iter * is.size()); } /** * Do the parsing * * @param options * @param pipe * @param params * @throws IOException */ private void outputParses(OptionsSuper options, Pipe pipe, ParametersFloat params, boolean maxInfo) throws Exception { long start = System.currentTimeMillis(); CONLLReader09 depReader = new CONLLReader09(options.testfile, options.formatTask); CONLLWriter09 depWriter = new CONLLWriter09(options.outfile, options.formatTask); // ExtractorClusterStacked.initFeatures(); int cnt = 0; int del = 0; long last = System.currentTimeMillis(); if (maxInfo) { is2.parser.Parser.out.println("\nParsing Information "); } if (maxInfo) { is2.parser.Parser.out.println("------------------- "); } if (maxInfo && !options.decodeProjective) { is2.parser.Parser.out.println("" + Decoder.getInfo()); } // if (!maxInfo) Parser.out.println(); String[] types = new String[pipe.mf.getFeatureCounter().get(PipeGen.REL)]; for (Entry<String, Integer> e : MFB.getFeatureSet().get(PipeGen.REL).entrySet()) { types[e.getValue()] = e.getKey(); } is2.parser.Parser.out.print("Processing Sentence: "); while (true) { Instances is = new Instances(); is.init(1, new MFB(), options.formatTask); SentenceData09 instance = pipe.nextInstance(is, depReader); if (instance == null) { break; } cnt++; SentenceData09 i09 = this.parse(instance, params); // } depWriter.write(i09); del = PipeGen.outValue(cnt, del, last); // DB.println("xbest "+amongxbest+" cnt "+cnt+" "+((float)((float)amongxbest/cnt))+" nbest "+((float)nbest/cnt)+ // " 1best "+((float)(rank.get(0)==null?0:rank.get(0))/cnt)+" best-proj "+((float)bestProj/cnt)); } //pipe.close(); depWriter.finishWriting(); long end = System.currentTimeMillis(); DB.println("rank\n" + rank + "\n"); DB.println("x-best-las " + amongxbest + " x-best-ula " + amongxbest_ula + " cnt " + cnt + " x-best-las " + ((float) ((float) amongxbest / cnt)) + " x-best-ula " + ((float) ((float) amongxbest_ula / cnt)) + " nbest " + ((float) nbest / cnt) + " 1best " + ((float) (rank.get(0) == null ? 0 : rank.get(0)) / cnt) + " best-proj " + ((float) bestProj / cnt) + " Sum LAS " + ((float) this.smallestErrorSum / countAllNodes)); // DB.println("errors "+error); rank.clear(); amongxbest = 0; amongxbest_ula = 0; nbest = 0; bestProj = 0; if (maxInfo) { is2.parser.Parser.out.println("Used time " + (end - start)); } if (maxInfo) { is2.parser.Parser.out.println("forms count " + Instances.m_count + " unkown " + Instances.m_unkown); } } /** * Do the parsing * * @param options * @param pipe * @param params * @throws IOException */ private void getNBest(OptionsSuper options, Pipe pipe, ParametersFloat params, boolean maxInfo) throws Exception { CONLLReader09 depReader = new CONLLReader09(options.testfile, options.formatTask); // ExtractorClusterStacked.initFeatures(); int cnt = 0; String[] types; types = new String[pipe.mf.getFeatureCounter().get(PipeGen.REL)]; for (Entry<String, Integer> e : MFB.getFeatureSet().get(PipeGen.REL).entrySet()) { types[e.getValue()] = e.getKey(); } // Parser.out.print("Processing Sentence: "); while (true) { Instances is = new Instances(); is.init(1, new MFB(), options.formatTask); SentenceData09 instance = pipe.nextInstance(is, depReader); if (instance == null) { break; } cnt++; this.parseNBest(instance); } //pipe.close(); // depWriter.finishWriting(); // long end = System.currentTimeMillis(); // DB.println("rank\n"+rank+"\n"); // DB.println("x-best-las "+amongxbest+" x-best-ula "+amongxbest_ula+" cnt "+cnt+" x-best-las " // +((float)((float)amongxbest/cnt))+ // " x-best-ula "+((float)((float)amongxbest_ula/cnt))+ // " nbest "+((float)nbest/cnt)+ // " 1best "+((float)(rank.get(0)==null?0:rank.get(0))/cnt)+ // " best-proj "+((float)bestProj/cnt)); // DB.println("errors "+error); } public SentenceData09 parse(SentenceData09 instance, ParametersFloat params) throws IOException { String[] types = new String[pipe.mf.getFeatureCounter().get(PipeGen.REL)]; for (Entry<String, Integer> e : MFB.getFeatureSet().get(PipeGen.REL).entrySet()) { types[e.getValue()] = e.getKey(); } Instances is = new Instances(); is.init(1, new MFB(), options.formatTask); new CONLLReader09().insert(is, instance); String[] forms = instance.forms; // use for the training ppos DataF d2; try { d2 = pipe.fillVector(params.getFV(), is, 0, null, pipe.cl, THREADS, l2i);//cnt-1 } catch (Exception e) { e.printStackTrace(); return null; } short[] pos = is.pposs[0]; List<ParseNBest> parses = null; Parse d = null; try { parses = Decoder.decode(pos, d2, options.decodeProjective, pipe.extractor[0]); //cnt-1 d = parses.get(0); } catch (Exception e) { e.printStackTrace(); } if (parses.size() > NBest) { parses = parses.subList(0, NBest); } int g_las = Decoder.getGoldRank(parses, is, 0, Decoder.LAS); int g_ula = Decoder.getGoldRank(parses, is, 0, !Decoder.LAS); int smallest = Decoder.getSmallestError(parses, is, 0, !Decoder.LAS); smallestErrorSum += is.length(0) - smallest; countAllNodes += is.length(0); if (g_las >= 0) { amongxbest++; } if (g_ula >= 0) { amongxbest_ula++; } nbest += parses.size(); Integer r = rank.get(g_las); if (r == null) { rank.put(g_las, 1); } else { rank.put(g_las, r + 1); } float err = (float) this.pipe.errors(is, 0, d); float errBestProj = (float) this.pipe.errors(is, 0, Decoder.bestProj); if (errBestProj == 0) { bestProj++; } SentenceData09 i09 = new SentenceData09(instance); i09.createSemantic(instance); for (int j = 0; j < forms.length - 1; j++) { i09.plabels[j] = types[d.labels[j + 1]]; i09.pheads[j] = d.heads[j + 1]; } return i09; } public List<ParseNBest> parseNBest(SentenceData09 instance) throws IOException { Instances is = new Instances(); is.init(1, new MFB(), options.formatTask); new CONLLReader09().insert(is, instance); // use for the training ppos DataF d2; try { d2 = pipe.fillVector(params.getFV(), is, 0, null, pipe.cl, THREADS, l2i);//cnt-1 } catch (Exception e) { e.printStackTrace(); return null; } short[] pos = is.pposs[0]; List<ParseNBest> parses = null; try { parses = Decoder.decode(pos, d2, options.decodeProjective, pipe.extractor[0]); //cnt-1 } catch (Exception e) { e.printStackTrace(); } if (parses.size() > NBest) { parses = parses.subList(0, NBest); } return parses; } /* * (non-Javadoc) @see is2.tools.Tool#apply(is2.data.SentenceData09) */ @Override public SentenceData09 apply(SentenceData09 snt09) { try { parse(snt09, this.params); } catch (Exception e) { e.printStackTrace(); } Decoder.executerService.shutdown(); Pipe.executerService.shutdown(); return snt09; } /** * Write the parsing model * * @param options * @param params * @param extension * @throws FileNotFoundException * @throws IOException */ private void writeModell(OptionsSuper options, ParametersFloat params, String extension, Cluster cs) throws FileNotFoundException, IOException { String name = extension == null ? options.modelName : options.modelName + extension; // Parser.out.println("Writting model: "+name); ZipOutputStream zos = new ZipOutputStream(new BufferedOutputStream(new FileOutputStream(name))); zos.putNextEntry(new ZipEntry("data")); try (DataOutputStream dos = new DataOutputStream(new BufferedOutputStream(zos))) { MFB.writeData(dos); cs.write(dos); params.write(dos); Edges.write(dos); dos.writeBoolean(options.decodeProjective); dos.writeInt(pipe.extractor[0].getMaxForm()); dos.writeInt(5); // Info count dos.writeUTF("Used parser " + Parser.class.toString()); dos.writeUTF("Creation date " + (new SimpleDateFormat("yyyy.MM.dd HH:mm:ss")).format(new Date())); dos.writeUTF("Training data " + options.trainfile); dos.writeUTF("Iterations " + options.numIters + " Used sentences " + options.count); dos.writeUTF("Cluster " + options.clusterFile); dos.flush(); } } }