package edu.stanford.nlp.stats;
import java.util.Random;
/**
* simple dirichlet distribution.
*
* @author Jenny Finkel
*/
public class Dirichlet<E> implements ConjugatePrior<Multinomial<E>, E> {
private static final long serialVersionUID = 1L;
private Counter<E> parameters;
public Dirichlet(Counter<E> parameters) {
checkParameters(parameters);
this.parameters = new ClassicCounter<>(parameters);
}
private void checkParameters(Counter<E> parameters) {
for (E o : parameters.keySet()) {
if (parameters.getCount(o) < 0.0) {
throw new RuntimeException("Parameters must be non-negative!");
}
}
if (parameters.totalCount() <= 0.0) {
throw new RuntimeException("Parameters must have positive mass!");
}
}
public Multinomial<E> drawSample(Random random) {
return drawSample(random, parameters);
}
public static <F> Multinomial<F> drawSample(Random random, Counter<F> parameters) {
Counter<F> multParameters = new ClassicCounter<>();
double sum = 0.0;
for (F o : parameters.keySet()) {
double parameter = Gamma.drawSample(random, parameters.getCount(o));
sum += parameter;
multParameters.setCount(o, parameter);
}
for (F o : multParameters.keySet()) {
multParameters.setCount(o, multParameters.getCount(o)/sum);
}
return new Multinomial<>(multParameters);
}
// Faster sampling from a Dirichlet.
public static double[] drawSample(Random random, double[] parameters) {
double sum = 0.0;
double[] result = new double[parameters.length];
for(int i = 0; i < parameters.length; ++i) {
double parameter = Gamma.drawSample(random, parameters[i]);
sum += parameter;
result[i] = parameter;
}
for(int i = 0; i < parameters.length; ++i) {
result[i] /= sum;
}
return result;
}
public static double sampleBeta(double a, double b, Random random) {
Counter<Boolean> c = new ClassicCounter<>();
c.setCount(true, a);
c.setCount(false, b);
Multinomial<Boolean> beta = (new Dirichlet<>(c)).drawSample(random);
return beta.probabilityOf(true);
}
public double getPredictiveProbability(E object) {
return parameters.getCount(object) / parameters.totalCount();
}
public double getPredictiveLogProbability(E object) {
return Math.log(getPredictiveProbability(object));
}
public Dirichlet<E> getPosteriorDistribution(Counter<E> counts) {
Counter<E> newParameters = new ClassicCounter<>(parameters);
Counters.addInPlace(newParameters, counts);
return new Dirichlet<>(newParameters);
}
public double getPosteriorPredictiveProbability(Counter<E> counts, E object) {
double numerator = parameters.getCount(object) + counts.getCount(object);
double denominator = parameters.totalCount() + counts.totalCount();
return numerator / denominator;
}
public double getPosteriorPredictiveLogProbability(Counter<E> counts, E object) {
return Math.log(getPosteriorPredictiveProbability(counts, object));
}
public double probabilityOf(Multinomial<E> object) {
// TODO Auto-generated method stub
return 0;
}
// Quick hack method for metropolis
public static double unnormalizedLogProbabilityOf(double[] mult, double[] params) {
double sum = 0.0;
for(int i =0; i < params.length; ++i) {
if(mult[i] > 0)
sum += (params[i] -1 )* Math.log(mult[i]);
}
return sum;
}
public double logProbabilityOf(Multinomial<E> object) {
// TODO Auto-generated method stub
return 0;
}
@Override
public String toString() {
return Counters.toBiggestValuesFirstString(parameters, 50);
}
}