package dr.geo.distributions;
import dr.geo.math.Space;
import dr.math.LogTricks;
/**
* @author Marc Suchard
*/
public class VonMisesFisherDistribution extends HyperSphereDistribution {
public VonMisesFisherDistribution(int dim, Space space, double[] mean, double kappa) {
super(dim, space, mean, kappa);
}
public double logPdf(double[] x) {
return logPdf(x, mean, kappa, space);
}
public String getType() {
return "von Mises-Fisher";
}
protected int getAllowableDim() {
return 3;
}
private static double logNormalizationConstant(double kappa) {
// 'sinh' has some numerical instability for small arguments
if (kappa < 1E-10) {
return -Math.log(2) - LOG_2_PI;
}
return Math.log(kappa) - LOG_2_PI - LogTricks.logDiff(kappa, -kappa);
// return Math.log(kappa) - LOG_2_PI - Math.log(Math.exp(+kappa) - Math.exp(-kappa));
}
public static double logPdf(double[] x, double[] mean, double kappa, Space space) {
return logNormalizationConstant(kappa) + kappa * HyperSphereDistribution.innerProduct(x, mean, space);
}
public static void main(String[] arg) {
// Test in cartesian coordinates
double kappa1 = 1;
double[] mean1 = { 1,0,0 };
double[] x1 = { 0,1,0 };
System.err.println("logP = "+logPdf(x1, mean1, kappa1, Space.CARTESIAN)+" ?= -2.692464\n");
// Test in (lat,long) coordinates
double kappa2 = 1;
double[] mean2 = { 0,0 };
double[] x2 = { 0,180 };
System.err.println("logP = "+logPdf(x2, mean2, kappa2, Space.LAT_LONG )+" ?= -3.692464\n");
// Test in (lat,long) coordinates
double kappa3 = 2;
double[] mean3 = { 90,0 };
double[] x3 = { 0,180 };
System.err.println("logP = "+logPdf(x3, mean3, kappa3, Space.LAT_LONG )+" ?= -3.126244");
}
}
// R test code
//
//x = c(1,0,0)
//y = c(0,1,0)
//
//vmf = function(x, mean, kappa, p) {
// order = p/2 - 1
// norm = kappa^(order) / (2*pi)^(p/2) / besselI(kappa,order)
// norm * exp(kappa * sum(x * mean))
//}
//
//log(vmf(x,y,1,3))
//
//pt1 = c(0,1,0) # lat_long (0,0)
//pt2 = c(0,-1,0) # lat_long (0,180)
//
//log(vmf(pt1,pt2,1,3))
//
//pt3 = c(0,0,1) # lat_long (90,0)
//pt4 = c(0,-1,0) # lat_long (0,180)
//
//log(vmf(pt3,pt4,2,3))