package edu.berkeley.nlp.lm.values;
import edu.berkeley.nlp.lm.array.CustomWidthArray;
import edu.berkeley.nlp.lm.bits.BitList;
import edu.berkeley.nlp.lm.bits.BitStream;
import edu.berkeley.nlp.lm.collections.Indexer;
import edu.berkeley.nlp.lm.collections.LongToIntHashMap;
import edu.berkeley.nlp.lm.collections.LongToIntHashMap.Entry;
import edu.berkeley.nlp.lm.util.Annotations.OutputParameter;
import edu.berkeley.nlp.lm.util.Annotations.PrintMemoryCount;
import edu.berkeley.nlp.lm.util.LongRef;
public final class CountValueContainer extends RankedValueContainer<LongRef>
{
private static final long serialVersionUID = 964277160049236607L;
@PrintMemoryCount
private final long[] countsForRank;
private transient LongToIntHashMap countIndexer;
private long unigramSum = 0L;
public CountValueContainer(final LongToIntHashMap countCounter, final int valueRadix, final boolean storePrefixes, final long[] numNgramsForEachOrder) {
super(valueRadix, storePrefixes, numNgramsForEachOrder);
final boolean hasDefaultVal = countCounter.get(getDefaultVal().asLong(), -1) >= 0;
countsForRank = new long[countCounter.size() + (hasDefaultVal ? 0 : 1)];
countIndexer = new LongToIntHashMap();
int k = 0;
for (final Entry pair : countCounter.getObjectsSortedByValue(true)) {
countIndexer.put(pair.key, countIndexer.size());
countsForRank[k++] = pair.key;
if (countIndexer.size() == defaultValRank && !hasDefaultVal) {
countIndexer.put(getDefaultVal().asLong(), countIndexer.size());
countsForRank[k++] = getDefaultVal().asLong();
}
}
if (countIndexer.size() < defaultValRank && !hasDefaultVal) {
countIndexer.put(getDefaultVal().asLong(), countIndexer.size());
countsForRank[k++] = getDefaultVal().asLong();
}
valueWidth = CustomWidthArray.numBitsNeeded(countIndexer.size());
}
/**
* @param valueRadix
* @param storePrefixIndexes
* @param maxNgramOrder
* @param countsForRank
* @param countIndexer
*/
private CountValueContainer(int valueRadix, boolean storePrefixIndexes, long[] numNgramsForEachOrder, long[] countsForRank, LongToIntHashMap countIndexer,
int wordWidth) {
super(valueRadix, storePrefixIndexes, numNgramsForEachOrder);
this.countsForRank = countsForRank;
this.countIndexer = countIndexer;
this.valueWidth = wordWidth;
}
@Override
public CountValueContainer createFreshValues(long[] numNgramsForEachOrder_) {
return new CountValueContainer(valueRadix, storeSuffixIndexes, numNgramsForEachOrder_, countsForRank, countIndexer, valueWidth);
}
@Override
public void getFromOffset(final long index, final int ngramOrder, @OutputParameter final LongRef outputVal) {
outputVal.value = getCount(ngramOrder, index, countsForRank);
}
@Override
protected void getFromRank(final long rank, @OutputParameter final LongRef outputVal) {
outputVal.value = countsForRank[(int) rank];
}
public final long getCount(final int ngramOrder, final long index) {
return getCount(ngramOrder, index, countsForRank);
}
/**
* @param ngramOrder
* @param index
* @param uncompressProbs2
* @return
*/
private long getCount(final int ngramOrder, final long index, final long[] array) {
final int countIndex = (int) valueRanks[ngramOrder].get(index);
return array[countIndex];
}
@Override
protected LongRef getDefaultVal() {
return new LongRef(-1L);
}
@Override
public void trimAfterNgram(final int ngramOrder, final long size) {
super.trimAfterNgram(ngramOrder, size);
if (ngramOrder == 0) {
for (int i = 0; i < valueRanks[ngramOrder].size(); ++i) {
unigramSum += countsForRank[(int) valueRanks[ngramOrder].get(i)];
}
}
}
public long getUnigramSum() {
return unigramSum;
}
@Override
public LongRef getScratchValue() {
return new LongRef(-1);
}
@Override
public void setFromOtherValues(final ValueContainer<LongRef> o) {
super.setFromOtherValues(o);
this.countIndexer = ((CountValueContainer) o).countIndexer;
}
@Override
public void trim() {
super.trim();
countIndexer = null;
}
@Override
protected long getCountRank(long val) {
return countIndexer.get(val, -1);
}
}