package edu.berkeley.cs.nlp.ocular.lm; import java.util.Arrays; /** * @author Taylor Berg-Kirkpatrick (tberg@eecs.berkeley.edu) */ public class NgramCounts { public final NgramWrapper ngram; public final CountDbBig[] counts; public final long[] tokenCounts; public final long[] tokenNormalizers; public final long[] typeCounts; public final long[] typeNormalizers; public final long[] historyTypeCounts; public static final double UNK_LOG_PROB = -10; public static final double DISCOUNT = 0.75; public NgramCounts(NgramWrapper ngram, CountDbBig[] counts) { // Logger.logss("ORIGINAL NGRAM: " + Ngram.toString(ngram, charIndexer)); this.ngram = ngram; this.counts = counts; int ngramOrder = ngram.getOrder(); this.tokenCounts = new long[ngramOrder]; this.tokenNormalizers = new long[ngramOrder]; this.historyTypeCounts = new long[ngramOrder]; int numberOfTypeCounts = Math.min(ngramOrder, counts.length-1); this.typeCounts = new long[numberOfTypeCounts]; this.typeNormalizers = new long[numberOfTypeCounts]; for (int i = 0; i < ngramOrder; i++) { int order = i + 1; // Version of the current n-gram, truncated to the appropriate length NgramWrapper tempNgramWrapper = ngram.getLowerOrder(order); long[] tempNgram = tempNgramWrapper.getLongerRep(); long[] tempHistory = tempNgramWrapper.getHistory().getLongerRep(); this.tokenCounts[i] = counts[i].getCount(tempNgram, CountType.TOKEN_INDEX); if (i > 0) { this.tokenNormalizers[i] = counts[i-1].getCount(tempHistory, CountType.TOKEN_INDEX); this.historyTypeCounts[i] = counts[i-1].getCount(tempHistory, CountType.HISTORY_TYPE_INDEX); } else { this.tokenNormalizers[i] = counts[i].getNumTokens(); this.historyTypeCounts[i] = 0; // don't need these at the lowest order } if (i < numberOfTypeCounts) { this.typeCounts[i] = counts[i].getCount(tempNgram, CountType.LOWER_ORDER_TYPE_INDEX); if (i > 0) { this.typeNormalizers[i] = counts[i-1].getCount(tempHistory, CountType.LOWER_ORDER_TYPE_NORMALIZER); } else { this.typeNormalizers[i] = counts[0].getNumBigramTypes(); } } } } public int getNgramOrder() { return ngram.getOrder(); } /** * @return The highest order for which we have nonzero history counts */ public int getHighestUsableOrder() { for (int i = getNgramOrder() - 1; i >= 0; i--) { if (tokenCounts[i] > 0) { //if (tokenNormalizers[i] <= 0) throw new RuntimeException("Bad counts: " + this); } if (tokenNormalizers[i] > 0) { return i+1; } } throw new RuntimeException("getHighestUsableOrder() failed. getNgramOrder()="+getNgramOrder()); } public double getTokenMle() { return getTokenMle(getHighestUsableOrder() - 1); } public double getTokenMle(int orderIndex) { return ((double)tokenCounts[orderIndex])/((double)tokenNormalizers[orderIndex]); } public double getTokenMleOrEpsilon(int orderIndex) { if (tokenCounts[orderIndex] == 0) { return Math.exp(UNK_LOG_PROB); } else { return ((double)tokenCounts[orderIndex])/((double)tokenNormalizers[orderIndex]); } } public double getTypeMle(int orderIndex) { return ((double)typeCounts[orderIndex])/((double)typeNormalizers[orderIndex]); } public double getAbsoluteDiscounting() { return adHelper(getHighestUsableOrder()); } private double adHelper(int order) { // Logger.logss("AD ORDER: " + order); int orderIndex = order - 1; if (order == 1) { return getTokenMleOrEpsilon(orderIndex); } else { return (Math.max(0.0, ((double)tokenCounts[orderIndex]) - DISCOUNT))/((double)tokenNormalizers[orderIndex]) + ((double)historyTypeCounts[orderIndex]) * DISCOUNT/((double)tokenNormalizers[orderIndex]) * adHelper(order - 1); } } public double getKneserNey() { int highestOrder = getHighestUsableOrder(); int highestOrderIndex = highestOrder - 1; if (highestOrder == 1) { return getTokenMleOrEpsilon(highestOrderIndex); } else if (highestOrder == getNgramOrder()) { double alpha = (Math.max(0.0, ((double)tokenCounts[highestOrderIndex]) - DISCOUNT))/((double)tokenNormalizers[highestOrderIndex]); double bow = ((double)historyTypeCounts[highestOrderIndex]) * DISCOUNT/((double)tokenNormalizers[highestOrderIndex]); // Logger.logss("KNTOP: " + alpha + " " + bow); return alpha + bow * knHelper(highestOrder - 1); } else { return knHelper(highestOrder); } } private double knHelper(int order) { int orderIndex = order - 1; if (order == 1) { if (typeCounts[0] == 0) { return Math.exp(UNK_LOG_PROB); } else { return ((double)typeCounts[0])/((double)typeNormalizers[0]); } } else { double alpha = (Math.max(0.0, ((double)typeCounts[orderIndex]) - DISCOUNT))/((double)typeNormalizers[orderIndex]); double bow = ((double)historyTypeCounts[orderIndex]) * DISCOUNT/((double)typeNormalizers[orderIndex]); // Logger.logss("KN: " + alpha + " " + bow); return alpha + bow * knHelper(order - 1); } } public String toString() { String string = ""; string += "Ngram: " + ngram + "; order: " + getNgramOrder() + "\n"; string += "Tok: " + Arrays.toString(tokenCounts) + "\n"; string += "TokNorm: " + Arrays.toString(tokenNormalizers) + "\n"; string += "Typ: " + Arrays.toString(typeCounts) + "\n"; string += "TypNorm: " + Arrays.toString(typeNormalizers) + "\n"; string += "HistTyp: " + Arrays.toString(historyTypeCounts) + "\n"; return string; } }