package edu.berkeley.nlp.lm.io;
import java.io.File;
import java.io.FilenameFilter;
import java.io.IOException;
import java.util.Arrays;
import edu.berkeley.nlp.lm.ConfigOptions;
import edu.berkeley.nlp.lm.WordIndexer;
import edu.berkeley.nlp.lm.collections.Iterators;
import edu.berkeley.nlp.lm.util.Logger;
import edu.berkeley.nlp.lm.util.LongRef;
/**
* Reads in n-gram count collections in the format that the Google n-grams Web1T
* corpus comes in.
*
* @author adampauls
*
*/
public class GoogleLmReader<W> implements LmReader<LongRef, NgramOrderedLmReaderCallback<LongRef>>
{
public int getLmOrder() {
return lmOrder;
}
private static final String START_SYMBOL = "<S>";
private static final String END_SYMBOL = "</S>";
private static final String UNK_SYMBOL = "<UNK>";
private static final String sortedVocabFile = "vocab_cs.gz";
private final File[] ngramDirectories;
private final int lmOrder;
private final WordIndexer<W> wordIndexer;
public GoogleLmReader(final String rootDir, final WordIndexer<W> wordIndexer, @SuppressWarnings("unused") final ConfigOptions opts) {
this.wordIndexer = wordIndexer;
ngramDirectories = new File(rootDir).listFiles(new FilenameFilter()
{
@Override
public boolean accept(final File dir, final String name) {
return name.endsWith("gms");
}
});
Arrays.sort(ngramDirectories);
lmOrder = ngramDirectories.length;
}
@Override
public void parse(final NgramOrderedLmReaderCallback<LongRef> callback) {
int ngramOrder = 0;
for (final File ngramDir : ngramDirectories) {
final int ngramOrder_ = ngramOrder;
final String regex = (ngramOrder_ + 1) + "gm-\\d+(.gz)?";
final File[] ngramFiles = ngramDir.listFiles(new FilenameFilter()
{
@Override
public boolean accept(final File dir, final String name) {
return ngramOrder_ == 0 ? name.equals(sortedVocabFile) : name.matches(regex);
}
});
if (ngramOrder == 0) {
if (ngramFiles.length != 1) throw new RuntimeException("Could not find expected vocab file " + sortedVocabFile);
final String sortedVocabPath = ngramFiles[0].getPath();
addToIndexer(wordIndexer, sortedVocabPath);
} else if (ngramFiles.length == 0) {
Logger.warn("Did not find any files matching expected regex " + regex);
}
Arrays.sort(ngramFiles);
Logger.startTrack("Reading ngrams of order " + (ngramOrder_ + 1));
for (final File ngramFile_ : ngramFiles) {
final File ngramFile = ngramFile_;
Logger.startTrack("Reading ngrams from file " + ngramFile);
try {
int k = 0;
for (String line : Iterators.able(IOUtils.lineIterator(ngramFile.getPath()))) {
if (k % 10000 == 0) Logger.logs("Line " + k);
k++;
line = line.trim();
try {
parseLine(line, ngramOrder, callback);
} catch (Throwable e) {
throw new RuntimeException("Could not parse line " + k + " '" + line + "' from file " + ngramFile + "\n", e);
}
}
} catch (final IOException e) {
throw new RuntimeException("Could not read file " + ngramFile + "\n", e);
}
Logger.endTrack();
}
Logger.endTrack();
callback.handleNgramOrderFinished(++ngramOrder);
}
callback.cleanup();
}
/**
* @param callback
* @param ngramOrder
* @param line
* @return
*/
private void parseLine(final String line, final int ngramOrder, final NgramOrderedLmReaderCallback<LongRef> callback) {
final int tabIndex = line.indexOf('\t');
int spaceIndex = 0;
final int[] ngram = new int[ngramOrder + 1];
final String words = line.substring(0, tabIndex);
for (int i = 0;; ++i) {
int nextIndex = line.indexOf(' ', spaceIndex);
if (nextIndex < 0) nextIndex = words.length();
final String word = words.substring(spaceIndex, nextIndex);
ngram[i] = wordIndexer.getOrAddIndexFromString(word);
if (nextIndex == words.length()) break;
spaceIndex = nextIndex + 1;
}
final long count = Long.parseLong(line.substring(tabIndex + 1));
callback.call(ngram, 0, ngram.length, new LongRef(count), words);
}
/**
* @param sortedVocabPath
*/
public static <W> void addToIndexer(final WordIndexer<W> wordIndexer, final String sortedVocabPath) {
if (!(new File(sortedVocabPath).getName().equals(sortedVocabFile))) {
Logger.warn("You have specified that " + sortedVocabPath + " is the count-sorted vocab file for Google n-grams, but it is usually named "
+ sortedVocabFile);
}
try {
for (final String line : Iterators.able(IOUtils.lineIterator(sortedVocabPath))) {
final String[] parts = line.split("\t");
final String word = parts[0];
wordIndexer.getOrAddIndexFromString(word);
}
} catch (final NumberFormatException e) {
throw new RuntimeException(e);
} catch (final IOException e) {
throw new RuntimeException(e);
}
addSpecialSymbols(wordIndexer);
}
/**
*
*/
public static <W> void addSpecialSymbols(final WordIndexer<W> wordIndexer) {
wordIndexer.setStartSymbol(wordIndexer.getWord(wordIndexer.getOrAddIndexFromString(START_SYMBOL)));
wordIndexer.setEndSymbol(wordIndexer.getWord(wordIndexer.getOrAddIndexFromString(END_SYMBOL)));
wordIndexer.setUnkSymbol(wordIndexer.getWord(wordIndexer.getOrAddIndexFromString(UNK_SYMBOL)));
}
}