package is2.parserR2; import decoder.ParallelDecoder; import decoder.ParallelRearrangeNBest; import decoder.ParallelRearrangeNBest2; import extractors.Extractor; import is2.data.*; import java.util.ArrayList; import java.util.Collections; import java.util.HashSet; import java.util.List; import java.util.concurrent.ExecutorService; /** * @author Bernd Bohnet, 01.09.2009 * * This methods do the actual work and they build the dependency trees. */ final public class Decoder { public static final boolean TRAINING = true; public static long timeDecotder; public static long timeRearrange; public static final boolean LAS = true; /** * Threshold for rearrange edges non-projective */ public static float NON_PROJECTIVITY_THRESHOLD = 0.3F; public static ExecutorService executerService = java.util.concurrent.Executors.newFixedThreadPool(Parser.THREADS); // do not initialize private Decoder() { } /** * Build a dependency tree based on the data * * @param pos part-of-speech tags * @param x the data * @param projective projective or non-projective * @param edges the edges * @return a parse tree * @throws InterruptedException */ public static List<ParseNBest> decode(short[] pos, DataF x, boolean projective, Extractor extractor) throws InterruptedException { long ts = System.nanoTime(); if (executerService.isShutdown()) { executerService = java.util.concurrent.Executors.newCachedThreadPool(); } final int n = pos.length; final Open O[][][][] = new Open[n][n][2][]; final Closed C[][][][] = new Closed[n][n][2][]; ArrayList<ParallelDecoder> pe = new ArrayList<>(); for (int i = 0; i < Parser.THREADS; i++) { pe.add(new ParallelDecoder(pos, x, O, C, n)); } for (short k = 1; k < n; k++) { // provide the threads the data for (short s = 0; s < n; s++) { short t = (short) (s + k); if (t >= n) { break; } ParallelDecoder.add(s, t); } executerService.invokeAll(pe); } double bestSpanScore = (-1.0F / 0.0F); Closed bestSpan = null; for (int m = 1; m < n; m++) { if (C[0][n - 1][1][m].p > bestSpanScore) { bestSpanScore = C[0][n - 1][1][m].p; bestSpan = C[0][n - 1][1][m]; } } // build the dependency tree from the chart ParseNBest out = new ParseNBest(pos.length); bestSpan.create(out); out.heads[0] = -1; out.labels[0] = 0; bestProj = out; timeDecotder += (System.nanoTime() - ts); // DB.println(""+out); ts = System.nanoTime(); List<ParseNBest> parses; if (!projective) { // if (training) // rearrange(pos, out.heads, out.types,x,training); //else { // DB.println("bestSpan score "+(float)bestSpan.p+" comp score "+Extractor.encode3(pos, out.heads, out.types, x)); // System.out.println(); // Parse best = new Parse(out.heads,out.types,Extractor.encode3(pos, out.heads, out.types, x)); parses = rearrangeNBest(pos, out.heads, out.labels, x, extractor); // DB.println("1best "+parses.get(0).f1); // DB.println(""+parses.get(0).toString()); // for(ParseNBest p :parses) if (p.heads==null) p.signature2parse(p.signature()); /// if (parses.get(0).f1>(best.f1+NON_PROJECTIVITY_THRESHOLD)) out = parses.get(0); // else out =best; // } } else { parses = new ArrayList<>(); parses.add(out); } timeRearrange += (System.nanoTime() - ts); return parses; } static Parse bestProj = null; /** * This is the parallel non-projective edge re-arranger * * @param pos part-of-speech tags * @param heads parent child relation * @param labs edge labels * @param x the data * @param edges the existing edges defined by part-of-speech tags * @throws InterruptedException */ public static List<ParseNBest> rearrangeNBestP(short[] pos, short[] heads, short[] labs, DataF x, Extractor extractor) throws InterruptedException { ArrayList<ParallelRearrangeNBest2> pe = new ArrayList<>(); int round = 0; ArrayList<ParseNBest> parses = new ArrayList<>(); ParseNBest px = new ParseNBest(); px.signature(heads, labs); //Object extractor; px.f1 = extractor.encode3(pos, heads, labs, x); parses.add(px); float lastNBest = Float.NEGATIVE_INFINITY; HashSet<Parse> done = new HashSet<>(); gnu.trove.THashSet<CharSequence> contained = new gnu.trove.THashSet<>(); while (true) { pe.clear(); // used the first three parses int ic = 0, considered = 0; while (true) { if (parses.size() <= ic || considered > 11) { break; } ParseNBest parse = parses.get(ic); ic++; // parse already extended if (done.contains(parse)) { continue; } considered++; parse.signature2parse(parse.signature()); done.add(parse); boolean[][] isChild = new boolean[heads.length][heads.length]; for (int i = 1, l1 = 1; i < heads.length; i++, l1 = i) { while ((l1 = heads[l1]) != -1) { isChild[l1][i] = true; } } // check the list of new possible parents and children for a better combination for (short ch = 1; ch < heads.length; ch++) { for (short pa = 0; pa < heads.length; pa++) { if (ch == pa || pa == heads[ch] || isChild[ch][pa]) { continue; } ParallelRearrangeNBest2.add(parse.clone(), ch, pa); } } } for (int t = 0; t < Parser.THREADS; t++) { pe.add(new ParallelRearrangeNBest2(pos, x, lastNBest, extractor, NON_PROJECTIVITY_THRESHOLD)); } executerService.invokeAll(pe); // avoid to add parses several times for (ParallelRearrangeNBest2 rp : pe) { for (int k = rp.parses.size() - 1; k >= 0; k--) { if (lastNBest > rp.parses.get(k).f1) { continue; } CharSequence sig = rp.parses.get(k).signature(); if (!contained.contains(sig)) { parses.add(rp.parses.get(k)); contained.add(sig); } } } Collections.sort(parses); if (round >= 2) { break; } round++; // do not use to much memory if (parses.size() > Parser.NBest) { // if (parses.get(Parser.NBest).f1>lastNBest) lastNBest = (float)parses.get(Parser.NBest).f1; parses.subList(Parser.NBest, parses.size() - 1).clear(); } } return parses; } /** * This is the parallel non-projective edge re-arranger * * @param pos part-of-speech tags * @param heads parent child relation * @param labs edge labels * @param x the data * @param edges the existing edges defined by part-of-speech tags * @throws InterruptedException */ public static List<ParseNBest> rearrangeNBest(short[] pos, short[] heads, short[] labs, DataF x, Extractor extractor) throws InterruptedException { ArrayList<ParallelRearrangeNBest> pe = new ArrayList<>(); int round = 0; ArrayList<ParseNBest> parses = new ArrayList<>(); ParseNBest px = new ParseNBest(); px.signature(heads, labs); //Object extractor; px.f1 = extractor.encode3(pos, heads, labs, x); parses.add(px); float lastNBest = Float.NEGATIVE_INFINITY; HashSet<Parse> done = new HashSet<>(); gnu.trove.THashSet<CharSequence> contained = new gnu.trove.THashSet<>(); while (true) { pe.clear(); // used the first three parses int i = 0; while (true) { if (parses.size() <= i || pe.size() > 12) { break; } ParseNBest parse = parses.get(i); i++; // parse already extended if (done.contains(parse)) { continue; } // DB.println("err "+parse.heads); parse.signature2parse(parse.signature()); done.add(parse); pe.add(new ParallelRearrangeNBest(pos, x, parse, lastNBest, extractor, (float) parse.f1, NON_PROJECTIVITY_THRESHOLD)); } executerService.invokeAll(pe); // avoid to add parses several times for (ParallelRearrangeNBest rp : pe) { for (int k = rp.parses.size() - 1; k >= 0; k--) { if (lastNBest > rp.parses.get(k).f1) { continue; } CharSequence sig = rp.parses.get(k).signature(); if (!contained.contains(sig)) { parses.add(rp.parses.get(k)); contained.add(sig); } } } Collections.sort(parses); if (round >= 2) { break; } round++; // do not use to much memory if (parses.size() > Parser.NBest) { if (parses.get(Parser.NBest).f1 > lastNBest) { lastNBest = (float) parses.get(Parser.NBest).f1; } parses.subList(Parser.NBest, parses.size() - 1).clear(); } } return parses; } public static String getInfo() { return "Decoder non-projectivity threshold: " + NON_PROJECTIVITY_THRESHOLD; } /** * @param parses * @param is * @param i * @return */ public static int getGoldRank(List<ParseNBest> parses, Instances is, int i, boolean las) { for (int p = 0; p < parses.size(); p++) { if (parses.get(p).heads == null) { parses.get(p).signature2parse(parses.get(p).signature()); } boolean eq = true; for (int w = 1; w < is.length(0); w++) { if (is.heads[i][w] != parses.get(p).heads[w] || (is.labels[i][w] != parses.get(p).labels[w] && las)) { eq = false; break; } } if (eq) { return p; } } return -1; } public static int getSmallestError(List<ParseNBest> parses, Instances is, int i, boolean las) { int smallest = -1; for (int p = 0; p < parses.size(); p++) { int err = 0; for (int w = 1; w < is.length(0); w++) { if (is.heads[i][w] != parses.get(p).heads[w] || (is.labels[i][w] != parses.get(p).labels[w] && las)) { err++; } } if (smallest == -1 || smallest > err) { smallest = err; } if (smallest == 0) { return 0; } } return smallest; } public static int getError(ParseNBest parse, Instances is, int i, boolean las) { int err = 0; for (int w = 1; w < is.length(i); w++) { if (is.heads[i][w] != parse.heads[w] || (is.labels[i][w] != parse.labels[w] && las)) { err++; } } return err; } }