package edu.berkeley.nlp.lm.values;
import java.util.List;
import edu.berkeley.nlp.lm.array.CustomWidthArray;
import edu.berkeley.nlp.lm.array.LongArray;
import edu.berkeley.nlp.lm.bits.BitList;
import edu.berkeley.nlp.lm.bits.BitStream;
import edu.berkeley.nlp.lm.collections.Indexer;
import edu.berkeley.nlp.lm.collections.LongToIntHashMap;
import edu.berkeley.nlp.lm.collections.LongToIntHashMap.Entry;
import edu.berkeley.nlp.lm.util.Logger;
import edu.berkeley.nlp.lm.util.LongRef;
import edu.berkeley.nlp.lm.util.Annotations.OutputParameter;
import edu.berkeley.nlp.lm.util.Annotations.PrintMemoryCount;
public final class UncompressedProbBackoffValueContainer extends RankedValueContainer<ProbBackoffPair> implements ProbBackoffValueContainer
{
private static final long serialVersionUID = 964277160049236607L;
@PrintMemoryCount
final long[] probsAndBackoffsForRank; // ugly: we encode probs and backoffs consecutively in this area to improve cache locality
transient LongToIntHashMap countIndexer;
public UncompressedProbBackoffValueContainer(final LongToIntHashMap countCounter, final int valueRadix, final boolean storePrefixes,
long[] numNgramsForEachOrder) {
super(valueRadix, storePrefixes, numNgramsForEachOrder);
Logger.startTrack("Storing values");
final long defaultVal = getDefaultVal().asLong();
final boolean hasDefaultVal = countCounter.get(defaultVal, -1) >= 0;
probsAndBackoffsForRank = new long[(countCounter.size() + (hasDefaultVal ? 0 : 1))];
countIndexer = new LongToIntHashMap();
int k = 0;
for (final Entry pair : countCounter.getObjectsSortedByValue(true)) {
countIndexer.put(pair.key, countIndexer.size());
probsAndBackoffsForRank[k++] = pair.key;
if (countIndexer.size() == defaultValRank && !hasDefaultVal) {
countIndexer.put(defaultVal, countIndexer.size());
probsAndBackoffsForRank[k++] = defaultVal;
}
}
if (countIndexer.size() < defaultValRank && !hasDefaultVal) {
countIndexer.put(defaultVal, countIndexer.size());
probsAndBackoffsForRank[k++] = defaultVal;
}
valueWidth = CustomWidthArray.numBitsNeeded(countIndexer.size());
Logger.logss("Storing count indices using " + valueWidth + " bits.");
Logger.endTrack();
}
/**
* @param valueRadix
* @param storePrefixIndexes
* @param maxNgramOrder
* @param hasBackoffValIndexer
* @param noBackoffValIndexer
* @param probsAndBackoffsForRank
* @param probsForRank
* @param hasBackoffValIndexer
*/
public UncompressedProbBackoffValueContainer(int valueRadix, boolean storePrefixIndexes, long[] numNgramsForEachOrder, long[] probsAndBackoffsForRank,
LongToIntHashMap countIndexer, int wordWidth) {
super(valueRadix, storePrefixIndexes, numNgramsForEachOrder);
this.countIndexer = countIndexer;
this.probsAndBackoffsForRank = probsAndBackoffsForRank;
super.valueWidth = wordWidth;
}
@Override
public UncompressedProbBackoffValueContainer createFreshValues(long[] numNgramsForEachOrder_) {
return new UncompressedProbBackoffValueContainer(valueRadix, storeSuffixIndexes, numNgramsForEachOrder_, probsAndBackoffsForRank, countIndexer,
valueWidth);
}
@Override
public final float getProb(final int ngramOrder, final long index) {
return getCount(ngramOrder, index, false);
}
public final long getInternalVal(final int ngramOrder, final long index) {
return valueRanks[ngramOrder].get(index);
}
public final float getProb(final CustomWidthArray valueRanksForOrder, final long index) {
return getCount(valueRanksForOrder, index, false);
}
@Override
public void getFromOffset(final long index, final int ngramOrder, @OutputParameter final ProbBackoffPair outputVal) {
final long rank = getRank(ngramOrder, index);
getFromRank(rank, outputVal);
}
/**
* @param ngramOrder
* @param index
* @param uncompressProbs2
* @return
*/
private float getCount(final int ngramOrder, final long index, final boolean backoff) {
final long rank = getRank(ngramOrder, index);
return getFromRank(rank, backoff);
}
private float getCount(final CustomWidthArray valueRanksForOrder, final long index, final boolean backoff) {
final long rank = valueRanksForOrder.get(index);
return getFromRank(rank, backoff);
}
private float getFromRank(final long rank, final boolean backoff) {
return backoff ? ProbBackoffPair.backoffOf(probsAndBackoffsForRank[(int) rank]) : ProbBackoffPair.probOf(probsAndBackoffsForRank[(int) rank]);//backoff ? backoffsForRank[backoffRankOf(val)] : probsForRank[probRankOf(val)];
}
/*
* (non-Javadoc)
*
* @see edu.berkeley.nlp.lm.values.IProb#getBackoff(int, long)
*/
@Override
public final float getBackoff(final int ngramOrder, final long index) {
return getCount(ngramOrder, index, true);
}
public final float getBackoff(final CustomWidthArray valueRanksForNgramOrder, final long index) {
return getCount(valueRanksForNgramOrder, index, true);
}
@Override
protected ProbBackoffPair getDefaultVal() {
return new ProbBackoffPair(Float.NaN, Float.NaN);
}
@Override
protected void getFromRank(final long rank, @OutputParameter final ProbBackoffPair outputVal) {
outputVal.prob = getFromRank(rank, false);
outputVal.backoff = getFromRank(rank, true);
}
/*
* (non-Javadoc)
*
* @see edu.berkeley.nlp.lm.values.IProb#getScratchValue()
*/
@Override
public ProbBackoffPair getScratchValue() {
return new ProbBackoffPair(Float.NaN, Float.NaN);
}
@Override
public void setFromOtherValues(final ValueContainer<ProbBackoffPair> o) {
super.setFromOtherValues(o);
this.countIndexer = ((UncompressedProbBackoffValueContainer) o).countIndexer;
}
@Override
public void trim() {
super.trim();
countIndexer = null;
}
@Override
protected long getCountRank(long val) {
return countIndexer.get(val, -1);
}
@Override
protected boolean useValueStoringArray() {
return true;
}
}