package dr.math.distributions;
import dr.math.GammaFunction;
/**
* @author Marc A. Suchard
*/
public class DirichletDistribution implements MultivariateDistribution {
public static final String TYPE = "dirichletDistribution";
private double[] counts;
private double countSum = 0.0;
private int dim;
private double logNormalizingConstant;
public DirichletDistribution(double[] counts) {
this.counts = counts;
dim = counts.length;
for (int i = 0; i < dim; i++)
countSum += counts[i];
computeNormalizingConstant();
}
private void computeNormalizingConstant() {
logNormalizingConstant = GammaFunction.lnGamma(countSum);
for (int i = 0; i < dim; i++)
logNormalizingConstant -= GammaFunction.lnGamma(counts[i]);
}
public double logPdf(double[] x) {
if (x.length != dim) {
throw new IllegalArgumentException("data array is of the wrong dimension");
}
double logPDF = logNormalizingConstant;
for (int i = 0; i < dim; i++) {
logPDF += (counts[i] - 1) * Math.log(x[i]);
if (x[i] <= 0.0 || x[i] >= 1.0) {
logPDF = Double.NEGATIVE_INFINITY;
break;
}
}
return logPDF;
}
public double[][] getScaleMatrix() {
return null;
}
public double[] getMean() {
double[] mean = new double[dim];
for (int i = 0; i < dim; i++)
mean[i] = counts[i] / countSum;
return mean;
}
public String getType() {
return TYPE;
}
}