package edu.berkeley.nlp.lm.values;
import java.util.List;
import edu.berkeley.nlp.lm.array.CustomWidthArray;
import edu.berkeley.nlp.lm.array.LongArray;
import edu.berkeley.nlp.lm.bits.BitList;
import edu.berkeley.nlp.lm.bits.BitStream;
import edu.berkeley.nlp.lm.collections.Indexer;
import edu.berkeley.nlp.lm.collections.LongToIntHashMap;
import edu.berkeley.nlp.lm.collections.LongToIntHashMap.Entry;
import edu.berkeley.nlp.lm.util.Logger;
import edu.berkeley.nlp.lm.util.LongRef;
import edu.berkeley.nlp.lm.util.Annotations.OutputParameter;
import edu.berkeley.nlp.lm.util.Annotations.PrintMemoryCount;
public final class CompressibleProbBackoffValueContainer extends RankedValueContainer<ProbBackoffPair> implements ProbBackoffValueContainer
{
private static final long serialVersionUID = 964277160049236607L;
@PrintMemoryCount
final float[] backoffsForRank;
@PrintMemoryCount
final float[] probsForRank;
private int backoffWidth = -1;
private transient Indexer<Float> probIndexer = new Indexer<Float>();
private transient Indexer<Float> backoffIndexer = new Indexer<Float>();
public CompressibleProbBackoffValueContainer(final LongToIntHashMap countCounter, final int valueRadix, final boolean storePrefixes,
long[] numNgramsForEachOrder) {
super(valueRadix, storePrefixes, numNgramsForEachOrder);
Logger.startTrack("Storing values");
final boolean hasDefaultVal = countCounter.get(getDefaultVal().asLong(), -1) >= 0;
List<Entry> objectsSortedByValue = countCounter.getObjectsSortedByValue(true);
LongToIntHashMap probSorter = new LongToIntHashMap();
LongToIntHashMap backoffSorter = new LongToIntHashMap();
for (Entry e : objectsSortedByValue) {
probSorter.incrementCount(Float.floatToIntBits(ProbBackoffPair.probOf(e.key)) & ((1L << Integer.SIZE) - 1), e.value);
backoffSorter.incrementCount(Float.floatToIntBits(ProbBackoffPair.backoffOf(e.key)) & ((1L << Integer.SIZE) - 1), e.value);
}
for (Entry probEntry : probSorter.getObjectsSortedByValue(true)) {
probIndexer.getIndex(Float.intBitsToFloat((int) probEntry.key));
if (!hasDefaultVal && probIndexer.size() == defaultValRank) {
probIndexer.getIndex(getDefaultVal().prob);
}
}
if (!hasDefaultVal && probIndexer.size() < defaultValRank) {
probIndexer.getIndex(getDefaultVal().prob);
}
for (Entry backoffEntry : backoffSorter.getObjectsSortedByValue(true)) {
backoffIndexer.getIndex(Float.intBitsToFloat((int) backoffEntry.key));
if (!hasDefaultVal && backoffIndexer.size() == defaultValRank) {
backoffIndexer.getIndex(getDefaultVal().backoff);
}
}
if (!hasDefaultVal && backoffIndexer.size() < defaultValRank) {
backoffIndexer.getIndex(getDefaultVal().backoff);
}
probsForRank = new float[probIndexer.size()];
int a = 0;
for (float f : probIndexer.getObjects()) {
probsForRank[a++] = f;
}
backoffsForRank = new float[backoffIndexer.size()];
int b = 0;
for (float f : backoffIndexer.getObjects()) {
backoffsForRank[b++] = f;
}
backoffWidth = CustomWidthArray.numBitsNeeded(backoffIndexer.size());
valueWidth = CustomWidthArray.numBitsNeeded(probIndexer.size()) + backoffWidth;
Logger.logss("Storing count indices using " + valueWidth + " bits.");
Logger.endTrack();
}
/**
* @param dprobIndex
* @param dbackoffIndex
* @return
*/
private long combine(int dprobIndex, int dbackoffIndex) {
assert dprobIndex >= 0;
assert dbackoffIndex >= 0;
return (((long) dprobIndex) << backoffWidth) | dbackoffIndex;
}
private int backoffRankOf(long val) {
return (int) (val & ((1L << backoffWidth) - 1));
}
private int probRankOf(long val) {
return (int) (val >>> backoffWidth);
}
/**
* @param valueRadix
* @param storePrefixIndexes
* @param maxNgramOrder
* @param hasBackoffValIndexer
* @param noBackoffValIndexer
* @param probsAndBackoffsForRank
* @param probsForRank
* @param hasBackoffValIndexer
*/
public CompressibleProbBackoffValueContainer(int valueRadix, boolean storePrefixIndexes, long[] numNgramsForEachOrder, float[] probsForRank,
float[] backoffsForRank, Indexer<Float> probIndexer, int wordWidth, Indexer<Float> backoffIndexer, int backoffWidth) {
super(valueRadix, storePrefixIndexes, numNgramsForEachOrder);
this.backoffsForRank = backoffsForRank;
this.probIndexer = probIndexer;
this.backoffIndexer = backoffIndexer;
this.probsForRank = probsForRank;
super.valueWidth = wordWidth;
this.backoffWidth = backoffWidth;
}
@Override
public CompressibleProbBackoffValueContainer createFreshValues(long[] numNgramsForEachOrder_) {
return new CompressibleProbBackoffValueContainer(valueRadix, storeSuffixIndexes, numNgramsForEachOrder_, probsForRank, backoffsForRank, probIndexer,
valueWidth, backoffIndexer, backoffWidth);
}
public final float getProb(final int ngramOrder, final long index) {
return getCount(ngramOrder, index, false);
}
public final long getInternalVal(final int ngramOrder, final long index) {
return valueRanks[ngramOrder].get(index);
}
public final float getProb(final CustomWidthArray valueRanksForOrder, final long index) {
return getCount(valueRanksForOrder, index, false);
}
@Override
public void getFromOffset(final long index, final int ngramOrder, @OutputParameter final ProbBackoffPair outputVal) {
final long rank = getRank(ngramOrder, index);
getFromRank(rank, outputVal);
}
/**
* @param ngramOrder
* @param index
* @param uncompressProbs2
* @return
*/
private float getCount(final int ngramOrder, final long index, final boolean backoff) {
final long rank = getRank(ngramOrder, index);
return getFromRank(rank, backoff);
}
private float getCount(final CustomWidthArray valueRanksForOrder, final long index, final boolean backoff) {
final long rank = valueRanksForOrder.get(index);
return getFromRank(rank, backoff);
}
private float getFromRank(final long val, final boolean backoff) {
return backoff ? backoffsForRank[backoffRankOf(val)] : probsForRank[probRankOf(val)];
}
public final float getBackoff(final int ngramOrder, final long index) {
return getCount(ngramOrder, index, true);
}
public final float getBackoff(final CustomWidthArray valueRanksForNgramOrder, final long index) {
return getCount(valueRanksForNgramOrder, index, true);
}
@Override
protected ProbBackoffPair getDefaultVal() {
return new ProbBackoffPair(Float.NaN, Float.NaN);
}
@Override
protected void getFromRank(final long rank, @OutputParameter final ProbBackoffPair outputVal) {
outputVal.prob = getFromRank(rank, false);
outputVal.backoff = getFromRank(rank, true);
}
@Override
public ProbBackoffPair getScratchValue() {
return new ProbBackoffPair(Float.NaN, Float.NaN);
}
@Override
public void setFromOtherValues(final ValueContainer<ProbBackoffPair> o) {
super.setFromOtherValues(o);
this.backoffIndexer = ((CompressibleProbBackoffValueContainer) o).backoffIndexer;
this.probIndexer = ((CompressibleProbBackoffValueContainer) o).probIndexer;
}
@Override
public void trim() {
super.trim();
backoffIndexer = probIndexer = null;
}
@Override
protected long getCountRank(long val) {
return combine(probIndexer.getIndex(ProbBackoffPair.probOf(val)), backoffIndexer.getIndex(ProbBackoffPair.backoffOf(val)));
}
@Override
public BitList getCompressed(final long offset, final int ngramOrder) {
final long rank = getRank(ngramOrder, offset);
final BitList probBits = valueCoder.compress(probRankOf(rank));
if (ngramOrder < numNgramsForEachOrder.length - 1) probBits.addAll(valueCoder.compress(backoffRankOf(rank)));
return probBits;
}
@Override
public final void decompress(final BitStream bits, final int ngramOrder, final boolean justConsume, @OutputParameter final ProbBackoffPair outputVal) {
final long probRank = valueCoder.decompress(bits);
final long backoffRank = (ngramOrder < numNgramsForEachOrder.length - 1) ? valueCoder.decompress(bits) : -1;
if (justConsume) return;
if (outputVal != null) {
outputVal.prob = probsForRank[(int) probRank];
outputVal.backoff = (ngramOrder < numNgramsForEachOrder.length - 1) ? backoffsForRank[(int) backoffRank] : 0;
}
}
}