package edu.berkeley.nlp.lm.values;
import edu.berkeley.nlp.lm.array.LongArray;
import edu.berkeley.nlp.lm.collections.LongHashSet;
import edu.berkeley.nlp.lm.map.HashNgramMap;
import edu.berkeley.nlp.lm.map.NgramMap;
import edu.berkeley.nlp.lm.util.Annotations.OutputParameter;
import edu.berkeley.nlp.lm.util.Annotations.PrintMemoryCount;
/**
* Stored type and token counts necessary for estimating a Kneser-Ney language
* model
*
* @author adampauls
*
*/
public final class KneserNeyCountValueContainer implements ValueContainer<KneserNeyCountValueContainer.KneserNeyCounts>
{
/**
* Warning: type counts are stored internally as 32-bit ints.
*
*
*
* @author adampauls
*
*/
public static class KneserNeyCounts
{
public long tokenCounts = 0; // only stored for the highest- and second-highest-order n-grams
public long leftDotTypeCounts = 0; // N_{1+}(\cdot w) as in Chen and Goodman (1998), not stored for highest-order
public long rightDotTypeCounts = 0; // N_{1+}(w \cdot) as in Chen and Goodman (1998), not stored for highest-order
public long dotdotTypeCounts = 0; // N_{1+}(\dot w \dot) as in Chen and Goodman (1998), not stored for highest-order
// these two are used to compute the Kneser-Ney discount
public boolean isOneCount = false;
public boolean isTwoCount = false;
boolean isInternal = false;
}
private static final long serialVersionUID = 964277160049236607L;
@PrintMemoryCount
private LongArray tokenCounts; // for highest-order ngrams
private LongArray prefixTokenCounts;// for second-highest order n-grams
@PrintMemoryCount
private final LongArray[] rightDotTypeCounts;
@PrintMemoryCount
private final LongArray[] dotdotTypeCounts;
@PrintMemoryCount
private final LongArray[] leftDotTypeCounts;// secretly, only token counts are stored for n-grams starting with the start symbol
// @PrintMemoryCount
// private final LongArray[] lowestOrderTokenCounts;
@PrintMemoryCount
private final LongHashSet[] oneCountOffsets;
@PrintMemoryCount
private final LongHashSet[] twoCountOffsets;
private long bigramTypeCounts = 0;
private HashNgramMap<KneserNeyCounts> map;
private final int startIndex;
public KneserNeyCountValueContainer(final int maxNgramOrder, final int startIndex) {
this.startIndex = startIndex;
this.tokenCounts = LongArray.StaticMethods.newLongArray(Long.MAX_VALUE, Integer.MAX_VALUE);
this.prefixTokenCounts = LongArray.StaticMethods.newLongArray(Long.MAX_VALUE, Integer.MAX_VALUE);
this.oneCountOffsets = new LongHashSet[maxNgramOrder];
this.twoCountOffsets = new LongHashSet[maxNgramOrder];
rightDotTypeCounts = new LongArray[maxNgramOrder - 1];
leftDotTypeCounts = new LongArray[maxNgramOrder - 1];
dotdotTypeCounts = new LongArray[maxNgramOrder - 2];
for (int i = 0; i < maxNgramOrder; ++i) {
oneCountOffsets[i] = new LongHashSet();
twoCountOffsets[i] = new LongHashSet();
if (i < maxNgramOrder - 1) {
rightDotTypeCounts[i] = LongArray.StaticMethods.newLongArray(Long.MAX_VALUE, Integer.MAX_VALUE);
leftDotTypeCounts[i] = LongArray.StaticMethods.newLongArray(Long.MAX_VALUE, Integer.MAX_VALUE);
if (i < maxNgramOrder - 2) dotdotTypeCounts[i] = LongArray.StaticMethods.newLongArray(Long.MAX_VALUE, Integer.MAX_VALUE);
}
}
}
@Override
public KneserNeyCountValueContainer createFreshValues(long[] numNgramsForEachOrder) {
final KneserNeyCountValueContainer kneseryNeyCountValueContainer = new KneserNeyCountValueContainer(rightDotTypeCounts.length + 1, startIndex);
kneseryNeyCountValueContainer.bigramTypeCounts = this.bigramTypeCounts;
return kneseryNeyCountValueContainer;
}
@Override
public void getFromOffset(final long offset, final int ngramOrder, @OutputParameter final KneserNeyCounts outputVal) {
final boolean isHighestOrder = isHighestOrder(ngramOrder);
final boolean isSecondHighestOrder = isSecondHighestOrder(ngramOrder);
outputVal.tokenCounts = isHighestOrder ? tokenCounts.get(offset) : (isSecondHighestOrder ? getSafe(offset, prefixTokenCounts) : -1);
outputVal.rightDotTypeCounts = (int) ((isHighestOrder || (offset >= rightDotTypeCounts[ngramOrder].size())) ? -1 : rightDotTypeCounts[ngramOrder]
.get(offset));
outputVal.leftDotTypeCounts = (int) ((isHighestOrder || (offset >= leftDotTypeCounts[ngramOrder].size())) ? -1 : leftDotTypeCounts[ngramOrder]
.get(offset));
outputVal.dotdotTypeCounts = (int) ((isHighestOrder || isSecondHighestOrder || (offset >= dotdotTypeCounts[ngramOrder].size())) ? -1
: dotdotTypeCounts[ngramOrder].get(offset));
outputVal.isOneCount = oneCountOffsets[ngramOrder].containsKey(offset);
outputVal.isTwoCount = twoCountOffsets[ngramOrder].containsKey(offset);
outputVal.isInternal = true;
}
private static long getSafe(final long offset, final LongArray array) {
return offset >= array.size() ? 0 : array.get(offset);
}
@Override
public void trimAfterNgram(final int ngramOrder, final long size) {
}
@Override
public KneserNeyCounts getScratchValue() {
return new KneserNeyCounts();
}
@Override
public boolean add(final int[] ngram, final int startPos, final int endPos, final int ngramOrder, final long offset, final long contextOffset,
final int word, final KneserNeyCounts val, final long suffixOffset, final boolean ngramIsNew) {
if (val == null) return true;
final boolean startsWithStart = ngram[startPos] == startIndex;
if (isHighestOrder(ngramOrder) || startsWithStart) {
final long relevantCount = val.tokenCounts;
if (ngramIsNew) {
if (relevantCount == 1)
oneCountOffsets[ngramOrder].put(offset);
else if (relevantCount == 2) //
twoCountOffsets[ngramOrder].put(offset);
} else if (oneCountOffsets[ngramOrder].containsKey(offset)) {
oneCountOffsets[ngramOrder].remove(offset);
if (relevantCount == 1) //
twoCountOffsets[ngramOrder].put(offset);
} else if (twoCountOffsets[ngramOrder].containsKey(offset)) {
twoCountOffsets[ngramOrder].remove(offset);
}
}
if (val.tokenCounts > 0) {
if (!val.isInternal && startsWithStart && ngramOrder < leftDotTypeCounts.length) {
leftDotTypeCounts[ngramOrder].incrementCount(offset, val.tokenCounts);
}
// if (ngramIsNew) {
// if (val.tokenCounts == 1)
// oneCountOffsets[ngramOrder].put(offset);
// else if (val.tokenCounts == 2) //
// twoCountOffsets[ngramOrder].put(offset);
// } else if (oneCountOffsets[ngramOrder].containsKey(offset)) {
// oneCountOffsets[ngramOrder].remove(offset);
// if (val.tokenCounts == 1) //
// twoCountOffsets[ngramOrder].put(offset);
// } else if (twoCountOffsets[ngramOrder].containsKey(offset)) {
// twoCountOffsets[ngramOrder].remove(offset);
// }
//
// } else {
// if (val.isOneCount) oneCountOffsets[ngramOrder].put(offset);
// if (val.isTwoCount) twoCountOffsets[ngramOrder].put(offset);
}
assert !map.isReversed();
if (isHighestOrder(ngramOrder)) {
tokenCounts.incrementCount(offset, val.tokenCounts);
prefixTokenCounts.incrementCount(contextOffset, val.tokenCounts);
}
assert !(val.isInternal && !ngramIsNew);
if (ngramIsNew) {
if (val.isInternal) {
if (val.dotdotTypeCounts > 0) dotdotTypeCounts[ngramOrder].incrementCount(offset, val.dotdotTypeCounts);
if (val.leftDotTypeCounts > 0) leftDotTypeCounts[ngramOrder].incrementCount(offset, val.leftDotTypeCounts);
if (val.rightDotTypeCounts > 0) rightDotTypeCounts[ngramOrder].incrementCount(offset, val.rightDotTypeCounts);
if (val.isOneCount) oneCountOffsets[ngramOrder].put(offset);
if (val.isTwoCount) twoCountOffsets[ngramOrder].put(offset);
} else {
if (ngramOrder > 0) {
if (ngramOrder == 1) {
bigramTypeCounts++;
} else {
final long dotDotOffset = map.getPrefixOffset(suffixOffset, endPos - startPos - 2);//map.getOffsetForNgramInModel(ngram, startPos + 1, endPos - 1);
dotdotTypeCounts[ngramOrder - 2].incrementCount(dotDotOffset, 1);
}
final long leftDotOffset = suffixOffset; //map.getOffsetForNgramInModel(ngram, startPos + 1, endPos);
assert suffixOffset >= 0;
final long oldCount = leftDotOffset >= leftDotTypeCounts[ngramOrder - 1].size() ? 0 : leftDotTypeCounts[ngramOrder - 1].get(leftDotOffset);
if (oldCount == 0) {
oneCountOffsets[ngramOrder - 1].put(leftDotOffset);
} else if (oldCount == 1) {
oneCountOffsets[ngramOrder - 1].remove(leftDotOffset);
twoCountOffsets[ngramOrder - 1].put(leftDotOffset);
} else if (oldCount == 2) {
twoCountOffsets[ngramOrder - 1].remove(leftDotOffset);
}
leftDotTypeCounts[ngramOrder - 1].incrementCount(leftDotOffset, 1);
final long rightDotOffset = contextOffset;//map.getOffsetForNgramInModel(ngram, startPos, endPos - 1);
assert contextOffset >= 0;
rightDotTypeCounts[ngramOrder - 1].incrementCount(rightDotOffset, 1);
}
}
}
return true;
}
@Override
public void setSizeAtLeast(final long size, final int ngramOrder) {
if (isHighestOrder(ngramOrder)) {
tokenCounts.setAndGrowIfNeeded(size - 1, 0);
} else {
if (isSecondHighestOrder(ngramOrder)) prefixTokenCounts.setAndGrowIfNeeded(size - 1, 0);
leftDotTypeCounts[ngramOrder].setAndGrowIfNeeded(size - 1, 0);
rightDotTypeCounts[ngramOrder].setAndGrowIfNeeded(size - 1, 0);
if (!isSecondHighestOrder(ngramOrder)) dotdotTypeCounts[ngramOrder].setAndGrowIfNeeded(size - 1, 0);
}
}
/**
* @param ngramOrder
* @return
*/
private boolean isHighestOrder(final int ngramOrder) {
return ngramOrder == rightDotTypeCounts.length;
}
/**
* @param ngramOrder
* @return
*/
private boolean isSecondHighestOrder(final int ngramOrder) {
return ngramOrder == rightDotTypeCounts.length - 1;
}
@Override
public void setFromOtherValues(final ValueContainer<KneserNeyCounts> other) {
final KneserNeyCountValueContainer other_ = (KneserNeyCountValueContainer) other;
tokenCounts = other_.tokenCounts;
System.arraycopy(other_.dotdotTypeCounts, 0, dotdotTypeCounts, 0, dotdotTypeCounts.length);
System.arraycopy(other_.rightDotTypeCounts, 0, rightDotTypeCounts, 0, rightDotTypeCounts.length);
System.arraycopy(other_.leftDotTypeCounts, 0, leftDotTypeCounts, 0, leftDotTypeCounts.length);
System.arraycopy(other_.oneCountOffsets, 0, oneCountOffsets, 0, oneCountOffsets.length);
System.arraycopy(other_.twoCountOffsets, 0, twoCountOffsets, 0, twoCountOffsets.length);
prefixTokenCounts = other_.prefixTokenCounts;
bigramTypeCounts = other_.bigramTypeCounts;
}
@Override
public void trim() {
tokenCounts.trim();
prefixTokenCounts.trim();
for (int i = 0; i < rightDotTypeCounts.length; ++i) {
rightDotTypeCounts[i].trim();
leftDotTypeCounts[i].trim();
if (i < dotdotTypeCounts.length) dotdotTypeCounts[i].trim();
}
}
@Override
public void setMap(final NgramMap<KneserNeyCounts> map) {
this.map = (HashNgramMap<KneserNeyCounts>) map;
}
@Override
public void clearStorageForOrder(int ngramOrder) {
oneCountOffsets[ngramOrder].clear();
twoCountOffsets[ngramOrder].clear();
if (ngramOrder == rightDotTypeCounts.length) {
tokenCounts = null;
} else if (ngramOrder == rightDotTypeCounts.length - 1) {
prefixTokenCounts = null;
}
if (ngramOrder < rightDotTypeCounts.length) {
rightDotTypeCounts[ngramOrder] = null;
leftDotTypeCounts[ngramOrder] = null;
if (ngramOrder < dotdotTypeCounts.length) dotdotTypeCounts[ngramOrder] = null;
}
}
@Override
public boolean storeSuffixoffsets() {
return true;
}
public long getBigramTypeCounts() {
return bigramTypeCounts;
}
public int getNumOneCountNgrams(int ngramOrder) {
return oneCountOffsets[ngramOrder].size();
}
public int getNumTwoCountNgrams(int ngramOrder) {
return twoCountOffsets[ngramOrder].size();
}
@Override
public int numValueBits(int ngramOrder) {
return 0;
}
}