package edu.berkeley.nlp.lm.io;
import java.io.File;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import edu.berkeley.nlp.lm.ConfigOptions;
import edu.berkeley.nlp.lm.ContextEncodedNgramLanguageModel.LmContextInfo;
import edu.berkeley.nlp.lm.WordIndexer;
import edu.berkeley.nlp.lm.map.HashNgramMap;
import edu.berkeley.nlp.lm.map.NgramMap.Entry;
import edu.berkeley.nlp.lm.util.Logger;
import edu.berkeley.nlp.lm.util.LongRef;
import edu.berkeley.nlp.lm.util.StrUtils;
import edu.berkeley.nlp.lm.values.KneserNeyCountValueContainer;
import edu.berkeley.nlp.lm.values.KneserNeyCountValueContainer.KneserNeyCounts;
import edu.berkeley.nlp.lm.values.ProbBackoffPair;
/**
* Class for producing a Kneser-Ney language model in ARPA format from raw text.
*
* @author adampauls
*
* @param <W>
*/
public class KneserNeyFileWritingLmReaderCallback<W> implements ArpaLmReaderCallback<ProbBackoffPair>
{
private PrintWriter out;
private WordIndexer<W> wordIndexer;
public KneserNeyFileWritingLmReaderCallback(final File outputFile, final WordIndexer<W> wordIndexer) {
this(IOUtils.openOutHard(outputFile), wordIndexer);
}
public KneserNeyFileWritingLmReaderCallback(final PrintWriter out, final WordIndexer<W> wordIndexer) {
this.wordIndexer = wordIndexer;
this.out = out;
}
@Override
public void handleNgramOrderFinished(int order) {
out.println("");
}
@Override
public void handleNgramOrderStarted(int order) {
out.println("\\" + (order) + "-grams:");
}
@Override
public void call(int[] ngram, int startPos, int endPos, ProbBackoffPair value, String words) {
final String line = StrUtils.join(WordIndexer.StaticMethods.toList(wordIndexer, ngram, startPos, endPos));
final boolean endsWithEndSym = ngram[ngram.length - 1] == wordIndexer.getIndexPossiblyUnk(wordIndexer.getEndSymbol());
if (endsWithEndSym || value.backoff == 0.0f)
out.printf(Locale.US, "%f\t%s\n", value.prob, line);
else {
out.printf(Locale.US, "%f\t%s\t%f\n", value.prob, line, value.backoff);
}
}
@Override
public void cleanup() {
out.println("\\end\\");
out.close();
}
@Override
public void initWithLengths(List<Long> numNGrams) {
Logger.startTrack("Writing ARPA");
out.println();
out.println("\\data\\");
for (int ngramOrder = 0; ngramOrder < numNGrams.size(); ++ngramOrder) {
final long numNgrams = numNGrams.get(ngramOrder);
out.println("ngram " + (ngramOrder + 1) + "=" + numNgrams);
}
out.println();
}
}