/*
* Copyright 1999-2002 Carnegie Mellon University.
* Portions Copyright 2002 Sun Microsystems, Inc.
* Portions Copyright 2002 Mitsubishi Electric Research Laboratories.
* Portions Copyright 2010 LIUM, University of Le Mans, France
-> Anthony Rousseau, Teva Merlin, Yannick Esteve
* All Rights Reserved. Use is subject to license terms.
*
* See the file "license.terms" for information on usage and
* redistribution of this file, and for a DISCLAIMER OF ALL
* WARRANTIES.
*
*/
package edu.cmu.sphinx.linguist.language.ngram.large;
import edu.cmu.sphinx.linguist.dictionary.Dictionary;
import edu.cmu.sphinx.util.LogMath;
import edu.cmu.sphinx.util.Utilities;
import java.io.*;
import java.util.regex.Pattern;
import java.util.regex.Matcher;
/**
* Reads a binary NGram language model file ("DMP file") generated by the SphinxBase sphinx_lm_convert.
* <p>
* Note that all probabilities in the grammar are stored in LogMath log base format. Language
* Probabilities in the language model file are stored in log 10 base. They are converted to
* the LogMath base.
*/
public class BinaryLoader {
private static final String DARPA_TG_HEADER = "Darpa Trigram LM";
private static final String DARPA_QG_HEADER = "Darpa Quadrigram LM";
// For convenience, NG Header is regular expression, so there is 2 extra characters in it.
// Therefore, header.length() must be adjusted by -1 (and not +1),
// and we use Pattern.matches() for equality in header names.
private static final String DARPA_NG_HEADER = "Darpa \\d-gram LM";
private static final int LOG2_NGRAM_SEGMENT_SIZE = 9;
private static final float MIN_PROBABILITY = -99.0f;
private static final int MAX_PROB_TABLE_SIZE = java.lang.Integer.MAX_VALUE;
private LogMath logMath;
private int maxNGram;
private float unigramWeight;
private float languageWeight;
private double wip;
private boolean bigEndian = true;
private boolean applyLanguageWeightAndWip;
private long bytesRead;
private UnigramProbability[] unigrams;
private String[] words;
private long[] NGramOffset;
private int[] numberNGrams;
private int logNGramSegmentSize;
private int startWordID;
private int endWordID;
private int[][] NGramSegmentTable;
private float[][] NGramProbTable;
private float[][] NGramBackoffTable;
private RandomAccessFile file;
// Bytes multiplier for LM (2 = 16 bits, 4 = 32 bits)
private int bytesPerField;
/**
* Initializes the binary loader
*
* @param location location of the model
* @param format file format
* @param applyLanguageWeightAndWip if true apply language weight and word insertion penalty
* @param languageWeight language weight
* @param wip word insertion probability
* @param unigramWeight unigram weight
* @throws IOException if an I/O error occurs
*/
public BinaryLoader(File location, String format,
boolean applyLanguageWeightAndWip,
float languageWeight, double wip, float unigramWeight)
throws IOException {
this(format, applyLanguageWeightAndWip, languageWeight, wip, unigramWeight);
loadModelLayout(new FileInputStream (location));
file = new RandomAccessFile(location, "r");
}
/**
* Initializes the binary loader
*
* @param format file format
* @param applyLanguageWeightAndWip if true apply language weight and word insertion penalty
* @param languageWeight language weight
* @param wip word insertion probability
* @param unigramWeight unigram weight
*/
public BinaryLoader(String format, boolean applyLanguageWeightAndWip, float
languageWeight, double wip,
float unigramWeight) {
startWordID = -1;
endWordID = -1;
this.applyLanguageWeightAndWip = applyLanguageWeightAndWip;
logMath = LogMath.getLogMath();
this.languageWeight = languageWeight;
this.wip = wip;
this.unigramWeight = unigramWeight;
}
public void deallocate() throws IOException {
if (null != file)
file.close();
}
/**
* Returns the number of unigrams
*
* @return the number of unigrams
*/
public int getNumberUnigrams() {
return getNumberNGrams(1);
}
/**
* Returns the number of bigrams
*
* @return the number of bigrams
*/
public int getNumberBigrams() {
return getNumberNGrams(2);
}
/**
* Returns the number of trigrams
*
* @return the number of trigrams
*/
public int getNumberTrigrams() {
return getNumberNGrams(3);
}
/**
* Returns the number of NGrams at
* a specified N order.
*
* @param n the desired order
* @return the number of NGrams
*/
public int getNumberNGrams(int n) {
// Be sure that we don't overcome the model
assert (n <= maxNGram) & (n > 0);
return numberNGrams[n - 1];
}
/**
* Returns all the unigrams
*
* @return all the unigrams
*/
public UnigramProbability[] getUnigrams() {
return unigrams;
}
/**
* Returns all the bigram probabilities.
*
* @return all the bigram probabilities
*/
public float[] getBigramProbabilities() {
return getNGramProbabilities(2);
}
/**
* Returns all the trigram probabilities.
*
* @return all the trigram probabilities
*/
public float[] getTrigramProbabilities() {
return getNGramProbabilities(3);
}
/**
* Returns all the trigram backoff weights
*
* @return all the trigram backoff weights
*/
public float[] getTrigramBackoffWeights() {
return getNGramBackoffWeights(3);
}
/**
* Returns the trigram segment table.
*
* @return the trigram segment table
*/
public int[] getTrigramSegments() {
return getNGramSegments(3);
}
/**
* Returns the log of the bigram segment size
*
* @return the log of the bigram segment size
*/
public int getLogBigramSegmentSize() {
return logNGramSegmentSize;
}
/**
* Returns all the NGram probabilities at
* a specified N order.
*
* @param n the desired order
* @return all the NGram probabilities
*/
public float[] getNGramProbabilities(int n) {
// Be sure that we don't overcome the model
assert (n <= maxNGram) && (n > 1);
return NGramProbTable[n - 1];
}
/**
* Returns all the NGram backoff weights at
* a specified N order.
*
* @param n the desired order
* @return all the NGram backoff weights
*/
public float[] getNGramBackoffWeights(int n) {
// Be sure that we don't overcome the model
assert (n <= maxNGram) & (n > 2);
return NGramBackoffTable[n - 1];
}
/**
* Returns the NGram segment table at
* a specified order.
*
* @param n the desired order
* @return the NGram segment table
*/
public int[] getNGramSegments(int n) {
// Be sure that we don't overcome the model
assert (n <= maxNGram) & (n > 2);
return NGramSegmentTable[n - 1];
}
/**
* Returns the log of the NGram segment size
*
* @return the log of the NGram segment size
*/
public int getLogNGramSegmentSize() {
return logNGramSegmentSize;
}
/**
* Returns all the words.
*
* @return all the words
*/
public String[] getWords() {
return words;
}
/**
* Returns the location (or offset) into the file where bigrams start.
*
* @return the location of the bigrams
*/
public long getBigramOffset() {
return getNGramOffset(2);
}
/**
* Returns the location (or offset) into the file where trigrams start.
*
* @return the location of the trigrams
*/
public long getTrigramOffset() {
return getNGramOffset(3);
}
/**
* Returns the location (or offset) into the file where NGrams start
* at a specified N order.
*
* @param n the desired order
* @return the location of the bigrams
*/
public long getNGramOffset(int n) {
// Be sure that we don't overcome the model
assert (n <= maxNGram) & (n > 1);
return NGramOffset[n - 1];
}
/**
* Returns the maximum depth of the language model
*
* @return the maximum depth of the language model
*/
public int getMaxDepth() {
return maxNGram;
}
/**
* Returns true if the loaded file is in big-endian.
*
* @return true if the loaded file is big-endian
*/
public boolean getBigEndian() {
return bigEndian;
}
/**
* Returns the multiplier for the size of a NGram
* (1 for 16 bits, 2 for 32 bits).
*
* @return the multiplier for the size of a NGram
*/
public int getBytesPerField() {
return bytesPerField;
}
/**
* Loads the contents of the memory-mapped file starting at the given position and for the given size, into a byte
* buffer. This method is implemented because MappedByteBuffer.load() does not work properly.
*
* @param position the starting position in the file
* @param size the number of bytes to load
* @return the loaded ByteBuffer
* @throws java.io.IOException if IO went wrong
*/
public byte[] loadBuffer(long position, int size) throws IOException {
// assert ((position + size) <= fileChannel.size());
file.seek(position);
byte[] bytes = new byte[size];
if (file.read(bytes) != size) {
throw new IOException("Incorrect number of bytes read. Size = " + size + ". Position =" + position + ".");
}
return bytes;
}
/**
* Loads the language model from the given file.
*
* @param inputStream stream to read the language model data
* @throws java.io.IOException if IO went wrong
*/
protected void loadModelLayout(InputStream inputStream) throws IOException {
DataInputStream stream = new DataInputStream
(new BufferedInputStream(inputStream));
// read standard header string-size; set bigEndian flag
readHeader(stream);
// +1 is the sentinel unigram at the end
unigrams = readUnigrams(stream, numberNGrams[0] + 1, bigEndian);
skipNGrams(stream);
// Read the NGram prob & bow tables
for (int i = 1; i < maxNGram; i++) {
if (numberNGrams[i] > 0) {
if (i == 1) {
NGramProbTable[i] = readFloatTable(stream, bigEndian);
}
else {
NGramBackoffTable[i] = readFloatTable(stream, bigEndian);
NGramProbTable[i] = readFloatTable(stream, bigEndian);
int nMinus1gramSegmentSize = 1 << logNGramSegmentSize;
int NGramSegTableSize = ((numberNGrams[i - 1] + 1) / nMinus1gramSegmentSize) + 1;
NGramSegmentTable[i] = readIntTable(stream, bigEndian, NGramSegTableSize);
}
}
}
// read word string names
int wordsStringLength = readInt(stream, bigEndian);
if (wordsStringLength <= 0) {
throw new Error("Bad word string size: " + wordsStringLength);
}
// read the string of all words
this.words = readWords(stream, wordsStringLength, numberNGrams[0]);
// A voir
if (startWordID > -1) {
UnigramProbability unigram = unigrams[startWordID];
unigram.setLogProbability(MIN_PROBABILITY);
}
if (endWordID > -1) {
UnigramProbability unigram = unigrams[endWordID];
unigram.setLogBackoff(MIN_PROBABILITY);
}
applyUnigramWeight();
if (applyLanguageWeightAndWip) {
for (int i = 0; i <= maxNGram; i++) {
applyLanguageWeight(NGramProbTable[i], languageWeight);
applyWip(NGramProbTable[i], wip);
if (i > 1) {
applyLanguageWeight(NGramBackoffTable[i], languageWeight);
}
}
}
stream.close();
}
/**
* Reads the LM file header
*
* @param stream the data stream of the LM file
* @throws java.io.IOException
*/
private void readHeader(DataInputStream stream) throws IOException {
int headerLength = readInt(stream, bigEndian);
if ((headerLength != DARPA_TG_HEADER.length() + 1) && (headerLength != DARPA_QG_HEADER.length() + 1) && (headerLength != DARPA_NG_HEADER.length() - 1)) { // not big-endian
headerLength = Utilities.swapInteger(headerLength);
if (headerLength == (DARPA_TG_HEADER.length() + 1) || headerLength == (DARPA_QG_HEADER.length() + 1) || headerLength == (DARPA_NG_HEADER.length() - 1)) {
bigEndian = false;
// System.out.println("Little-endian");
} else {
throw new Error("Bad binary LM file magic number: "
+ headerLength + ", not an LM dumpfile?");
}
} else {
// System.out.println("Big-endian");
}
// read and verify standard header string
String header = readString(stream, headerLength - 1);
stream.readByte(); // read the '\0'
bytesRead++;
if (!header.equals(DARPA_TG_HEADER) && !header.equals(DARPA_QG_HEADER) && !Pattern.matches(DARPA_NG_HEADER, header)) {
throw new Error("Bad binary LM file header: " + header);
}
else {
if (header.equals(DARPA_TG_HEADER))
maxNGram = 3;
else if (header.equals(DARPA_QG_HEADER))
maxNGram = 4;
else {
Pattern p = Pattern.compile("\\d");
Matcher m = p.matcher(header);
maxNGram = Integer.parseInt(m.group());
}
}
// read LM filename string size and string
int fileNameLength = readInt(stream, bigEndian);
skipStreamBytes(stream, fileNameLength);
numberNGrams = new int[maxNGram];
NGramOffset = new long[maxNGram];
NGramProbTable = new float[maxNGram][];
NGramBackoffTable = new float[maxNGram][];
NGramSegmentTable = new int[maxNGram][];
numberNGrams[0] = 0;
logNGramSegmentSize = LOG2_NGRAM_SEGMENT_SIZE;
// read version number, if present. it must be <= 0.
int version = readInt(stream, bigEndian);
// System.out.println("Version: " + version);
bytesPerField = 2;
if (version <= 0) { // yes, its the version number
readInt(stream, bigEndian); // read and skip timestamp
// Means we are going 32 bits.
if (version <= -3)
bytesPerField = 4;
// read and skip format description
int formatLength;
for (; ;) {
if ((formatLength = readInt(stream, bigEndian)) == 0) {
break;
}
bytesRead += stream.skipBytes(formatLength);
}
// read log NGram segment size if present
// only for 16 bits version 2 LM
if (version == -2) {
logNGramSegmentSize = readInt(stream, bigEndian);
if (logNGramSegmentSize < 1 || logNGramSegmentSize > 15) {
throw new Error("log2(bg_seg_sz) outside range 1..15");
}
}
numberNGrams[0] = readInt(stream, bigEndian);
} else {
numberNGrams[0] = version;
}
if (numberNGrams[0] <= 0) {
throw new Error("Bad number of unigrams: " + numberNGrams[0]
+ ", must be > 0.");
}
for (int i = 1; i < maxNGram; i++) {
if ((numberNGrams[i] = readInt(stream, bigEndian)) < 0) {
throw new Error("Bad number of " + String.valueOf(i) + "-grams: " + numberNGrams[i]);
}
}
}
/**
* Skips the NGrams of the LM.
*
* @param stream
* the source of data
* @throws java.io.IOException
*/
private void skipNGrams(DataInputStream stream) throws IOException {
long bytesToSkip;
NGramOffset[1] = bytesRead;
bytesToSkip = (numberNGrams[1] + 1) * LargeNGramModel.BYTES_PER_NGRAM * getBytesPerField();
skipStreamBytes(stream, bytesToSkip);
for (int i = 2; i < maxNGram; i++) {
if (numberNGrams[i] > 0 && i < maxNGram - 1) {
NGramOffset[i] = bytesRead;
bytesToSkip = (long) (numberNGrams[i] + 1) * (long) LargeNGramModel.BYTES_PER_NGRAM * getBytesPerField();
skipStreamBytes(stream, bytesToSkip);
} else if (numberNGrams[i] > 0 && i == maxNGram - 1) {
NGramOffset[i] = bytesRead;
bytesToSkip = (long) (numberNGrams[i]) * (long) LargeNGramModel.BYTES_PER_NMAXGRAM * getBytesPerField();
skipStreamBytes(stream, bytesToSkip);
}
}
}
/**
* Reliable skip
*
* @param stream stream
* @param bytes number of bytes
*/
private void skipStreamBytes(DataInputStream stream, long bytes) throws IOException {
while (bytes > 0) {
long skipped = stream.skip(bytes);
bytesRead += skipped;
bytes -= skipped;
}
}
/** Apply the unigram weight to the set of unigrams */
private void applyUnigramWeight() {
float logUnigramWeight = logMath.linearToLog(unigramWeight);
float logNotUnigramWeight = logMath.linearToLog(1.0f - unigramWeight);
float logUniform = logMath.linearToLog(1.0f / (numberNGrams[0]));
float logWip = logMath.linearToLog(wip);
float p2 = logUniform + logNotUnigramWeight;
for (int i = 0; i < numberNGrams[0]; i++) {
UnigramProbability unigram = unigrams[i];
float p1 = unigram.getLogProbability();
if (i != startWordID) {
p1 += logUnigramWeight;
p1 = logMath.addAsLinear(p1, p2);
}
if (applyLanguageWeightAndWip) {
p1 = p1 * languageWeight + logWip;
unigram.setLogBackoff(unigram.getLogBackoff() * languageWeight);
}
unigram.setLogProbability(p1);
}
}
/** Apply the language weight to the given array of probabilities.
*/
private void applyLanguageWeight(float[] logProbabilities,
float languageWeight) {
for (int i = 0; i < logProbabilities.length; i++) {
logProbabilities[i] = logProbabilities[i] * languageWeight;
}
}
/** Apply the WIP to the given array of probabilities.
*/
private void applyWip(float[] logProbabilities, double wip) {
float logWip = logMath.linearToLog(wip);
for (int i = 0; i < logProbabilities.length; i++) {
logProbabilities[i] = logProbabilities[i] + logWip;
}
}
/**
* Reads the probability table from the given DataInputStream.
*
* @param stream the DataInputStream from which to read the table
* @param bigEndian true if the given stream is bigEndian, false otherwise
* @throws java.io.IOException
*/
private float[] readFloatTable(DataInputStream stream, boolean bigEndian)
throws IOException {
int numProbs = readInt(stream, bigEndian);
if (numProbs <= 0 || numProbs > MAX_PROB_TABLE_SIZE) {
throw new Error("Bad probabilities table size: " + numProbs);
}
float[] probTable = new float[numProbs];
for (int i = 0; i < numProbs; i++) {
//probTable[i] = readFloat(stream, bigEndian);
probTable[i] = logMath.log10ToLog(readFloat(stream, bigEndian));
}
return probTable;
}
/**
* Reads a table of integers from the given DataInputStream.
*
* @param stream the DataInputStream from which to read the table
* @param bigEndian true if the given stream is bigEndian, false otherwise
* @param tableSize the size of the NGram segment table
* @return the NGram segment table, which is an array of integers
* @throws java.io.IOException
*/
private int[] readIntTable(DataInputStream stream, boolean bigEndian,
int tableSize) throws IOException {
int numSegments = readInt(stream, bigEndian);
if (numSegments != tableSize) {
throw new Error("Bad NGram seg table size: " + numSegments);
}
int[] segmentTable = new int[numSegments];
for (int i = 0; i < numSegments; i++) {
segmentTable[i] = readInt(stream, bigEndian);
}
return segmentTable;
}
/**
* Read in the unigrams in the given DataInputStream.
*
* @param stream the DataInputStream to read from
* @param numberUnigrams the number of unigrams to read
* @param bigEndian true if the DataInputStream is big-endian, false otherwise
* @return an array of UnigramProbability index by the unigram ID
* @throws java.io.IOException
*/
private UnigramProbability[] readUnigrams(DataInputStream stream,
int numberUnigrams, boolean bigEndian) throws IOException {
UnigramProbability[] unigrams = new UnigramProbability[numberUnigrams];
for (int i = 0; i < numberUnigrams; i++) {
// read unigram ID, unigram probability, unigram backoff weight
int unigramID = readInt(stream, bigEndian);
/* Some tools to convert to DMP doesn't store ID's in unigrams */
if (unigramID < 1)
unigramID = i;
// if we're not reading the sentinel unigram at the end,
// make sure that the unigram IDs are consecutive
if (i != (numberUnigrams - 1)) {
assert (unigramID == i);
}
float unigramProbability = readFloat(stream, bigEndian);
float unigramBackoff = readFloat(stream, bigEndian);
int firstBigramEntry = readInt(stream, bigEndian);
float logProbability = logMath.log10ToLog(unigramProbability);
float logBackoff = logMath.log10ToLog(unigramBackoff);
unigrams[i] = new UnigramProbability(unigramID, logProbability,
logBackoff, firstBigramEntry);
}
return unigrams;
}
/**
* Reads an integer from the given DataInputStream.
*
* @param stream the DataInputStream to read from
* @param bigEndian true if the DataInputStream is in bigEndian, false otherwise
* @return the integer read
* @throws java.io.IOException
*/
private int readInt(DataInputStream stream, boolean bigEndian)
throws IOException {
bytesRead += 4;
if (bigEndian) {
return stream.readInt();
} else {
return Utilities.readLittleEndianInt(stream);
}
}
/**
* Reads a float from the given DataInputStream.
*
* @param stream the DataInputStream to read from
* @param bigEndian true if the DataInputStream is in bigEndian, false otherwise
* @return the float read
* @throws java.io.IOException
*/
private float readFloat(DataInputStream stream, boolean bigEndian)
throws IOException {
bytesRead += 4;
if (bigEndian) {
return stream.readFloat();
} else {
return Utilities.readLittleEndianFloat(stream);
}
}
/**
* Reads a string of the given length from the given DataInputStream. It is assumed that the DataInputStream
* contains 8-bit chars.
*
* @param stream the DataInputStream to read from
* @param length the number of characters in the returned string
* @return a string of the given length from the given DataInputStream
* @throws java.io.IOException
*/
private String readString(DataInputStream stream, int length)
throws IOException {
StringBuilder builder = new StringBuilder();
byte[] bytes = new byte[length];
bytesRead += stream.read(bytes);
for (int i = 0; i < length; i++) {
builder.append((char) bytes[i]);
}
return builder.toString();
}
/**
* Reads a series of consecutive Strings from the given stream.
*
* @param stream the DataInputStream to read from
* @param length the total length in bytes of all the Strings
* @param numberUnigrams the number of String to read
* @return an array of the Strings read
* @throws java.io.IOException
*/
private String[] readWords(DataInputStream stream, int length,
int numberUnigrams) throws IOException {
String[] words = new String[numberUnigrams];
byte[] bytes = new byte[length];
bytesRead += stream.read(bytes);
int s = 0;
int wordStart = 0;
for (int i = 0; i < length; i++) {
char c = (char) (bytes[i] & 0xFF);
bytesRead++;
if (c == '\0') {
// if its the end of a string, add it to the 'words' array
words[s] = new String(bytes, wordStart, i - wordStart);
wordStart = i + 1;
if (words[s].equals(Dictionary.SENTENCE_START_SPELLING)) {
startWordID = s;
} else if (words[s].equals(Dictionary.SENTENCE_END_SPELLING)) {
endWordID = s;
}
s++;
}
}
assert (s == numberUnigrams);
return words;
}
}