package edu.stanford.nlp.stats;
import java.util.Random;
/**
* a multinomial distribution. pretty straightforward. specify the parameters with
* a counter. It is assumed that the Counter's keySet() contains all of the parameters (i.e., there are not other
* possible values which are set to 0). It makes a copy of the Counter, so tha parameters cannot be changes,
* and it normalizes the values if they are not already normalized.
*
* @author Jenny Finkel
*/
public class Multinomial<E> implements ProbabilityDistribution<E> {
/**
*
*/
private static final long serialVersionUID = -697457414113362926L;
private Counter<E> parameters;
public Multinomial(Counter<E> parameters) {
double totalMass = parameters.totalCount();
if (totalMass <= 0.0) {
throw new RuntimeException("total mass must be positive!");
}
this.parameters = new ClassicCounter<>();
for (E object : parameters.keySet()) {
double oldCount = parameters.getCount(object);
if (oldCount < 0.0) {
throw new RuntimeException("no negative parameters allowed!");
}
this.parameters.setCount(object, oldCount/totalMass);
}
}
public Counter<E> getParameters() {
return new ClassicCounter<>(parameters);
}
public double probabilityOf(E object) {
if (!parameters.keySet().contains(object)) {
throw new RuntimeException("Not a valid object for this multinomial!");
}
return parameters.getCount(object);
}
public double logProbabilityOf(E object) {
if (!parameters.keySet().contains(object)) {
throw new RuntimeException("Not a valid object for this multinomial!");
}
return Math.log(parameters.getCount(object));
}
public E drawSample(Random random) {
double r = random.nextDouble();
double sum = 0.0;
for (E object : parameters.keySet()) {
sum += parameters.getCount(object);
if (sum >= r) {
return object;
}
}
throw new RuntimeException("This point should never be reached");
}
@SuppressWarnings("unchecked")
@Override
public boolean equals(Object o) {
if (!(o instanceof Multinomial)) { return false; }
Multinomial otherMultinomial = (Multinomial)o;
return parameters.equals(otherMultinomial.parameters);
}
private int hashCode = -1;
@Override
public int hashCode() {
if (hashCode == -1) {
hashCode = parameters.hashCode() + 17;
}
return hashCode;
}
}