package ids.utils; import java.io.Serializable; import org.apache.commons.math3.special.Erf; @SuppressWarnings("serial") public class MMDUtils implements Serializable { private double mmd = 0; private double eps = 0; // Constructors public MMDUtils() {} public MMDUtils(double[] x, double[] y, double sigma, double alpha) { findMMD(x, y, sigma, alpha); } public MMDUtils(double[][] x, double[][] y, double sigma, double alpha) { findMMD(x, y, sigma, alpha); } public void findMMD(double[][] x, double[][] y, double sigma, double alpha) { double m = x.length; if (m != y.length) { System.out.println("Data sets has different number of points"); } double mmd_sq = 0; double sigma_sq = 0; double t = 0; double th = 0; // find MMD for (int i = 0; i < m; i++) { t = 0; for (int j = 0; j < m; j++) { if (i != j) { t = t + getGaussianKernel(x[i], x[j], sigma) + getGaussianKernel(y[i], y[j], sigma) - getGaussianKernel(x[i], y[j], sigma) - getGaussianKernel(x[j], y[i], sigma); } } mmd_sq = mmd_sq + 1/m/(m-1)*t; sigma_sq = sigma_sq + t*t; } sigma_sq = 4*sigma_sq/m/m/(m-1)/(m-1) - 4/m*mmd_sq*mmd_sq; th = Math.sqrt(2*sigma_sq)*Erf.erfInv(1-2*alpha); this.mmd = mmd_sq; this.eps = th; if (this.mmd <= this.eps) { System.out.println("Distributions are the same"); } else { System.out.println("Distributions are different"); } } public void findMMD(double[] x, double[] y, double sigma, double alpha) { double m = x.length; if (m != y.length) { System.out.println("Data sets has different number of points"); } double mmd_sq = 0; double sigma_sq = 0; double t = 0; double th = 0; // find MMD for (int i = 0; i < m; i++) { t = 0; for (int j = 0; j < m; j++) { if (i != j) { t = t + getGaussianKernel(x[i], x[j], sigma) + getGaussianKernel(y[i], y[j], sigma) - getGaussianKernel(x[i], y[j], sigma) - getGaussianKernel(x[j], y[i], sigma); } } mmd_sq = mmd_sq + 1/m/(m-1)*t; sigma_sq = sigma_sq + t*t; } sigma_sq = 4*sigma_sq/m/m/(m-1)/(m-1) - 4/m*mmd_sq*mmd_sq; th = Math.sqrt(2*sigma_sq)*Erf.erf(1-2*alpha); this.mmd = mmd_sq; this.eps = th; if (this.mmd <= this.eps) { System.out.println("Distributions are the same"); } else { System.out.println("Distributions are different"); } } public double getMMD() { return this.mmd; } public double getEps() { return this.eps; } // Gaussian Kernel private double getGaussianKernel(double x[], double y[], double sigma) { if (sigma == 0) { System.out.println("Sigma is set to 0"); return -1.0; } int n = x.length; if (n != y.length) { System.out.println("Vector x and y have different length"); return -1.0; } double c[] = new double[n]; double ss = 0; for (int i = 0; i < n; i++) { c[i] = x[i] - y[i]; ss = ss + c[i]*c[i]; } return Math.exp(-ss/2/sigma/sigma); } private double getGaussianKernel(double x, double y, double sigma) { if (sigma == 0) { System.out.println("Sigma is set to 0"); return -1.0; } return Math.exp(-(x-y)*(x-y)/2/sigma/sigma); } }