package edu.fudan.ml.types.featurecluster;
import java.text.DecimalFormat;
public class JSDistance extends AbstractDistance {
double max = 999999;
DecimalFormat df = new DecimalFormat("0.###E0");
public double cal(ClassData cd1, ClassData cd2) {
if (checkLabelLength(cd1, cd2))
return calJSDistance(cd1, cd2);
else
return Double.MAX_VALUE;
}
private boolean checkLabelLength(ClassData cd1, ClassData cd2) {
return cd1.getLabel().length == cd2.getLabel().length;
}
private double calJSDistance(ClassData cd1, ClassData cd2) {
double p1 = (double)cd1.getCount() / (double)cd1.getAllCount();
double p2 = (double)cd2.getCount() / (double)cd2.getAllCount();
double[] averageLabel = calAverageLabel(cd1, cd2);
double kl1 = klDistance(cd1.getLabel(), averageLabel);
double kl2 = klDistance(cd2.getLabel(), averageLabel);
double distance = p1 * kl1 + p2 * kl2;
// distance = format(distance);
return distance;
}
private double klDistance(double[] label1, double[] label2) {
double distance = 0;
for (int i = 0; i < label1.length; i++) {
if (label1[i] == 0)
continue;
else if (label2[i] == 0)
distance += max;
else {
double tempDistance = label1[i] * Math.log(label1[i] / label2[i]);
distance += tempDistance;
}
}
return distance;
}
private double[] calAverageLabel(ClassData cd1, ClassData cd2) {
double[] label1 = cd1.getLabel();
double[] label2 = cd2.getLabel();
int length = label1.length;
double[] label = new double[length];
double pi = (double)cd1.getCount() / (double)(cd1.getCount() + cd2.getCount());
for (int i = 0; i < length; i++)
label[i] = pi * label1[i] + (1-pi) * label2[i];
return label;
}
private double format (double v) {
String s = df.format(v);
return Double.parseDouble(s);
}
}