package edu.berkeley.nlp.lm.io;
import java.io.BufferedReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import edu.berkeley.nlp.lm.WordIndexer;
import edu.berkeley.nlp.lm.util.Logger;
import edu.berkeley.nlp.lm.values.ProbBackoffPair;
/**
* A parser for ARPA LM files.
*
* @author Alex Bouchard-Cote
* @author Adam Pauls
*/
public class ArpaLmReader<W> implements LmReader<ProbBackoffPair, ArpaLmReaderCallback<ProbBackoffPair>>
{
public static final String START_SYMBOL = "<s>";
public static final String END_SYMBOL = "</s>";
public static final String UNK_SYMBOL = "<unk>";
private BufferedReader reader;
private int currentNGramLength = 1;
int currentNGramCount = 0;
/**
* The current line in the file being examined.
*/
private int lineNumber = 1;
private final WordIndexer<W> wordIndexer;
private final int maxOrder;
private final String file;
/**
*
* @return
* @throws IOException
*/
protected String readLine() throws IOException {
lineNumber++;
return reader.readLine();
}
/**
*
* @param reader
*/
public ArpaLmReader(final String file, final WordIndexer<W> wordIndexer, final int maxNgramOrder) {
this.file = file;
this.wordIndexer = wordIndexer;
this.maxOrder = maxNgramOrder;
}
/**
* Parse the ARPA file and populate the relevant fields of the enclosing
* ICSILanguageModel
*
*/
@Override
public void parse(final ArpaLmReaderCallback<ProbBackoffPair> callback) {
currentNGramLength = 1;
currentNGramCount = 0;
lineNumber = 1;
this.reader = IOUtils.openInHard(file);
Logger.startTrack("Parsing ARPA language model file");
final List<Long> numNGrams = parseHeader();
callback.initWithLengths(numNGrams);
parseNGrams(callback);
Logger.endTrack();
callback.cleanup();
wordIndexer.setStartSymbol(wordIndexer.getWord(wordIndexer.getOrAddIndexFromString(START_SYMBOL)));
wordIndexer.setEndSymbol(wordIndexer.getWord(wordIndexer.getOrAddIndexFromString(END_SYMBOL)));
wordIndexer.setUnkSymbol(wordIndexer.getWord(wordIndexer.getOrAddIndexFromString(UNK_SYMBOL)));
}
/**
*
* @param callback
* @throws IOException
* @throws ARPAParserException
*/
protected List<Long> parseHeader() {
final List<Long> numEachNgrams = new ArrayList<Long>();
try {
String readLine = null;
while ((readLine = readLine()) != null) {
final String ngramTotalPrefix = "ngram ";
if (readLine.startsWith(ngramTotalPrefix)) {
final int equalsIndex = readLine.indexOf('=');
assert equalsIndex >= 0;
final long currNumNGrams = Long.parseLong(readLine.substring(equalsIndex + 1));
if (numEachNgrams.size() < maxOrder) numEachNgrams.add(currNumNGrams);
}
if (readLine.contains("\\1-grams:")) { return numEachNgrams; }
}
} catch (final NumberFormatException e) {
throw new RuntimeException(e);
} catch (final IOException e) {
throw new RuntimeException(e);
}
throw new RuntimeException("Something wrong with I/O.");
}
/**
*
*
*/
protected void parseNGrams(final ArpaLmReaderCallback<ProbBackoffPair> callback) {
int currLine = 0;
Logger.startTrack("Reading 1-grams");
callback.handleNgramOrderStarted(currentNGramLength);
try {
String line = null;
int[] ngramScratch = new int[currentNGramLength];
while ((line = reader.readLine()) != null) {
if (currLine % 100000 == 0) Logger.logs("Read " + currLine + " lines");
currLine++;
if (line.length() == 0) {
// nothing to do (skip blank lines)
} else if (line.charAt(0) == '\\') {
// a new block of n-gram is beginning
if (!line.startsWith("\\end")) {
Logger.logs(currentNGramCount + " " + currentNGramLength + "-gram read.");
Logger.endTrack();
callback.handleNgramOrderFinished(currentNGramLength);
currentNGramLength++;
if (currentNGramLength > maxOrder) return;
ngramScratch = new int[currentNGramLength];
currentNGramCount = 0;
callback.handleNgramOrderStarted(currentNGramLength);
Logger.startTrack("Reading " + currentNGramLength + "-grams");
}
} else {
parseLine(callback, line, ngramScratch);
}
}
reader.close();
} catch (final IOException e) {
throw new RuntimeException(e);
}
Logger.endTrack();
callback.handleNgramOrderFinished(currentNGramLength);
}
/**
*
* @param line
* @throws ARPAParserException
*/
private void parseLine(final ArpaLmReaderCallback<ProbBackoffPair> callback, final String line, final int[] ngram) {
// this is a 2 or 3 columns n-gram entry
final int firstTab = line.indexOf('\t');
final int secondTab = line.indexOf('\t', firstTab + 1);
final boolean hasBackOff = (secondTab >= 0);
final int length = line.length();
parseNGram(line, firstTab + 1, secondTab < 0 ? length : secondTab, ngram);
// the first column contains the log pr
final String logProbString = line.substring(0, firstTab);
final float logProbability = Float.parseFloat(logProbString);
float backoff = 0.0f;
// and its backoff, if specified
if (hasBackOff) {
backoff = Float.parseFloat(line.substring(secondTab + 1, length));
}
// add the new n-gram
if (logProbability > 0.0) throw new RuntimeException("Bad ARPA line " + line);
callback.call(ngram, 0, ngram.length, new ProbBackoffPair(logProbability, backoff), line);
currentNGramCount++;
}
/**
*
* @param string
* @return
*/
private void parseNGram(final String string, int start, int stringLength, final int[] retVal) {
int k = 0;
int spaceIndex = start;
while (true) {
final int nextIndex = string.indexOf(' ', spaceIndex);
final String currWord = string.substring(spaceIndex, nextIndex < 0 ? stringLength : nextIndex);
retVal[k++] = wordIndexer.getOrAddIndexFromString(currWord);
if (nextIndex < 0) break;
spaceIndex = nextIndex + 1;
}
}
}