package edu.stanford.nlp.ie.crf; import edu.stanford.nlp.math.ArrayMath; import edu.stanford.nlp.optimization.AbstractStochasticCachingDiffUpdateFunction; import edu.stanford.nlp.optimization.HasFeatureGrouping; import edu.stanford.nlp.util.concurrent.*; import edu.stanford.nlp.util.Index; import edu.stanford.nlp.util.Pair; import edu.stanford.nlp.util.logging.Redwood; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.Random; /** * @author Jenny Finkel * Mengqiu Wang */ public class CRFLogConditionalObjectiveFunction extends AbstractStochasticCachingDiffUpdateFunction implements HasCliquePotentialFunction, HasFeatureGrouping { /** A logger for this class */ private static final Redwood.RedwoodChannels log = Redwood.channels(CRFLogConditionalObjectiveFunction.class); public static final int NO_PRIOR = 0; public static final int QUADRATIC_PRIOR = 1; /* Use a Huber robust regression penalty (L1 except very near 0) not L2 */ public static final int HUBER_PRIOR = 2; public static final int QUARTIC_PRIOR = 3; public static final int DROPOUT_PRIOR = 4; // public static final boolean DEBUG2 = true; public static final boolean DEBUG2 = false; public static final boolean DEBUG3 = false; public static final boolean TIMED = false; // public static final boolean TIMED = true; public static final boolean CONDENSE = true; // public static final boolean CONDENSE = false; public static boolean VERBOSE = false; protected final int prior; protected final double sigma; protected final double epsilon = 0.1; // You can't actually set this at present /** label indices - for all possible label sequences - for each feature */ protected final List<Index<CRFLabel>> labelIndices; protected final Index<String> classIndex; // didn't have <String> before. Added since that's what is assumed everywhere. protected final double[][] Ehat; // empirical counts of all the features [feature][class] protected final double[][] E; protected double[][][] parallelE; protected double[][][] parallelEhat; protected final int window; protected final int numClasses; // public static Index<String> featureIndex; // no idea why this was here [cdm 2013] protected final int[] map; protected int[][][][] data; // data[docIndex][tokenIndex][][] protected double[][][][] featureVal; // featureVal[docIndex][tokenIndex][][] protected int[][] labels; // labels[docIndex][tokenIndex] protected final int domainDimension; // protected double[][] eHat4Update, e4Update; protected int[][] weightIndices; protected final String backgroundSymbol; protected int[][] featureGrouping = null; protected static final double smallConst = 1e-6; // protected static final double largeConst = 5; protected Random rand = new Random(2147483647L); protected final int multiThreadGrad; // need to ensure the following two objects are only read during multi-threading // to ensure thread-safety. It should only be modified in calculate() via setWeights() protected double[][] weights; protected CliquePotentialFunction cliquePotentialFunc; @Override public double[] initial() { return initial(rand); } public double[] initial(boolean useRandomSeed) { Random randToUse = useRandomSeed ? new Random() : rand; return initial(rand); } public double[] initial(Random randGen) { double[] initial = new double[domainDimension()]; for (int i = 0; i < initial.length; i++) { initial[i] = randGen.nextDouble() + smallConst; // initial[i] = generator.nextDouble() * largeConst; // initial[i] = -1+2*(i); // initial[i] = (i == 0 ? 1 : 0); } return initial; } public static int getPriorType(String priorTypeStr) { if (priorTypeStr == null) return QUADRATIC_PRIOR; // default if ("QUADRATIC".equalsIgnoreCase(priorTypeStr)) { return QUADRATIC_PRIOR; } else if ("HUBER".equalsIgnoreCase(priorTypeStr)) { return HUBER_PRIOR; } else if ("QUARTIC".equalsIgnoreCase(priorTypeStr)) { return QUARTIC_PRIOR; } else if ("DROPOUT".equalsIgnoreCase(priorTypeStr)) { return DROPOUT_PRIOR; } else if ("NONE".equalsIgnoreCase(priorTypeStr)) { return NO_PRIOR; } else if (priorTypeStr.equalsIgnoreCase("lasso") || priorTypeStr.equalsIgnoreCase("ridge") || priorTypeStr.equalsIgnoreCase("gaussian") || priorTypeStr.equalsIgnoreCase("ae-lasso") || priorTypeStr.equalsIgnoreCase("sg-lasso") || priorTypeStr.equalsIgnoreCase("g-lasso") ) { return NO_PRIOR; } else { throw new IllegalArgumentException("Unknown prior type: " + priorTypeStr); } } CRFLogConditionalObjectiveFunction(int[][][][] data, int[][] labels, int window, Index<String> classIndex, List<Index<CRFLabel>> labelIndices, int[] map, String priorType, String backgroundSymbol, double sigma, double[][][][] featureVal, int multiThreadGrad) { this(data, labels, window, classIndex, labelIndices, map, priorType, backgroundSymbol, sigma, featureVal, multiThreadGrad, true); } CRFLogConditionalObjectiveFunction(int[][][][] data, int[][] labels, int window, Index<String> classIndex, List<Index<CRFLabel>> labelIndices, int[] map, String priorType, String backgroundSymbol, double sigma, double[][][][] featureVal, int multiThreadGrad, boolean calcEmpirical) { this.window = window; this.classIndex = classIndex; this.numClasses = classIndex.size(); this.labelIndices = labelIndices; this.map = map; this.data = data; this.featureVal = featureVal; this.labels = labels; this.prior = getPriorType(priorType); this.backgroundSymbol = backgroundSymbol; this.sigma = sigma; this.multiThreadGrad = multiThreadGrad; // takes docIndex, returns Triple<prob, E, dropoutGrad> Ehat = empty2D(); E = empty2D(); weights = empty2D(); if (calcEmpirical) empiricalCounts(Ehat); int myDomainDimension = 0; for (int dim : map) { myDomainDimension += labelIndices.get(dim).size(); } domainDimension = myDomainDimension; } protected void empiricalCounts(double[][] eHat) { for (int m = 0; m < data.length; m++) { empiricalCountsForADoc(eHat, m); } } protected void empiricalCountsForADoc(double[][] eHat, int docIndex) { int[][][] docData = data[docIndex]; int[] docLabels = labels[docIndex]; int[] windowLabels = new int[window]; Arrays.fill(windowLabels, classIndex.indexOf(backgroundSymbol)); double[][][] featureValArr = null; if (featureVal != null) featureValArr = featureVal[docIndex]; if (docLabels.length>docData.length) { // only true for self-training // fill the windowLabel array with the extra docLabels System.arraycopy(docLabels, 0, windowLabels, 0, windowLabels.length); // shift the docLabels array left int[] newDocLabels = new int[docData.length]; System.arraycopy(docLabels, docLabels.length-newDocLabels.length, newDocLabels, 0, newDocLabels.length); docLabels = newDocLabels; } for (int i = 0; i < docData.length; i++) { System.arraycopy(windowLabels, 1, windowLabels, 0, window - 1); windowLabels[window - 1] = docLabels[i]; for (int j = 0; j < docData[i].length; j++) { int[] cliqueLabel = new int[j + 1]; System.arraycopy(windowLabels, window - 1 - j, cliqueLabel, 0, j + 1); CRFLabel crfLabel = new CRFLabel(cliqueLabel); int labelIndex = labelIndices.get(j).indexOf(crfLabel); //log.info(crfLabel + " " + labelIndex); for (int n = 0; n < docData[i][j].length; n++) { double fVal = 1.0; if (featureValArr != null && j == 0) // j == 0 because only node features gets feature values fVal = featureValArr[i][j][n]; eHat[docData[i][j][n]][labelIndex] += fVal; } } } } @Override public CliquePotentialFunction getCliquePotentialFunction(double[] x) { to2D(x, weights); return new LinearCliquePotentialFunction(weights); } protected double expectedAndEmpiricalCountsAndValueForADoc(double[][] E, double[][] Ehat, int docIndex) { empiricalCountsForADoc(Ehat, docIndex); return expectedCountsAndValueForADoc(E, docIndex); } public double valueForADoc(int docIndex) { return expectedCountsAndValueForADoc(null, docIndex, false, true); } protected double expectedCountsAndValueForADoc(double[][] E, int docIndex) { return expectedCountsAndValueForADoc(E, docIndex, true, true); } protected double expectedCountsForADoc(double[][] E, int docIndex) { return expectedCountsAndValueForADoc(E, docIndex, true, false); } protected double expectedCountsAndValueForADoc(double[][] E, int docIndex, boolean doExpectedCountCalc, boolean doValueCalc) { int[][][] docData = data[docIndex]; double[][][] featureVal3DArr = null; if (featureVal != null) { featureVal3DArr = featureVal[docIndex]; } // make a clique tree for this document CRFCliqueTree cliqueTree = CRFCliqueTree.getCalibratedCliqueTree(docData, labelIndices, numClasses, classIndex, backgroundSymbol, cliquePotentialFunc, featureVal3DArr); double prob = 0.0; if (doValueCalc) { prob = documentLogProbability(docData, docIndex, cliqueTree); } if (doExpectedCountCalc) { documentExpectedCounts(E, docData, featureVal3DArr, cliqueTree); } return prob; } /** Compute the expected counts for this document, which we will need to compute the derivative. */ protected void documentExpectedCounts(double[][] E, int[][][] docData, double[][][] featureVal3DArr, CRFCliqueTree cliqueTree) { // iterate over the positions in this document for (int i = 0; i < docData.length; i++) { // for each possible clique at this position for (int j = 0; j < docData[i].length; j++) { Index<CRFLabel> labelIndex = labelIndices.get(j); // for each possible labeling for that clique for (int k = 0, liSize = labelIndex.size(); k < liSize; k++) { int[] label = labelIndex.get(k).getLabel(); double p = cliqueTree.prob(i, label); // probability of these labels occurring in this clique with these features for (int n = 0; n < docData[i][j].length; n++) { double fVal = 1.0; if (j == 0 && featureVal3DArr != null) { // j == 0 because only node features gets feature values fVal = featureVal3DArr[i][j][n]; } E[docData[i][j][n]][k] += p * fVal; } } } } } /** Compute the log probability of the document given the model with the parameters x. */ private double documentLogProbability(int[][][] docData, int docIndex, CRFCliqueTree cliqueTree) { int[] docLabels = labels[docIndex]; int[] given = new int[window - 1]; Arrays.fill(given, classIndex.indexOf(backgroundSymbol)); if (docLabels.length>docData.length) { // only true for self-training // fill the given array with the extra docLabels System.arraycopy(docLabels, 0, given, 0, given.length); // shift the docLabels array left int[] newDocLabels = new int[docData.length]; System.arraycopy(docLabels, docLabels.length-newDocLabels.length, newDocLabels, 0, newDocLabels.length); docLabels = newDocLabels; } double startPosLogProb = cliqueTree.logProbStartPos(); if (VERBOSE) { System.err.printf("P_-1(Background) = % 5.3f%n", startPosLogProb); } double prob = startPosLogProb; // iterate over the positions in this document for (int i = 0; i < docData.length; i++) { int label = docLabels[i]; double p = cliqueTree.condLogProbGivenPrevious(i, label, given); if (VERBOSE) { log.info("P(" + label + "|" + ArrayMath.toString(given) + ")=" + p); } prob += p; System.arraycopy(given, 1, given, 0, given.length - 1); given[given.length - 1] = label; } return prob; } private ThreadsafeProcessor<Pair<Integer, List<Integer>>, Pair<Integer, Double>> expectedThreadProcessor = new ExpectationThreadsafeProcessor(); private ThreadsafeProcessor<Pair<Integer, List<Integer>>, Pair<Integer, Double>> expectedAndEmpiricalThreadProcessor = new ExpectationThreadsafeProcessor(true); class ExpectationThreadsafeProcessor implements ThreadsafeProcessor<Pair<Integer, List<Integer>>, Pair<Integer, Double>> { boolean calculateEmpirical = false; public ExpectationThreadsafeProcessor() { } public ExpectationThreadsafeProcessor(boolean calculateEmpirical) { this.calculateEmpirical = calculateEmpirical; } @Override public Pair<Integer, Double> process(Pair<Integer, List<Integer>> threadIDAndDocIndices) { int tID = threadIDAndDocIndices.first(); if (tID < 0 || tID >= multiThreadGrad) throw new IllegalArgumentException("threadID must be with in range 0 <= tID < multiThreadGrad(="+multiThreadGrad+")"); List<Integer> docIDs = threadIDAndDocIndices.second(); double[][] partE; // initialized below double[][] partEhat = null; // initialized below if (multiThreadGrad == 1) { partE = E; if (calculateEmpirical) partEhat = Ehat; } else { partE = parallelE[tID]; // TODO: if we put this on the heap, this clearing will be unnecessary clear2D(partE); if (calculateEmpirical) { partEhat = parallelEhat[tID]; clear2D(partEhat); } } double probSum = 0; for (int docIndex: docIDs) { if (calculateEmpirical) probSum += expectedAndEmpiricalCountsAndValueForADoc(partE, partEhat, docIndex); else probSum += expectedCountsAndValueForADoc(partE, docIndex); } return new Pair<>(tID, probSum); } @Override public ThreadsafeProcessor<Pair<Integer, List<Integer>>, Pair<Integer, Double>> newInstance() { return this; } } public void setWeights(double[][] weights) { this.weights = weights; cliquePotentialFunc = new LinearCliquePotentialFunction(weights); } protected double regularGradientAndValue() { int totalLen = data.length; List<Integer> docIDs = new ArrayList<>(totalLen); for (int m=0; m < totalLen; m++) docIDs.add(m); return multiThreadGradient(docIDs, false); } protected double multiThreadGradient(List<Integer> docIDs, boolean calculateEmpirical) { double objective = 0.0; // TODO: This is a bunch of unnecessary heap traffic, should all be on the stack if (multiThreadGrad > 1) { if (parallelE == null) { parallelE = new double[multiThreadGrad][][]; for (int i=0; i<multiThreadGrad; i++) parallelE[i] = empty2D(); } if (calculateEmpirical) { if (parallelEhat == null) { parallelEhat = new double[multiThreadGrad][][]; for (int i=0; i<multiThreadGrad; i++) parallelEhat[i] = empty2D(); } } } // TODO: this is a huge amount of machinery for no discernible reason MulticoreWrapper<Pair<Integer, List<Integer>>, Pair<Integer, Double>> wrapper = new MulticoreWrapper<>(multiThreadGrad, (calculateEmpirical ? expectedAndEmpiricalThreadProcessor : expectedThreadProcessor)); int totalLen = docIDs.size(); int partLen = totalLen / multiThreadGrad; int currIndex = 0; for (int part=0; part < multiThreadGrad; part++) { int endIndex = currIndex + partLen; if (part == multiThreadGrad-1) endIndex = totalLen; // TODO: let's not construct a sub-list of DocIDs, unnecessary object creation, can calculate directly from ThreadID List<Integer> subList = docIDs.subList(currIndex, endIndex); wrapper.put(new Pair<>(part, subList)); currIndex = endIndex; } wrapper.join(); // This all seems fine. May want to start running this after the joins, in case we have different end-times while (wrapper.peek()) { Pair<Integer, Double> result = wrapper.poll(); int tID = result.first(); objective += result.second(); if (multiThreadGrad > 1) { combine2DArr(E, parallelE[tID]); if (calculateEmpirical) combine2DArr(Ehat, parallelEhat[tID]); } } return objective; } /** * Calculates both value and partial derivatives at the point x, and save them internally. */ @Override public void calculate(double[] x) { // final double[][] weights = to2D(x); to2D(x, weights); setWeights(weights); // the expectations over counts // first index is feature index, second index is of possible labeling // double[][] E = empty2D(); clear2D(E); double prob = regularGradientAndValue(); // the log prob of the sequence given the model, which is the negation of value at this point if (Double.isNaN(prob)) { // shouldn't be the case throw new RuntimeException("Got NaN for prob in CRFLogConditionalObjectiveFunction.calculate()" + " - this may well indicate numeric underflow due to overly long documents."); } // because we minimize -L(\theta) value = -prob; if (VERBOSE) { log.info("value is " + Math.exp(-value)); } // compute the partial derivative for each feature by comparing expected counts to empirical counts int index = 0; for (int i = 0; i < E.length; i++) { for (int j = 0; j < E[i].length; j++) { // because we minimize -L(\theta) derivative[index] = (E[i][j] - Ehat[i][j]); if (VERBOSE) { log.info("deriv(" + i + "," + j + ") = " + E[i][j] + " - " + Ehat[i][j] + " = " + derivative[index]); } index++; } } applyPrior(x, 1.0); // log.info("\nfuncVal: " + value); } @Override public int dataDimension() { return data.length; } @Override public void calculateStochastic(double[] x, double [] v, int[] batch) { to2D(x, weights); setWeights(weights); double batchScale = ((double) batch.length)/((double) this.dataDimension()); // the expectations over counts // first index is feature index, second index is of possible labeling // double[][] E = empty2D(); // iterate over all the documents List<Integer> docIDs = new ArrayList<>(batch.length); for (int item : batch) { docIDs.add(item); } double prob = multiThreadGradient(docIDs, false); // the log prob of the sequence given the model, which is the negation of value at this point if (Double.isNaN(prob)) { // shouldn't be the case throw new RuntimeException("Got NaN for prob in CRFLogConditionalObjectiveFunction.calculate()"); } value = -prob; // compute the partial derivative for each feature by comparing expected counts to empirical counts int index = 0; for (int i = 0; i < E.length; i++) { for (int j = 0; j < E[i].length; j++) { // real gradient should be empirical-expected; // but since we minimize -L(\theta), the gradient is -(empirical-expected) derivative[index++] = (E[i][j] - batchScale*Ehat[i][j]); if (VERBOSE) { log.info("deriv(" + i + "," + j + ") = " + E[i][j] + " - " + Ehat[i][j] + " = " + derivative[index - 1]); } } } applyPrior(x, batchScale); } // re-initialization is faster than Arrays.fill(arr, 0) // private void clearUpdateEs() { // for (int i = 0; i < eHat4Update.length; i++) // eHat4Update[i] = new double[eHat4Update[i].length]; // for (int i = 0; i < e4Update.length; i++) // e4Update[i] = new double[e4Update[i].length]; // } /** * Performs stochastic update of weights x (scaled by xScale) based * on samples indexed by batch. * NOTE: This function does not do regularization (regularization is done by the minimizer). * * @param x - unscaled weights * @param xScale - how much to scale x by when performing calculations * @param batch - indices of which samples to compute function over * @param gScale - how much to scale adjustments to x * @return value of function at specified x (scaled by xScale) for samples */ @Override public double calculateStochasticUpdate(double[] x, double xScale, int[] batch, double gScale) { // int[][] wis = getWeightIndices(); to2D(x, xScale, weights); setWeights(weights); // if (eHat4Update == null) { // eHat4Update = empty2D(); // e4Update = new double[eHat4Update.length][]; // for (int i = 0; i < e4Update.length; i++) // e4Update[i] = new double[eHat4Update[i].length]; // } else { // clearUpdateEs(); // } // Adjust weight by -gScale*gradient // gradient is expected count - empirical count // so we adjust by + gScale(empirical count - expected count) // iterate over all the documents List<Integer> docIDs = new ArrayList<>(batch.length); for (int item : batch) { docIDs.add(item); } double prob = multiThreadGradient(docIDs, true); // the log prob of the sequence given the model, which is the negation of value at this point if (Double.isNaN(prob)) { // shouldn't be the case throw new RuntimeException("Got NaN for prob in CRFLogConditionalObjectiveFunction.calculate()"); } value = -prob; int index = 0; for (int i = 0; i < E.length; i++) { for (int j = 0; j < E[i].length; j++) { x[index++] += (Ehat[i][j] - E[i][j]) * gScale; } } return value; } /** * Performs stochastic gradient update based * on samples indexed by batch, but does not apply regularization. * * @param x - unscaled weights * @param batch - indices of which samples to compute function over */ @Override public void calculateStochasticGradient(double[] x, int[] batch) { if (derivative == null) { derivative = new double[domainDimension()]; } // int[][] wis = getWeightIndices(); // was: double[][] weights = to2D(x, 1.0); // but 1.0 should be the same as omitting 2nd parameter.... to2D(x, weights); setWeights(weights); // iterate over all the documents List<Integer> docIDs = new ArrayList<>(batch.length); for (int item : batch) { docIDs.add(item); } multiThreadGradient(docIDs, true); int index = 0; for (int i = 0; i < E.length; i++) { for (int j = 0; j < E[i].length; j++) { // real gradient should be empirical-expected; // but since we minimize -L(\theta), the gradient is -(empirical-expected) derivative[index++] = (E[i][j]-Ehat[i][j]); } } } /** * Computes value of function for specified value of x (scaled by xScale) * only over samples indexed by batch. * NOTE: This function does not do regularization (regularization is done by the minimizer). * * @param x - unscaled weights * @param xScale - how much to scale x by when performing calculations * @param batch - indices of which samples to compute function over * @return value of function at specified x (scaled by xScale) for samples */ @Override public double valueAt(double[] x, double xScale, int[] batch) { double prob = 0.0; // the log prob of the sequence given the model, which is the negation of value at this point // int[][] wis = getWeightIndices(); to2D(x, xScale, weights); setWeights(weights); // iterate over all the documents for (int ind : batch) { prob += valueForADoc(ind); } if (Double.isNaN(prob)) { // shouldn't be the case throw new RuntimeException("Got NaN for prob in CRFLogConditionalObjectiveFunction.calculate()"); } value = -prob; return value; } @Override public int[][] getFeatureGrouping() { if (featureGrouping != null) return featureGrouping; else { int[][] fg = new int[1][]; fg[0] = ArrayMath.range(0, domainDimension()); return fg; } } public void setFeatureGrouping(int[][] fg) { this.featureGrouping = fg; } protected void applyPrior(double[] x, double batchScale) { // incorporate priors if (prior == QUADRATIC_PRIOR) { double lambda = 1 / (sigma * sigma); for (int i = 0; i < x.length; i++) { double w = x[i]; value += batchScale * w * w * lambda * 0.5; derivative[i] += batchScale * w * lambda; } } else if (prior == HUBER_PRIOR) { double sigmaSq = sigma * sigma; for (int i = 0; i < x.length; i++) { double w = x[i]; double wabs = Math.abs(w); if (wabs < epsilon) { value += batchScale*w * w / 2.0 / epsilon / sigmaSq; derivative[i] += batchScale*w / epsilon / sigmaSq; } else { value += batchScale*(wabs - epsilon / 2) / sigmaSq; derivative[i] += batchScale*((w < 0.0) ? -1.0 : 1.0) / sigmaSq; } } } else if (prior == QUARTIC_PRIOR) { double sigmaQu = sigma * sigma * sigma * sigma; double lambda = 1 / 2.0 / sigmaQu; for (int i = 0; i < x.length; i++) { double w = x[i]; value += batchScale * w * w * w * w * lambda; derivative[i] += batchScale * w / sigmaQu; } } } protected Pair<double[][][], double[][][]> getCondProbs(CRFCliqueTree cTree, int[][][] docData) { // first index position is curr index, second index curr-class, third index prev-class // e.g. [1][2][3] means curr is at position 1 with class 2, prev is at position 0 with class 3 double[][][] prevGivenCurr = new double[docData.length][][]; // first index position is curr index, second index curr-class, third index next-class // e.g. [0][2][3] means curr is at position 0 with class 2, next is at position 1 with class 3 double[][][] nextGivenCurr = new double[docData.length][][]; for (int i = 0; i < docData.length; i++) { prevGivenCurr[i] = new double[numClasses][]; nextGivenCurr[i] = new double[numClasses][]; for (int j = 0; j < numClasses; j++) { prevGivenCurr[i][j] = new double[numClasses]; nextGivenCurr[i][j] = new double[numClasses]; } } // computing prevGivenCurr and nextGivenCurr for (int i=0; i < docData.length; i++) { int[] labelPair = new int[2]; for (int l1 = 0; l1 < numClasses; l1++) { labelPair[0] = l1; for (int l2 = 0; l2 < numClasses; l2++) { labelPair[1] = l2; double prob = cTree.logProb(i, labelPair); // log.info(prob); if (i-1 >= 0) nextGivenCurr[i-1][l1][l2] = prob; prevGivenCurr[i][l2][l1] = prob; } } if (DEBUG2) { log.info("unnormalized conditionals:"); if (i>0) { log.info("nextGivenCurr[" + (i-1) + "]:"); for (int a = 0; a < nextGivenCurr[i-1].length; a++) { for (int b = 0; b < nextGivenCurr[i-1][a].length; b++) log.info((nextGivenCurr[i-1][a][b])+"\t"); log.info(); } } log.info("prevGivenCurr[" + (i) + "]:"); for (int a = 0; a < prevGivenCurr[i].length; a++) { for (int b = 0; b < prevGivenCurr[i][a].length; b++) log.info((prevGivenCurr[i][a][b])+"\t"); log.info(); } } for (int j=0; j< numClasses; j++) { if (i-1 >= 0) { // ArrayMath.normalize(nextGivenCurr[i-1][j]); ArrayMath.logNormalize(nextGivenCurr[i-1][j]); for (int k = 0; k < nextGivenCurr[i-1][j].length; k++) nextGivenCurr[i-1][j][k] = Math.exp(nextGivenCurr[i-1][j][k]); } // ArrayMath.normalize(prevGivenCurr[i][j]); ArrayMath.logNormalize(prevGivenCurr[i][j]); for (int k = 0; k < prevGivenCurr[i][j].length; k++) prevGivenCurr[i][j][k] = Math.exp(prevGivenCurr[i][j][k]); } if (DEBUG2) { log.info("normalized conditionals:"); if (i>0) { log.info("nextGivenCurr[" + (i-1) + "]:"); for (int a = 0; a < nextGivenCurr[i-1].length; a++) { for (int b = 0; b < nextGivenCurr[i-1][a].length; b++) log.info((nextGivenCurr[i-1][a][b])+"\t"); log.info(); } } log.info("prevGivenCurr[" + (i) + "]:"); for (int a = 0; a < prevGivenCurr[i].length; a++) { for (int b = 0; b < prevGivenCurr[i][a].length; b++) log.info((prevGivenCurr[i][a][b])+"\t"); log.info(); } } } return new Pair<>(prevGivenCurr, nextGivenCurr); } protected static void combine2DArr(double[][] combineInto, double[][] toBeCombined, double scale) { for (int i = 0; i < toBeCombined.length; i++) for (int j = 0; j < toBeCombined[i].length; j++) combineInto[i][j] += toBeCombined[i][j] * scale; } protected static void combine2DArr(double[][] combineInto, double[][] toBeCombined) { for (int i = 0; i < toBeCombined.length; i++) for (int j = 0; j < toBeCombined[i].length; j++) combineInto[i][j] += toBeCombined[i][j]; } // TODO(mengqiu) add dimension checks protected static void combine2DArr(double[][] combineInto, Map<Integer, double[]> toBeCombined) { for (Map.Entry<Integer, double[]> entry: toBeCombined.entrySet()) { int key = entry.getKey(); double[] source = entry.getValue(); for (int i = 0; i< source.length; i++) combineInto[key][i] += source[i]; } } protected static void combine2DArr(double[][] combineInto, Map<Integer, double[]> toBeCombined, double scale) { for (Map.Entry<Integer, double[]> entry: toBeCombined.entrySet()) { int key = entry.getKey(); double[] source = entry.getValue(); for (int i = 0; i< source.length; i++) combineInto[key][i] += source[i] * scale; } } // this used to be computed lazily, but that was clearly erroneous for multithreading! @Override public int domainDimension() { return domainDimension; } /** * Takes a double array of weights and creates a 2D array where: * * the first element is the mapped index of the clique size (e.g., node-0, edge-1) matching featuresIndex i * the second element is the number of output classes for that clique size * * @return a 2D weight array */ public static double[][] to2D(double[] weights, List<Index<CRFLabel>> labelIndices, int[] map) { double[][] newWeights = new double[map.length][]; int index = 0; for (int i = 0; i < map.length; i++) { int labelSize = labelIndices.get(map[i]).size(); newWeights[i] = new double[labelSize]; try { System.arraycopy(weights, index, newWeights[i], 0, labelSize); } catch (Exception ex) { log.info("weights: " + Arrays.toString(weights)); log.info("newWeights["+i+"]: " + Arrays.toString(newWeights[i])); throw new RuntimeException(ex); } index += labelSize; } return newWeights; } public double[][] to2D(double[] weights) { return to2D(weights, this.labelIndices, this.map); } public static void to2D(double[] weights, List<Index<CRFLabel>> labelIndices, int[] map, double[][] newWeights) { int index = 0; for (int i = 0; i < map.length; i++) { int labelSize = labelIndices.get(map[i]).size(); try { System.arraycopy(weights, index, newWeights[i], 0, labelSize); } catch (Exception ex) { log.info("weights: " + Arrays.toString(weights)); log.info("newWeights["+i+"]: " + Arrays.toString(newWeights[i])); throw new RuntimeException(ex); } index += labelSize; } } public void to2D(double[] weights1D, double[][] newWeights) { to2D(weights1D, this.labelIndices, this.map, newWeights); } /** Beware: this changes the input weights array in place. */ public double[][] to2D(double[] weights1D, double wScale) { for (int i = 0; i < weights1D.length; i++) weights1D[i] = weights1D[i] * wScale; return to2D(weights1D, this.labelIndices, this.map); } /** Beware: this changes the input weights array in place. */ public void to2D(double[] weights1D, double wScale, double[][] newWeights) { for (int i = 0; i < weights1D.length; i++) weights1D[i] = weights1D[i] * wScale; to2D(weights1D, this.labelIndices, this.map, newWeights); } public static void clear2D(double[][] arr2D) { for (int i = 0; i < arr2D.length; i++) for (int j = 0; j < arr2D[i].length; j++) arr2D[i][j] = 0.0; } public static void to1D(double[][] weights, double[] newWeights) { int index = 0; for (double[] weightVector : weights) { System.arraycopy(weightVector, 0, newWeights, index, weightVector.length); index += weightVector.length; } } public static double[] to1D(double[][] weights, int domainDimension) { double[] newWeights = new double[domainDimension]; int index = 0; for (double[] weightVector : weights) { System.arraycopy(weightVector, 0, newWeights, index, weightVector.length); index += weightVector.length; } return newWeights; } public double[] to1D(double[][] weights) { return to1D(weights, domainDimension()); } public int[][] getWeightIndices() { if (weightIndices == null) { weightIndices = new int[map.length][]; int index = 0; for (int i = 0; i < map.length; i++) { weightIndices[i] = new int[labelIndices.get(map[i]).size()]; for (int j = 0; j < labelIndices.get(map[i]).size(); j++) { weightIndices[i][j] = index; index++; } } } return weightIndices; } protected double[][] empty2D() { double[][] d = new double[map.length][]; // int index = 0; for (int i = 0; i < map.length; i++) { d[i] = new double[labelIndices.get(map[i]).size()]; } return d; } public int[][] getLabels() { return labels; } }