package ids.framework; import ids.clustering.model.Clusters; import ids.clustering.model.View; import ids.clustering.utils.ClusterUtils; import ids.utils.CommonUtils; import java.io.Serializable; import java.util.ArrayList; import java.util.List; @SuppressWarnings("serial") public class MultiviewKmeans implements Serializable { // input parameters private int maximum_number_iterations = 30; private int minimum_number_iterations = 5; private int iteration_window = 3; // output private boolean verbose = false; private int[] finalIDX = null; // Utilities private CommonUtils utils; private ClusterUtils clusterUtils; // statistics public List<Double> view1_objF = new ArrayList<Double>(); public List<Double> view2_objF = new ArrayList<Double>(); // views data private View view1; private View view2; private int n = 0; // number of clusters int k = 2; public MultiviewKmeans(View view1, View view2, int k) { if (view1.data.length != view2.data.length) System.out.println("Different number of points in views"); // parameters this.view1 = view1; this.view2 = view2; this.k = k; this.n = view1.data.length; // initialization utils = new CommonUtils(false); clusterUtils = new ClusterUtils(false); } public MultiviewKmeans(View view1, View view2, int k, boolean verbose) { if (view1.data.length != view2.data.length) System.out.println("Different number of points in views"); // parameters this.view1 = view1; this.view2 = view2; this.k = k; this.n = view1.data.length; this.verbose = verbose; // initialization utils = new CommonUtils(this.verbose); clusterUtils = new ClusterUtils(this.verbose); } public int[] Cluster() { // randomly initialize centroids for second view double[][] centroids1 = null; double[][] centroids2 = clusterUtils.generateRandomClusterCentroids(view2.data, k); // do first E-step int[] idx1 = null; int[] idx2 = clusterUtils.getClusterMemberships(view2.data, centroids2, view2.distance); // value of the objective functions double obj_view1 = 0; double obj_view2 = 0; double view1_min = Integer.MAX_VALUE; double view2_min = Integer.MAX_VALUE; // run int t = 0; while (t < maximum_number_iterations) { t++; // VIEW 1 // M-step: Clusters cl1 = clusterUtils.getClusterCentoids(view1.data, idx2, k, view1.distance); centroids1 = cl1.centroids; // E-Step idx1 = clusterUtils.getClusterMemberships(view1.data, centroids1, view1.distance); // VIEW 2 // M-step: Clusters cl2 = clusterUtils.getClusterCentoids(view2.data, idx1, k, view2.distance); centroids2 = cl2.centroids; // E-step: idx2 = clusterUtils.getClusterMemberships(view2.data, centroids2, view2.distance); // re-compute the objective function for each view obj_view1 = clusterUtils.getKMeansObjectiveFunction(view1.data, centroids1, idx1, view1.distance); obj_view2 = clusterUtils.getKMeansObjectiveFunction(view2.data, centroids2, idx2, view2.distance); System.out.printf("Iteration %d:\tView1: %5.4f\tView2: %5.4f\n", t, obj_view1, obj_view2); // save the values view1_objF.add(obj_view1); view2_objF.add(obj_view2); // exit condition if (t > minimum_number_iterations) { // for view1 double[] view1_temp = new double[iteration_window]; double[] view2_temp = new double[iteration_window]; int counter = -1; for (int i = t; i > (t - iteration_window); i--) { counter++; view1_temp[counter] = view1_objF.get(i-1); view2_temp[counter] = view2_objF.get(i-1); } // find minimum double temp1 = utils.getMin(view1_temp); double temp2 = utils.getMin(view2_temp); if (temp1 < view1_min) { view1_min = temp1; } else { break; } if (temp2 < view2_min) { view2_min = temp2; } else { break; } } } // output System.out.println("Multiview Spherical K-means: Done on iteration: " + t); System.out.printf("Value of Obj. Function Domain 1: %5.4f (min: %5.4f)\n", obj_view1, view1_min); System.out.printf("Value of Obj. Function Domain 1: %5.4f (min: %5.4f)\n", obj_view2, view2_min); // find consensus membership int[] consensus = new int[n]; for (int i = 0; i < n; i++) { consensus[i] = Integer.MIN_VALUE; if (idx1[i] == idx2[i]) consensus[i] = idx1[i]; } // find consensus centroids Clusters cl1 = clusterUtils.getClusterCentoids(view1.data, consensus, k, view1.distance); Clusters cl2 = clusterUtils.getClusterCentoids(view2.data, consensus, k, view2.distance); double[][] cons_centroids1 = cl1.centroids; double[][] cons_centroids2 = cl2.centroids; // find consensus IDX int[] common_idx = new int[n]; // find maximum distance in data set if (view1.pd_max == -1.0) { view1.pd_max = utils.getMaxDistance(view1.data, view1.distance); if (verbose) System.out.printf("The maximum distance in View 1 is %5.4f\n", view1.pd_max); } if (view2.pd_max == -1.0) { view2.pd_max = utils.getMaxDistance(view2.data, view2.distance); if (verbose) System.out.printf("The maximum distance in View 2 is %5.4f\n", view2.pd_max); } // header if (verbose) { System.out.printf("Point\t"); for (int j = 0; j < k; j++) System.out.printf("%d\t", j); System.out.printf("Final Index\n"); } // run for (int i = 0; i < n; i++) { if (verbose) System.out.printf("%d\t", i); double f_min = Double.MAX_VALUE; int index_min = 0; for (int j = 0; j < k; j++) { double current_f = utils.getDistance(view1.data[i], cons_centroids1[j], view1.distance)/view1.pd_max + utils.getDistance(view2.data[i], cons_centroids2[j], view2.distance)/view1.pd_max; if (current_f < f_min) { f_min = current_f; index_min = j; } if (verbose) System.out.printf("%5.4f\t", current_f); } common_idx[i] = index_min; if (verbose) System.out.printf("%d\n", common_idx[i]); } this.finalIDX = common_idx; return common_idx; } // return centrois for view 1 public double[][] getCentroidsView1() { if (finalIDX == null) this.Cluster(); Clusters cl = clusterUtils.getClusterCentoids(view1.data, finalIDX, k, view1.distance); return cl.centroids; } // return centrois for view 2 public double[][] getCentroidsView2() { if (finalIDX == null) this.Cluster(); Clusters cl = clusterUtils.getClusterCentoids(view2.data, finalIDX, k, view2.distance); return cl.centroids; } // Objective functions public double[] getObjFunctionView1() { int r = view1_objF.size(); double[] res = new double[r]; for (int i = 0; i < r; i++) res[i] = view1_objF.get(i); return res; } public double[] getObjFunctionView2() { int r = view2_objF.size(); double[] res = new double[r]; for (int i = 0; i < r; i++) res[i] = view2_objF.get(i); return res; } }