package edu.berkeley.nlp.lm.io;
import java.io.File;
import java.io.PrintWriter;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import edu.berkeley.nlp.lm.ArrayEncodedNgramLanguageModel;
import edu.berkeley.nlp.lm.ConfigOptions;
import edu.berkeley.nlp.lm.ContextEncodedNgramLanguageModel.LmContextInfo;
import edu.berkeley.nlp.lm.WordIndexer;
import edu.berkeley.nlp.lm.collections.Counter;
import edu.berkeley.nlp.lm.map.HashNgramMap;
import edu.berkeley.nlp.lm.map.NgramMap.Entry;
import edu.berkeley.nlp.lm.util.Logger;
import edu.berkeley.nlp.lm.util.LongRef;
import edu.berkeley.nlp.lm.util.StrUtils;
import edu.berkeley.nlp.lm.values.KneserNeyCountValueContainer;
import edu.berkeley.nlp.lm.values.KneserNeyCountValueContainer.KneserNeyCounts;
import edu.berkeley.nlp.lm.values.ProbBackoffPair;
/**
* Class for producing a Kneser-Ney language model in ARPA format from raw text.
*
* Confusingly, this class is both a {@link LmReaderCallback} (called from
* {@link TextReader}, which reads plain text), and a {@link LmReader}, which
* "reads" counts and produces Kneser-Ney probabilities and backoffs and passes
* them on an {@link ArpaLmReaderCallback}
*
* @author adampauls
*
* @param <W>
*/
public class KneserNeyLmReaderCallback<W> implements NgramOrderedLmReaderCallback<LongRef>, LmReader<ProbBackoffPair, ArpaLmReaderCallback<ProbBackoffPair>>,
ArrayEncodedNgramLanguageModel<W>, Serializable
{
// from http://www-speech.sri.com/projects/srilm/manpages/ngram-discount.7.html
// p(a_z) = g(a_z) + bow(a_) p(_z) ; Eqn.4
//
// Let Z1 be the set {z: c(a_z) > 0}. For highest order N-grams we have:
//
// g(a_z) = max(0, c(a_z) - D) / c(a_)
// bow(a_) = 1 - Sum_Z1 g(a_z)
// = 1 - Sum_Z1 c(a_z) / c(a_) + Sum_Z1 D / c(a_)
// = D n(a_*) / c(a_)
//
// Let Z2 be the set {z: n(*_z) > 0}. For lower order N-grams we have:
//
// g(_z) = max(0, n(*_z) - D) / n(*_*)
// bow(_) = 1 - Sum_Z2 g(_z)
// = 1 - Sum_Z2 n(*_z) / n(*_*) + Sum_Z2 D / n(*_*)
// = D n(_*) / n(*_*)
//
/**
*
*/
protected static final long serialVersionUID = 1L;
protected static final float DEFAULT_DISCOUNT = 0.75f;
protected final int lmOrder;
/**
*
*/
/**
*
* This array represents the discount used for each ngram order.
*
* The original Kneser-Ney discounting (-ukndiscount) uses one discounting
* constant for each N-gram order. These constants are estimated as
*
* D = n1 / (n1 + 2*n2)
*
* where n1 and n2 are the total number of N-grams with exactly one and two
* counts, respectively.
*
* For simplicity, our code just uses a constant discount for each order of
* 0.75. However, other discounts can be specified.
*/
protected final WordIndexer<W> wordIndexer;
protected final HashNgramMap<KneserNeyCounts> ngrams;
protected final ConfigOptions opts;
protected final int startIndex;
/**
*
* @param wordIndexer
* @param maxOrder
* @param inputIsSentences
* If true, input n-grams are assumed to be sentences, and all
* sub-ngrams of up to order <code>maxOrder</code> are added. If
* false, input n-grams are assumed to be atomic.
*/
public KneserNeyLmReaderCallback(final WordIndexer<W> wordIndexer, final int maxOrder) {
this(wordIndexer, maxOrder, new ConfigOptions());
}
public KneserNeyLmReaderCallback(final WordIndexer<W> wordIndexer, final int maxOrder, final ConfigOptions opts) {
this.lmOrder = maxOrder;
this.startIndex = wordIndexer.getIndexPossiblyUnk(wordIndexer.getStartSymbol());
this.opts = opts;
double last = Double.NEGATIVE_INFINITY;
for (final double c : opts.kneserNeyMinCounts) {
if (c < last)
throw new IllegalArgumentException("Please ensure that ConfigOptions.kneserNeyMinCounts is monotonic (value was "
+ Arrays.toString(opts.kneserNeyMinCounts) + ")");
last = c;
}
this.wordIndexer = wordIndexer;
final KneserNeyCountValueContainer values = new KneserNeyCountValueContainer(lmOrder, startIndex);//, justLastWord);
ngrams = HashNgramMap.createExplicitWordHashNgramMap(values, opts, lmOrder, false);
}
public void call(final W[] ngram, final LongRef value) {
final int[] ints = new int[ngram.length];
for (int i = 0; i < ngram.length; ++i)
ints[i] = wordIndexer.getOrAddIndex(ngram[i]);
call(ints, 0, ints.length, value, "");
}
public void callJustLast(final W[] ngram, final LongRef value, final long[][] scratch) {
final int[] ints = new int[ngram.length];
for (int i = 0; i < ngram.length; ++i)
ints[i] = wordIndexer.getOrAddIndex(ngram[i]);
addNgram(ints, 0, ints.length, value, "", true, scratch);
}
@Override
public void call(final int[] ngram, final int startPos, final int endPos, final LongRef value, final String words) {
final long[][] prevOffsets = new long[lmOrder][endPos - startPos];
addNgram(ngram, startPos, endPos, value, words, false, prevOffsets);
}
/**
* @param ngram
* @param startPos
* @param endPos
* @param value
* @param words
*/
public void addNgram(final int[] ngram, final int startPos, final int endPos, final LongRef value, @SuppressWarnings("unused") final String words, final boolean justLastWord,
final long[][] scratch) {
final KneserNeyCounts scratchCounts = new KneserNeyCounts();
ngrams.rehashIfNecessary(endPos - startPos);
for (int ngramOrder = 0; ngramOrder < lmOrder; ++ngramOrder) {
for (int i = startPos; i < endPos; ++i) {
int j = i + ngramOrder + 1;
if (j > endPos) continue;
scratchCounts.tokenCounts = value.value;
final long prevOffset = ngramOrder == 0 ? 0 : scratch[ngramOrder - 1][i];
final long suffixOffset = ngramOrder == 0 ? 0 : scratch[ngramOrder - 1][i + 1];
assert prevOffset >= 0;
scratch[ngramOrder][i - startPos] = ngrams.putWithOffsetAndSuffix(ngram, i, j, prevOffset, suffixOffset, !justLastWord || j == endPos
/* || ngram[startPos] == startIndex */
? scratchCounts : null);
}
}
}
protected float interpolateProb(final int[] ngram, final int startPos, final int endPos) {
if (startPos == endPos) return 0.0f;
final float backoff = getLowerOrderBackoff(ngram, startPos, endPos - 1);
final float prob = getLowerOrderProb(ngram, startPos, endPos);
return prob + backoff * interpolateProb(ngram, startPos + 1, endPos);
}
protected float getHighestOrderProb(final int[] ngram, final int startPos, final int endPos) {
final KneserNeyCounts counts = getCounts(ngram, startPos, endPos, false);
final KneserNeyCounts rightDotCounts = getCounts(ngram, startPos, endPos - 1, true);
final int ngramOrder = endPos - startPos - 1;
final float D = getDiscountForOrder(ngramOrder);
final float prob = rightDotCounts.tokenCounts == 0 ? 0.0f : Math.max(0.0f, (counts.tokenCounts - D) / rightDotCounts.tokenCounts);
return prob;
}
protected float getLowerOrderProb(final int[] ngram, final int startPos, final int endPos) {
if (startPos == endPos) return 1.0f;
final KneserNeyCounts counts = getCounts(ngram, startPos, endPos, false);
final KneserNeyCounts prefixCounts = getCounts(ngram, startPos, endPos - 1, true);
final float probDiscount = (endPos - startPos == 1) ? 0.0f : getDiscountForOrder(endPos - startPos - 1);
final float prob = prefixCounts.dotdotTypeCounts == 0 ? 0.0f : Math.max(0.0f, counts.leftDotTypeCounts - probDiscount) / prefixCounts.dotdotTypeCounts;
return prob;
}
protected float getLowerOrderBackoff(final int[] ngram, final int startPos, final int endPos) {
if (startPos == endPos) return 1.0f;
final KneserNeyCounts counts = getCounts(ngram, startPos, endPos, true);
final long backoffDenom = (endPos - startPos == lmOrder - 1 || ngram[startPos] == startIndex) ? counts.tokenCounts : counts.dotdotTypeCounts;
assert backoffDenom >= 0;
// final long backoffDenom = endPos - startPos == lmOrder - 1 ? counts.tokenCounts : counts.dotdotTypeCounts;
final float backoffDiscount = getDiscountForOrder(endPos - startPos);
final float backoff = backoffDenom == 0.0f ? 1.0f : backoffDiscount * counts.rightDotTypeCounts / backoffDenom;
return backoff;
}
protected float getDiscountForOrder(int ngramOrder) {
if (opts.kneserNeyDiscounts != null) return (float) opts.kneserNeyDiscounts[ngramOrder];
final int numOneCounters = ((KneserNeyCountValueContainer) ngrams.getValues()).getNumOneCountNgrams(ngramOrder);
final int numTwoCounters = ((KneserNeyCountValueContainer) ngrams.getValues()).getNumTwoCountNgrams(ngramOrder);
final float denom = (numOneCounters + 2 * (float) numTwoCounters);
return denom == 0.0f ? 1e-5f : numOneCounters / denom;
}
@Override
public void cleanup() {
}
/**
* @param key
* @param ngrams
* @param startPos
* @param endPos
*/
private KneserNeyCounts getCounts(final int[] key, final int startPos, final int endPos, final boolean isBackoff) {
final KneserNeyCounts value = new KneserNeyCounts();
if (startPos == endPos) {
//only happens when requesting number of bigrams
value.dotdotTypeCounts = ((KneserNeyCountValueContainer) ngrams.getValues()).getBigramTypeCounts();
return value;
}
final long offset = ngrams.getOffsetForNgramInModel(key, startPos, endPos);
if (offset < 0) return value;
ngrams.getValues().getFromOffset(offset, endPos - startPos - 1, value);
final boolean startsWithStartSym = key[startPos] == startIndex;
final boolean endsWithEndSym = key[endPos - 1] == wordIndexer.getIndexPossiblyUnk(wordIndexer.getEndSymbol());
if (startsWithStartSym) {
value.dotdotTypeCounts = value.rightDotTypeCounts;
if (endPos - startPos < lmOrder - 1 || (endPos - startPos == lmOrder - 1 && !isBackoff)) value.tokenCounts = value.leftDotTypeCounts;
}
if (endsWithEndSym) {
value.rightDotTypeCounts = 1;
value.dotdotTypeCounts = value.leftDotTypeCounts;
}
return value;
}
public static double[] defaultDiscounts() {
return constantArray(defaultMinCounts().length, DEFAULT_DISCOUNT);
}
public static double[] defaultMinCounts() {
//same as SRILM
return new double[] { 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2 };
}
private static double[] constantArray(final int n, final double f) {
final double[] ret = new double[n];
Arrays.fill(ret, f);
return ret;
}
@Override
public void parse(ArpaLmReaderCallback<ProbBackoffPair> callback) {
Logger.startTrack("Writing Kneser-Ney probabilities");
List<Long> lengths = new ArrayList<Long>();
for (int ngramOrder = 0; ngramOrder < lmOrder; ++ngramOrder) {
Logger.startTrack("Counting counts for order " + ngramOrder);
long numNgrams = 0; //ngrams.getNumNgrams(ngramOrder);
for (final Entry<KneserNeyCounts> entry : ngrams.getNgramsForOrder(ngramOrder)) {
final long relevantCount = entry.value.tokenCounts;
if (ngramOrder >= lmOrder - 2 && ngramOrder < opts.kneserNeyMinCounts.length && relevantCount < opts.kneserNeyMinCounts[ngramOrder]) continue;
numNgrams++;
}
lengths.add(numNgrams);
Logger.endTrack();
}
callback.initWithLengths(lengths);
for (int ngramOrder = 0; ngramOrder < lmOrder; ++ngramOrder) {
callback.handleNgramOrderStarted(ngramOrder + 1);
Logger.logss("On order " + (ngramOrder + 1));
int linenum = 0;
for (final Entry<KneserNeyCounts> entry : ngrams.getNgramsForOrder(ngramOrder)) {
if (linenum++ % 10000 == 0) Logger.logs("Writing line " + linenum);
final long relevantCount = entry.value.tokenCounts;
if (ngramOrder >= lmOrder - 2 && ngramOrder < opts.kneserNeyMinCounts.length && relevantCount < opts.kneserNeyMinCounts[ngramOrder]) continue;
final int[] ngram = entry.key;
final int endPos = ngram.length;
final int startPos = 0;
ProbBackoffPair value = getProbBackoff(ngram, startPos, endPos);
callback.call(ngram, startPos, endPos, value, "");
}
callback.handleNgramOrderFinished(ngramOrder + 1);
}
callback.cleanup();
Logger.endTrack();
}
/**
* @param startIndex
* @param ngramOrder
* @param entry
* @param ngram
* @param endPos
* @param startPos
* @return
*/
private ProbBackoffPair getProbBackoff(final int[] ngram, final int startPos, final int endPos) {
final int ngramOrder = endPos - startPos - 1;
final boolean isHighestOrder = ngramOrder == lmOrder - 1;
final float val = isHighestOrder || ngram[startPos] == startIndex ? getHighestOrderProb(ngram, startPos, endPos) : getLowerOrderProb(ngram, startPos,
endPos);
int nextNonStart = startPos + 1;
while (nextNonStart < endPos && ngram[nextNonStart] == startIndex) {
nextNonStart++;
}
final float prob = val + getLowerOrderBackoff(ngram, startPos, endPos - 1) * interpolateProb(ngram, nextNonStart, endPos);
final boolean isStartEndSym = endPos - startPos == 1 && ngram[startPos] == startIndex;
final float logProb = isStartEndSym ? -99 : ((float) (Math.log10(prob)));
// if (logProb == Float.NEGATIVE_INFINITY) {
// System.out.println("here");
// }
final float backoff = isHighestOrder ? 0.0f : (float) Math.log10(getLowerOrderBackoff(ngram, startPos, endPos));
final ProbBackoffPair ret = new ProbBackoffPair(logProb, backoff);
return ret;
}
public WordIndexer<W> getWordIndexer() {
return wordIndexer;
}
@Override
public void handleNgramOrderFinished(int order) {
}
@Override
public void handleNgramOrderStarted(int order) {
}
@Override
public int getLmOrder() {
return lmOrder;
}
@Override
public float scoreSentence(List<W> sentence) {
return ArrayEncodedNgramLanguageModel.DefaultImplementations.scoreSentence(sentence, this);
}
@Override
public float getLogProb(List<W> ngram) {
return ArrayEncodedNgramLanguageModel.DefaultImplementations.getLogProb(ngram, this);
}
@Override
public float getLogProb(int[] ngram, int startPos, int endPos) {
ProbBackoffPair probBackoff = getProbBackoff(ngram, startPos, endPos);
return probBackoff.prob;
}
@Override
public float getLogProb(int[] ngram) {
return ArrayEncodedNgramLanguageModel.DefaultImplementations.getLogProb(ngram, this);
}
public long getTotalSize() {
return ngrams.getTotalSize();
}
@Override
public void setOovWordLogProb(float logProb) {
throw new UnsupportedOperationException("Method not yet implemented");
}
}