package edu.berkeley.nlp.lm.values;
import edu.berkeley.nlp.lm.ConfigOptions;
import edu.berkeley.nlp.lm.bits.BitUtils;
import edu.berkeley.nlp.lm.collections.LongRepresentable;
public class ProbBackoffPair implements Comparable<ProbBackoffPair>, LongRepresentable<ProbBackoffPair>
{
static final int MANTISSA_MASK = 0x7fffff;
static final int REST_MASK = ~MANTISSA_MASK;
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result + Float.floatToIntBits(prob);
result = prime * result + Float.floatToIntBits(backoff);
return result;
}
@Override
public boolean equals(final Object obj) {
if (this == obj) return true;
if (obj == null) return false;
if (getClass() != obj.getClass()) return false;
final ProbBackoffPair other = (ProbBackoffPair) obj;
if (Float.floatToIntBits(prob) != Float.floatToIntBits(other.prob)) return false;
if (Float.floatToIntBits(backoff) != Float.floatToIntBits(other.backoff)) return false;
return true;
}
public ProbBackoffPair(final long probBackoff) {
this(probOf(probBackoff), backoffOf(probBackoff));
}
public ProbBackoffPair(final float logProb, final float backoff) {
this.prob = round(logProb, ConfigOptions.roundBits);
this.backoff = round(backoff, ConfigOptions.roundBits);
}
private float round(final float f, final int mantissaBits) {
if (Float.isInfinite(f)) return f;
final int bits = Float.floatToIntBits(f);
final int mantissa = bits & MANTISSA_MASK;
final int rest = bits & REST_MASK;
final int highestBit = Integer.highestOneBit(mantissa);
int mask = highestBit;
for (int i = 0; i < mantissaBits; ++i) {
mask >>>= 1;
mask |= highestBit;
}
final int maskedMantissa = mantissa & mask;
final float newFloat = Float.intBitsToFloat(rest | maskedMantissa);
assert Float.isNaN(f) || (Math.abs(f - newFloat) <= 1e-3f) : "Rounding went bad for float " + f + " and rounded " + newFloat;
return newFloat;
}
@Override
public String toString() {
return "[FloatPair first=" + prob + ", second=" + backoff + "]";
}
public float prob;
public float backoff;
@Override
public int compareTo(final ProbBackoffPair arg0) {
final int c = Float.compare(prob, arg0.prob);
if (c != 0) return c;
return Float.compare(backoff, arg0.backoff);
}
@Override
public long asLong() {
return floatsToLong(prob, backoff);
}
/**
* @param prob
* @param backoff
* @return
*/
public static long floatsToLong(final float prob, final float backoff) {
final int probBits = Float.floatToIntBits(prob);
final int backoffBits = Float.floatToIntBits(backoff);
return BitUtils.combineInts(probBits, backoffBits);
}
public static float probOf(long key) {
return Float.intBitsToFloat(BitUtils.getLowInt(key));
}
public static float backoffOf(long key) {
return Float.intBitsToFloat(BitUtils.getHighInt(key));
}
}