package edu.stanford.nlp.ie.crf; import edu.stanford.nlp.util.logging.Redwood; import edu.stanford.nlp.math.ArrayMath; import edu.stanford.nlp.util.concurrent.*; import edu.stanford.nlp.util.Index; import edu.stanford.nlp.util.Timing; import edu.stanford.nlp.util.Pair; import edu.stanford.nlp.util.Quadruple; import java.util.*; /** * @author Mengqiu Wang */ public class CRFLogConditionalObjectiveFunctionWithDropout extends CRFLogConditionalObjectiveFunction { /** A logger for this class */ private static Redwood.RedwoodChannels log = Redwood.channels(CRFLogConditionalObjectiveFunctionWithDropout.class); private final double delta; private final double dropoutScale; private double[][] dropoutPriorGradTotal; private final boolean dropoutApprox; private double[][] weightSquare; private final int[][][][] totalData; // data[docIndex][tokenIndex][][] private int unsupDropoutStartIndex; private final double unsupDropoutScale; private List<List<Set<Integer>>> dataFeatureHash; private List<Map<Integer, List<Integer>>> condensedMap; private int[][] dataFeatureHashByDoc; private int edgeLabelIndexSize; private int nodeLabelIndexSize; private int[][] edgeLabels; private Map<Integer, List<Integer>> currPrevLabelsMap; private Map<Integer, List<Integer>> currNextLabelsMap; private ThreadsafeProcessor<Pair<Integer, Boolean>, Quadruple<Integer, Double, Map<Integer, double[]>, Map<Integer, double[]>>> dropoutPriorThreadProcessor = new ThreadsafeProcessor<Pair<Integer, Boolean>, Quadruple<Integer, Double, Map<Integer, double[]>, Map<Integer, double[]>>>() { @Override public Quadruple<Integer, Double, Map<Integer, double[]>, Map<Integer, double[]>> process(Pair<Integer,Boolean> docIndexUnsup) { return expectedCountsAndValueForADoc(docIndexUnsup.first(), false, docIndexUnsup.second()); } @Override public ThreadsafeProcessor<Pair<Integer, Boolean>, Quadruple<Integer, Double, Map<Integer, double[]>, Map<Integer, double[]>>> newInstance() { return this; } }; //TODO(Mengqiu) Need to figure out what to do with dataDimension() in case of // mixed supervised+unsupervised data for SGD (AdaGrad) CRFLogConditionalObjectiveFunctionWithDropout(int[][][][] data, int[][] labels, int window, Index<String> classIndex, List<Index<CRFLabel>> labelIndices, int[] map, String priorType, String backgroundSymbol, double sigma, double[][][][] featureVal, double delta, double dropoutScale, int multiThreadGrad, boolean dropoutApprox, double unsupDropoutScale, int[][][][] unsupDropoutData) { super(data, labels, window, classIndex, labelIndices, map, priorType, backgroundSymbol, sigma, featureVal, multiThreadGrad); this.delta = delta; this.dropoutScale = dropoutScale; this.dropoutApprox = dropoutApprox; dropoutPriorGradTotal = empty2D(); this.unsupDropoutStartIndex = data.length; this.unsupDropoutScale = unsupDropoutScale; if (unsupDropoutData != null) { this.totalData = new int[data.length + unsupDropoutData.length][][][]; for (int i=0; i<data.length; i++) { this.totalData[i] = data[i]; } for (int i=0; i<unsupDropoutData.length; i++) { this.totalData[i+unsupDropoutStartIndex] = unsupDropoutData[i]; } } else { this.totalData = data; } initEdgeLabels(); initializeDataFeatureHash(); } private void initEdgeLabels() { if (labelIndices.size() < 2) return; Index<CRFLabel> edgeLabelIndex = labelIndices.get(1); edgeLabelIndexSize = edgeLabelIndex.size(); Index<CRFLabel> nodeLabelIndex = labelIndices.get(0); nodeLabelIndexSize = nodeLabelIndex.size(); currPrevLabelsMap = new HashMap<>(); currNextLabelsMap = new HashMap<>(); edgeLabels = new int[edgeLabelIndexSize][]; for (int k=0; k < edgeLabelIndexSize; k++) { int[] labelPair = edgeLabelIndex.get(k).getLabel(); edgeLabels[k] = labelPair; int curr = labelPair[1]; int prev = labelPair[0]; if (!currPrevLabelsMap.containsKey(curr)) currPrevLabelsMap.put(curr, new ArrayList<>(numClasses)); currPrevLabelsMap.get(curr).add(prev); if (!currNextLabelsMap.containsKey(prev)) currNextLabelsMap.put(prev, new ArrayList<>(numClasses)); currNextLabelsMap.get(prev).add(curr); } } private Map<Integer, double[]> sparseE(Set<Integer> activeFeatures) { Map<Integer, double[]> aMap = new HashMap<>(activeFeatures.size()); for (int f: activeFeatures) { // System.err.printf("aMap.put(%d, new double[%d])\n", f, map[f]+1); aMap.put(f,new double[map[f] == 0 ? nodeLabelIndexSize : edgeLabelIndexSize]); } return aMap; } private Map<Integer, double[]> sparseE(int[] activeFeatures) { Map<Integer, double[]> aMap = new HashMap<>(activeFeatures.length); for (int f: activeFeatures) { // System.err.printf("aMap.put(%d, new double[%d])\n", f, map[f]+1); aMap.put(f,new double[map[f] == 0 ? nodeLabelIndexSize : edgeLabelIndexSize]); } return aMap; } private Quadruple<Integer, Double, Map<Integer, double[]>, Map<Integer, double[]>> expectedCountsAndValueForADoc(int docIndex, boolean skipExpectedCountCalc, boolean skipValCalc) { int[] activeFeatures = dataFeatureHashByDoc[docIndex]; List<Set<Integer>> docDataHash = dataFeatureHash.get(docIndex); Map<Integer, List<Integer>> condensedFeaturesMap = condensedMap.get(docIndex); double prob = 0; int[][][] docData = totalData[docIndex]; int[] docLabels = null; if (docIndex < labels.length) docLabels = labels[docIndex]; Timing timer = new Timing(); double[][][] featureVal3DArr = null; if (featureVal != null) featureVal3DArr = featureVal[docIndex]; // make a clique tree for this document CRFCliqueTree cliqueTree = CRFCliqueTree.getCalibratedCliqueTree(docData, labelIndices, numClasses, classIndex, backgroundSymbol, cliquePotentialFunc, featureVal3DArr); if (!skipValCalc) { if (TIMED) timer.start(); // compute the log probability of the document given the model with the parameters x int[] given = new int[window - 1]; Arrays.fill(given, classIndex.indexOf(backgroundSymbol)); if (docLabels.length>docData.length) { // only true for self-training // fill the given array with the extra docLabels System.arraycopy(docLabels, 0, given, 0, given.length); // shift the docLabels array left int[] newDocLabels = new int[docData.length]; System.arraycopy(docLabels, docLabels.length-newDocLabels.length, newDocLabels, 0, newDocLabels.length); docLabels = newDocLabels; } double startPosLogProb = cliqueTree.logProbStartPos(); if (VERBOSE) System.err.printf("P_-1(Background) = % 5.3f\n", startPosLogProb); prob += startPosLogProb; // iterate over the positions in this document for (int i = 0; i < docData.length; i++) { int label = docLabels[i]; double p = cliqueTree.condLogProbGivenPrevious(i, label, given); if (VERBOSE) { log.info("P(" + label + "|" + ArrayMath.toString(given) + ")=" + Math.exp(p)); } prob += p; System.arraycopy(given, 1, given, 0, given.length - 1); given[given.length - 1] = label; } if (TIMED) { long elapsedMs = timer.stop(); log.info("Calculate objective took: " + Timing.toMilliSecondsString(elapsedMs) + " ms"); } } Map<Integer, double[]> EForADoc = sparseE(activeFeatures); List<Map<Integer, double[]>> EForADocPos = null; if (dropoutApprox) { EForADocPos = new ArrayList<>(docData.length); } if (!skipExpectedCountCalc) { if (TIMED) timer.start(); // compute the expected counts for this document, which we will need to compute the derivative // iterate over the positions in this document double fVal = 1.0; for (int i = 0; i < docData.length; i++) { Set<Integer> docDataHashI = docDataHash.get(i); Map<Integer, double[]> EForADocPosAtI = null; if (dropoutApprox) EForADocPosAtI = sparseE(docDataHashI); for (int fIndex: docDataHashI) { int j= map[fIndex]; Index<CRFLabel> labelIndex = labelIndices.get(j); // for each possible labeling for that clique for (int k = 0; k < labelIndex.size(); k++) { int[] label = labelIndex.get(k).getLabel(); double p = cliqueTree.prob(i, label); // probability of these labels occurring in this clique with these features if (dropoutApprox) increScore(EForADocPosAtI, fIndex, k, fVal * p); increScore(EForADoc, fIndex, k, fVal * p); } } if (dropoutApprox) { for (int fIndex: docDataHashI) { if (condensedFeaturesMap.containsKey(fIndex)) { List<Integer> aList = condensedFeaturesMap.get(fIndex); for (int toCopyInto: aList) { double[] arr = EForADocPosAtI.get(fIndex); double[] targetArr = new double[arr.length]; for (int q=0; q < arr.length; q++) targetArr[q] = arr[q]; EForADocPosAtI.put(toCopyInto, targetArr); } } } EForADocPos.add(EForADocPosAtI); } } // copy for condensedFeaturesMap for (Map.Entry<Integer, List<Integer>> entry: condensedFeaturesMap.entrySet()) { int key = entry.getKey(); List<Integer> aList = entry.getValue(); for (int toCopyInto: aList) { double[] arr = EForADoc.get(key); double[] targetArr = new double[arr.length]; for (int i=0; i < arr.length; i++) targetArr[i] = arr[i]; EForADoc.put(toCopyInto, targetArr); } } if (TIMED) { long elapsedMs = timer.stop(); log.info("Expected count took: " + Timing.toMilliSecondsString(elapsedMs) + " ms"); } } Map<Integer, double[]> dropoutPriorGrad = null; if (prior == DROPOUT_PRIOR) { if (TIMED) timer.start(); // we can optimize this, this is too large, don't need this big dropoutPriorGrad = sparseE(activeFeatures); // log.info("computing dropout prior for doc " + docIndex + " ... "); prob -= getDropoutPrior(cliqueTree, docData, EForADoc, docDataHash, activeFeatures, dropoutPriorGrad, condensedFeaturesMap, EForADocPos); // log.info(" done!"); if (TIMED) { long elapsedMs = timer.stop(); log.info("Dropout took: " + Timing.toMilliSecondsString(elapsedMs) + " ms"); } } return new Quadruple<>(docIndex, prob, EForADoc, dropoutPriorGrad); } private void increScore(Map<Integer, double[]> aMap, int fIndex, int k, double val) { aMap.get(fIndex)[k] += val; } private void increScoreAllowNull(Map<Integer, double[]> aMap, int fIndex, int k, double val) { if (!aMap.containsKey(fIndex)) { aMap.put(fIndex, new double[map[fIndex] == 0 ? nodeLabelIndexSize : edgeLabelIndexSize]); } aMap.get(fIndex)[k] += val; } private void initializeDataFeatureHash() { int macroActiveFeatureTotalCount = 0; int macroCondensedTotalCount = 0; int macroDocPosCount = 0; log.info("initializing data feature hash, sup-data size: " + data.length + ", unsup data size: " + (totalData.length-data.length)); dataFeatureHash = new ArrayList<>(totalData.length); condensedMap = new ArrayList<>(totalData.length); dataFeatureHashByDoc = new int[totalData.length][]; for (int m=0; m < totalData.length; m++) { Map<Integer, Integer> occurPos = new HashMap<>(); int[][][] aDoc = totalData[m]; List<Set<Integer>> aList = new ArrayList<>(aDoc.length); Set<Integer> setOfFeatures = new HashSet<>(); for (int i=0; i< aDoc.length; i++) { // positions in docI Set<Integer> aSet = new HashSet<>(); int[][] dataI = aDoc[i]; for (int j=0; j < dataI.length; j++) { int[] dataJ = dataI[j]; for (int item: dataJ) { if (j == 0) { if (occurPos.containsKey(item)) occurPos.put(item, -1); else occurPos.put(item, i); } aSet.add(item); } } aList.add(aSet); setOfFeatures.addAll(aSet); } macroDocPosCount += aDoc.length; macroActiveFeatureTotalCount += setOfFeatures.size(); if (CONDENSE) { if (DEBUG3) log.info("Before condense, activeFeatures = " + setOfFeatures.size()); // examine all singletons, merge ones in the same position Map<Integer, List<Integer>> condensedFeaturesMap = new HashMap<>(); int[] representFeatures = new int[aDoc.length]; Arrays.fill(representFeatures, -1); for (Map.Entry<Integer, Integer> entry: occurPos.entrySet()) { int key = entry.getKey(); int pos = entry.getValue(); if (pos != -1) { if (representFeatures[pos] == -1) { // use this as representFeatures representFeatures[pos] = key; condensedFeaturesMap.put(key, new ArrayList<>()); } else { // condense this one int rep = representFeatures[pos]; condensedFeaturesMap.get(rep).add(key); // remove key aList.get(pos).remove(key); setOfFeatures.remove(key); } } } int condensedCount = 0; for(Iterator<Map.Entry<Integer, List<Integer>>> it = condensedFeaturesMap.entrySet().iterator(); it.hasNext(); ) { Map.Entry<Integer, List<Integer>> entry = it.next(); if(entry.getValue().size() == 0) { it.remove(); } else { if (DEBUG3) { condensedCount += entry.getValue().size(); for (int cond: entry.getValue()) log.info("condense " + cond + " to " + entry.getKey()); } } } if (DEBUG3) log.info("After condense, activeFeatures = " + setOfFeatures.size() + ", condensedCount = " + condensedCount); macroCondensedTotalCount += setOfFeatures.size(); condensedMap.add(condensedFeaturesMap); } dataFeatureHash.add(aList); int[] arrOfIndex = new int[setOfFeatures.size()]; int pos2 = 0; for(Integer ind: setOfFeatures) arrOfIndex[pos2++] = ind; dataFeatureHashByDoc[m] = arrOfIndex; } log.info("Avg. active features per position: " + (macroActiveFeatureTotalCount/ (macroDocPosCount+0.0))); log.info("Avg. condensed features per position: " + (macroCondensedTotalCount / (macroDocPosCount+0.0))); log.info("initializing data feature hash done!"); } private double getDropoutPrior(CRFCliqueTree cliqueTree, int[][][] docData, Map<Integer, double[]> EForADoc, List<Set<Integer>> docDataHash, int[] activeFeatures, Map<Integer, double[]> dropoutPriorGrad, Map<Integer, List<Integer>> condensedFeaturesMap, List<Map<Integer, double[]>> EForADocPos) { Map<Integer, double[]> dropoutPriorGradFirstHalf = sparseE(activeFeatures); if (TIMED) log.info("activeFeatures size: "+activeFeatures.length + ", dataLen: " + docData.length); Timing timer = new Timing(); if (TIMED) timer.start(); double priorValue = 0; long elapsedMs = 0; Pair<double[][][], double[][][]> condProbs = getCondProbs(cliqueTree, docData); if (TIMED) { elapsedMs = timer.stop(); log.info("\t cond prob took: " + Timing.toMilliSecondsString(elapsedMs) + " ms"); } // first index position is curr index, second index curr-class, third index prev-class // e.g. [1][2][3] means curr is at position 1 with class 2, prev is at position 0 with class 3 double[][][] prevGivenCurr = condProbs.first(); // first index position is curr index, second index curr-class, third index next-class // e.g. [0][2][3] means curr is at position 0 with class 2, next is at position 1 with class 3 double[][][] nextGivenCurr = condProbs.second(); // first dim is doc length (i) // second dim is numOfFeatures (fIndex) // third dim is numClasses (y) // fourth dim is labelIndexSize (matching the clique type of fIndex, for \theta) double[][][][] FAlpha = null; double[][][][] FBeta = null; if (!dropoutApprox) { FAlpha = new double[docData.length][][][]; FBeta = new double[docData.length][][][]; } for (int i = 0; i < docData.length; i++) { if (!dropoutApprox) { FAlpha[i] = new double[activeFeatures.length][][]; FBeta[i] = new double[activeFeatures.length][][]; } } if (!dropoutApprox) { if (TIMED) { timer.start(); } // computing FAlpha int fIndex = 0; double aa, bb, cc = 0; boolean prevFeaturePresent = false; for (int i = 1; i < docData.length; i++) { // for each possible clique at this position Set<Integer> docDataHashIMinusOne = docDataHash.get(i-1); for (int fIndexPos = 0; fIndexPos < activeFeatures.length; fIndexPos++) { fIndex = activeFeatures[fIndexPos]; prevFeaturePresent = docDataHashIMinusOne.contains(fIndex); int j = map[fIndex]; Index<CRFLabel> labelIndex = labelIndices.get(j); int labelIndexSize = labelIndex.size(); if (FAlpha[i-1][fIndexPos] == null) { FAlpha[i-1][fIndexPos] = new double[numClasses][labelIndexSize]; for (int q = 0; q < numClasses; q++) FAlpha[i-1][fIndexPos][q] = new double[labelIndexSize]; } for (Map.Entry<Integer, List<Integer>> entry : currPrevLabelsMap.entrySet()) { int y = entry.getKey(); // value at i-1 double[] sum = new double[labelIndexSize]; for (int yPrime: entry.getValue()) { // value at i-2 for (int kk = 0; kk < labelIndexSize; kk++) { int[] prevLabel = labelIndex.get(kk).getLabel(); aa = (prevGivenCurr[i-1][y][yPrime]); bb = (prevFeaturePresent && ((j == 0 && prevLabel[0] == y) || (j == 1 && prevLabel[1] == y && prevLabel[0] == yPrime)) ? 1 : 0); cc = 0; if (FAlpha[i-1][fIndexPos][yPrime] != null) cc = FAlpha[i-1][fIndexPos][yPrime][kk]; sum[kk] += aa * (bb + cc); // sum[kk] += (prevGivenCurr[i-1][y][yPrime]) * ((prevFeaturePresent && ((j == 0 && prevLabel[0] == y) || (j == 1 && prevLabel[1] == y && prevLabel[0] == yPrime)) ? 1 : 0) + FAlpha[i-1][fIndexPos][yPrime][kk]); if (DEBUG2) System.err.printf("alpha[%d][%d][%d][%d] += % 5.3f * (%d + % 5.3f), prevLabel=%s\n", i, fIndex, y, kk, (prevGivenCurr[i-1][y][yPrime]), (prevFeaturePresent && ((j == 0 && prevLabel[0] == y) || (j == 1 && prevLabel[1] == y && prevLabel[0] == yPrime)) ? 1 : 0) , FAlpha[i-1][fIndexPos][yPrime][kk], Arrays.toString(prevLabel)); } } if (FAlpha[i][fIndexPos] == null) { FAlpha[i][fIndexPos] = new double[numClasses][]; } FAlpha[i][fIndexPos][y] = sum; if (DEBUG2) log.info("FAlpha["+i+"]["+fIndexPos+"]["+y+"] = " + Arrays.toString(sum)); } } } if (TIMED) { elapsedMs = timer.stop(); log.info("\t alpha took: " + Timing.toMilliSecondsString(elapsedMs) + " ms"); timer.start(); } // computing FBeta int docDataLen = docData.length; for (int i = docDataLen-2; i >= 0; i--) { Set<Integer> docDataHashIPlusOne = docDataHash.get(i+1); // for each possible clique at this position for (int fIndexPos = 0; fIndexPos < activeFeatures.length; fIndexPos++) { fIndex = activeFeatures[fIndexPos]; boolean nextFeaturePresent = docDataHashIPlusOne.contains(fIndex); int j = map[fIndex]; Index<CRFLabel> labelIndex = labelIndices.get(j); int labelIndexSize = labelIndex.size(); if (FBeta[i+1][fIndexPos] == null) { FBeta[i+1][fIndexPos] = new double[numClasses][labelIndexSize]; for (int q = 0; q < numClasses; q++) FBeta[i+1][fIndexPos][q] = new double[labelIndexSize]; } for (Map.Entry<Integer, List<Integer>> entry : currNextLabelsMap.entrySet()) { int y = entry.getKey(); // value at i double[] sum = new double[labelIndexSize]; for (int yPrime: entry.getValue()) { // value at i+1 for (int kk=0; kk < labelIndexSize; kk++) { int[] nextLabel = labelIndex.get(kk).getLabel(); // log.info("labelIndexSize:"+labelIndexSize+", nextGivenCurr:"+nextGivenCurr+", nextLabel:"+nextLabel+", FBeta["+(i+1)+"]["+ fIndexPos +"]["+yPrime+"] :"+FBeta[i+1][fIndexPos][yPrime]); aa = (nextGivenCurr[i][y][yPrime]); bb = (nextFeaturePresent && ((j == 0 && nextLabel[0] == yPrime) || (j == 1 && nextLabel[0] == y && nextLabel[1] == yPrime)) ? 1 : 0); cc = 0; if (FBeta[i+1][fIndexPos][yPrime] != null) cc = FBeta[i+1][fIndexPos][yPrime][kk]; sum[kk] += aa * ( bb + cc); // sum[kk] += (nextGivenCurr[i][y][yPrime]) * ( (nextFeaturePresent && ((j == 0 && nextLabel[0] == yPrime) || (j == 1 && nextLabel[0] == y && nextLabel[1] == yPrime)) ? 1 : 0) + FBeta[i+1][fIndexPos][yPrime][kk]); if (DEBUG2) System.err.printf("beta[%d][%d][%d][%d] += % 5.3f * (%d + % 5.3f)\n", i, fIndex, y, kk, (nextGivenCurr[i][y][yPrime]), (nextFeaturePresent && ((j == 0 && nextLabel[0] == yPrime) || (j == 1 && nextLabel[0] == y && nextLabel[1] == yPrime)) ? 1 : 0), FBeta[i+1][fIndexPos][yPrime][kk]); } } if (FBeta[i][fIndexPos] == null) { FBeta[i][fIndexPos] = new double[numClasses][]; } FBeta[i][fIndexPos][y] = sum; if (DEBUG2) log.info("FBeta["+i+"]["+fIndexPos+"]["+y+"] = " + Arrays.toString(sum)); } } } if (TIMED) { elapsedMs = timer.stop(); log.info("\t beta took: " + Timing.toMilliSecondsString(elapsedMs) + " ms"); } } if (TIMED) { timer.start(); } // derivative equals: VarU' * PtYYp * (1-PtYYp) + VarU * PtYYp' * (1-PtYYp) + VarU * PtYYp * (1-PtYYp)' // derivative equals: VarU' * PtYYp * (1-PtYYp) + VarU * PtYYp' * (1-PtYYp) + VarU * PtYYp * -PtYYp' // derivative equals: VarU' * PtYYp * (1-PtYYp) + VarU * PtYYp' * (1 - 2 * PtYYp) double deltaDivByOneMinusDelta = delta / (1.0-delta); Timing innerTimer = new Timing(); long eTiming = 0; long dropoutTiming= 0; boolean containsFeature = false; // iterate over the positions in this document for (int i = 1; i < docData.length; i++) { Set<Integer> docDataHashI = docDataHash.get(i); Map<Integer, double[]> EForADocPosAtI = null; if (dropoutApprox) EForADocPosAtI = EForADocPos.get(i); // for each possible clique at this position for (int k = 0; k < edgeLabelIndexSize; k++) { // sum over (y, y') int[] label = edgeLabels[k]; int y = label[0]; int yP = label[1]; if (TIMED) innerTimer.start(); // important to use label as an int[] for calculating cliqueTree.prob() // if it's a node clique, and label index is 2, if we don't use int[]{2} but just pass 2, // cliqueTree is going to treat it as index of the edge clique labels, and convert 2 // into int[]{0,2}, and return the edge prob marginal instead of node marginal double PtYYp = cliqueTree.prob(i, label); double PtYYpTimesOneMinusPtYYp = PtYYp * (1.0 - PtYYp); double oneMinus2PtYYp = (1.0 - 2 * PtYYp); double USum = 0; int fIndex; for (int jjj=0; jjj<labelIndices.size(); jjj++) { for (int n = 0; n < docData[i][jjj].length; n++) { fIndex = docData[i][jjj][n]; int valIndex; if (jjj == 1) valIndex = k; else valIndex = yP; double theta; try { theta = weights[fIndex][valIndex]; }catch (Exception ex) { System.err.printf("weights[%d][%d], map[%d]=%d, labelIndices.get(map[%d]).size() = %d, weights.length=%d\n", fIndex, valIndex, fIndex, map[fIndex], fIndex, labelIndices.get(map[fIndex]).size(), weights.length); throw new RuntimeException(ex); } USum += weightSquare[fIndex][valIndex]; // first half of derivative: VarU' * PtYYp * (1-PtYYp) double VarUp = deltaDivByOneMinusDelta * theta; increScoreAllowNull(dropoutPriorGradFirstHalf, fIndex, valIndex, VarUp * PtYYpTimesOneMinusPtYYp); } } if (TIMED) { eTiming += innerTimer.stop(); innerTimer.start(); } double VarU = 0.5 * deltaDivByOneMinusDelta * USum; // update function objective priorValue += VarU * PtYYpTimesOneMinusPtYYp; double VarUTimesOneMinus2PtYYp = VarU * oneMinus2PtYYp; // second half of derivative: VarU * PtYYp' * (1 - 2 * PtYYp) // boolean prevFeaturePresent = false; // boolean nextFeaturePresent = false; for (int fIndexPos = 0; fIndexPos < activeFeatures.length; fIndexPos++) { fIndex = activeFeatures[fIndexPos]; containsFeature = docDataHashI.contains(fIndex); // if (!containsFeature) continue; int jj = map[fIndex]; Index<CRFLabel> fLabelIndex = labelIndices.get(jj); for (int kk = 0; kk < fLabelIndex.size(); kk++) { // for all parameter \theta int[] fLabel = fLabelIndex.get(kk).getLabel(); // if (FAlpha[i] != null) // log.info("fIndex: " + fIndex+", FAlpha[i].size:"+FAlpha[i].length); double fCount = containsFeature && ((jj == 0 && fLabel[0] == yP) || (jj == 1 && k == kk)) ? 1 : 0; double alpha; double beta; double condE; double PtYYpPrime; if (!dropoutApprox) { alpha = ((FAlpha[i][fIndexPos] == null || FAlpha[i][fIndexPos][y] == null) ? 0 : FAlpha[i][fIndexPos][y][kk]); beta = ((FBeta[i][fIndexPos] == null || FBeta[i][fIndexPos][yP] == null) ? 0 : FBeta[i][fIndexPos][yP][kk]); condE = fCount + alpha + beta; if (DEBUG2) System.err.printf("fLabel=%s, yP = %d, fCount:%f = ((jj == 0 && fLabel[0] == yP)=%b || (jj == 1 && k == kk))=%b\n", Arrays.toString(fLabel),yP, fCount,(jj == 0 && fLabel[0] == yP) , (jj == 1 && k == kk)); PtYYpPrime = PtYYp * (condE - EForADoc.get(fIndex)[kk]); } else { double E = 0; if (EForADocPosAtI.containsKey(fIndex)) E = EForADocPosAtI.get(fIndex)[kk]; condE = fCount; PtYYpPrime = PtYYp * (condE - E); } if (DEBUG2) System.err.printf("for i=%d, k=%d, y=%d, yP=%d, fIndex=%d, kk=%d, PtYYpPrime=% 5.3f, PtYYp=% 3.3f, (condE-E[fIndex][kk])=% 3.3f, condE=% 3.3f, E[fIndex][k]=% 3.3f, alpha=% 3.3f, beta=% 3.3f, fCount=% 3.3f\n", i, k, y, yP, fIndex, kk, PtYYpPrime, PtYYp, (condE - EForADoc.get(fIndex)[kk]), condE, EForADoc.get(fIndex)[kk], alpha, beta, fCount); increScore(dropoutPriorGrad, fIndex, kk, VarUTimesOneMinus2PtYYp * PtYYpPrime); } if (DEBUG2) log.info(); } if (TIMED) dropoutTiming += innerTimer.stop(); } } if (CONDENSE) { // copy for condensedFeaturesMap for (Map.Entry<Integer, List<Integer>> entry: condensedFeaturesMap.entrySet()) { int key = entry.getKey(); List<Integer> aList = entry.getValue(); for (int toCopyInto: aList) { double[] arr = dropoutPriorGrad.get(key); double[] targetArr = new double[arr.length]; for (int i=0; i < arr.length; i++) targetArr[i] = arr[i]; dropoutPriorGrad.put(toCopyInto, targetArr); } } } if (DEBUG3) { log.info("dropoutPriorGradFirstHalf.keys:["); for (int key: dropoutPriorGradFirstHalf.keySet()) log.info(" "+key); log.info("]"); log.info("dropoutPriorGrad.keys:["); for (int key: dropoutPriorGrad.keySet()) log.info(" "+key); log.info("]"); } for (Map.Entry<Integer, double[]> entry: dropoutPriorGrad.entrySet()) { Integer key = entry.getKey(); double[] target = entry.getValue(); if (dropoutPriorGradFirstHalf.containsKey(key)) { double[] source = dropoutPriorGradFirstHalf.get(key); for (int i=0; i<target.length; i++) { target[i] += source[i]; } } } // for (int i=0;i<dropoutPriorGrad.length;i++) // for (int j=0; j<dropoutPriorGrad[i].length;j++) { // if (DEBUG3) // System.err.printf("f=%d, k=%d, dropoutPriorGradFirstHalf[%d][%d]=% 5.3f, dropoutPriorGrad[%d][%d]=% 5.3f\n", i, j, i, j, dropoutPriorGradFirstHalf[i][j], i, j, dropoutPriorGrad[i][j]); // dropoutPriorGrad[i][j] += dropoutPriorGradFirstHalf[i][j]; // } if (TIMED) { elapsedMs = timer.stop(); log.info("\t grad took: " + Timing.toMilliSecondsString(elapsedMs) + " ms"); log.info("\t\t exp took: " + Timing.toMilliSecondsString(eTiming) + " ms"); log.info("\t\t dropout took: " + Timing.toMilliSecondsString(dropoutTiming) + " ms"); } return dropoutScale * priorValue; } @Override public void setWeights(double[][] weights) { super.setWeights(weights); if (weightSquare == null) { weightSquare = new double[weights.length][]; for (int i = 0; i < weights.length; i++) weightSquare[i] = new double[weights[i].length]; } for (int i = 0; i < weights.length; i++) { for (int j=0; j < weights[i].length; j++) { double w = weights[i][j]; weightSquare[i][j] = w * w; } } } /** * Calculates both value and partial derivatives at the point x, and save them internally. */ @Override public void calculate(double[] x) { double prob = 0.0; // the log prob of the sequence given the model, which is the negation of value at this point // final double[][] weights = to2D(x); to2D(x, weights); setWeights(weights); // the expectations over counts // first index is feature index, second index is of possible labeling // double[][] E = empty2D(); clear2D(E); clear2D(dropoutPriorGradTotal); MulticoreWrapper<Pair<Integer, Boolean>, Quadruple<Integer, Double, Map<Integer, double[]>, Map<Integer, double[]>>> wrapper = new MulticoreWrapper<>(multiThreadGrad, dropoutPriorThreadProcessor); // supervised part for (int m = 0; m < totalData.length; m++) { boolean submitIsUnsup = (m >= unsupDropoutStartIndex); wrapper.put(new Pair<>(m, submitIsUnsup)); while (wrapper.peek()) { Quadruple<Integer, Double, Map<Integer, double[]>, Map<Integer, double[]>> result = wrapper.poll(); int docIndex = result.first(); boolean isUnsup = docIndex >= unsupDropoutStartIndex; if (isUnsup) { prob += unsupDropoutScale * result.second(); } else { prob += result.second(); } Map<Integer, double[]> partialDropout = result.fourth(); if (partialDropout != null) { if (isUnsup) { combine2DArr(dropoutPriorGradTotal, partialDropout, unsupDropoutScale); } else { combine2DArr(dropoutPriorGradTotal, partialDropout); } } if (!isUnsup) { Map<Integer, double[]> partialE = result.third(); if (partialE != null) combine2DArr(E, partialE); } } } wrapper.join(); while (wrapper.peek()) { Quadruple<Integer, Double, Map<Integer, double[]>, Map<Integer, double[]>> result = wrapper.poll(); int docIndex = result.first(); boolean isUnsup = docIndex >= unsupDropoutStartIndex; if (isUnsup) { prob += unsupDropoutScale * result.second(); } else { prob += result.second(); } Map<Integer, double[]> partialDropout = result.fourth(); if (partialDropout != null) { if (isUnsup) { combine2DArr(dropoutPriorGradTotal, partialDropout, unsupDropoutScale); } else { combine2DArr(dropoutPriorGradTotal, partialDropout); } } if (!isUnsup) { Map<Integer, double[]> partialE = result.third(); if (partialE != null) combine2DArr(E, partialE); } } if (Double.isNaN(prob)) { // shouldn't be the case throw new RuntimeException("Got NaN for prob in CRFLogConditionalObjectiveFunctionWithDropout.calculate()" + " - this may well indicate numeric underflow due to overly long documents."); } // because we minimize -L(\theta) value = -prob; if (VERBOSE) { log.info("value is " + Math.exp(-value)); } // compute the partial derivative for each feature by comparing expected counts to empirical counts int index = 0; for (int i = 0; i < E.length; i++) { for (int j = 0; j < E[i].length; j++) { // because we minimize -L(\theta) derivative[index] = (E[i][j] - Ehat[i][j]); derivative[index] += dropoutScale * dropoutPriorGradTotal[i][j]; if (VERBOSE) { log.info("deriv(" + i + ',' + j + ") = " + E[i][j] + " - " + Ehat[i][j] + " = " + derivative[index]); } index++; } } } }