package ids.clustering.algorithm; import ids.clustering.model.Clusters; import ids.clustering.model.Distance; import ids.clustering.utils.ClusterUtils; import ids.utils.CommonUtils; import ids.utils.ConstraintsUtils; import ids.utils.FindMaxDistance; import ids.utils.SearchResult; import java.io.Serializable; import java.util.Arrays; import java.util.logging.Logger; import cern.colt.list.DoubleArrayList; import cern.colt.list.IntArrayList; import cern.colt.matrix.DoubleMatrix1D; import cern.colt.matrix.DoubleMatrix2D; import cern.colt.matrix.impl.SparseDoubleMatrix2D; @SuppressWarnings("serial") public class HMRFWKmeansMMD implements Serializable { // logger private Logger log = Logger.getLogger(getClass().getName()); // data private double[][] data; private int n = 0; private int dim = 0; private int k = 0; // number of clusters private boolean verbose = false; private boolean debug = false; // Utilities private CommonUtils utils = new CommonUtils(false); private ClusterUtils clusterUtils = new ClusterUtils(false); private FindMaxDistance mDistance; // distance private Distance distanceFunction = Distance.SQEUCLIDEAN; private double phi_d = 0; // cluster membership private int[] idx; private int[] prev_idx; public int[] getIDX() { return idx; } // cluster centroids private double[][] centroids = null; public double[][] getCentroids() { return centroids; } // value of the objective function private double objF = 0; public double getObjF() { return objF; } // cluster selection private boolean useTC = true; private boolean inferCannotLink = true; // number of iterations private int number_of_iterations = 0; // maximum number of iterations private int max_number_of_iterations = 100; // maximum number of iterations in ICM algorithm private int max_number_of_iterations_icm = 100; // constraints private DoubleMatrix2D Mconstraints = null; private DoubleMatrix2D Cconstraints = null; private double[][] constraints_list = null; // Constructors public HMRFWKmeansMMD(double[][] inputData, int numberOfClusters, HMRFKmeansParams par) { this.data = inputData; if (this.data == null) { System.out.println("Data set is not set."); System.exit(1); } this.k = numberOfClusters; this.n = this.data.length; this.dim = this.data[0].length; if (this.n == 0) { System.out.println("The number of data objects is 0."); System.exit(1); } if (this.dim == 0) { System.out.println("The dimension of data objects is 0."); System.exit(1); } if (this.k < 2) { System.out.println("Number of clusters cannot be set less then 2 (current os " + k + ")"); System.exit(1); } // parse parameters parseParams(par); // Initialization Initialization(); } private void Initialization() { // initialize the membership vector idx = new int[n]; prev_idx = new int[n]; for (int i=0; i<n; i++) { idx[i] = -1; prev_idx[i] = -1; } // find the maximum distance in the data set if (phi_d == -1) { mDistance = new FindMaxDistance(false); phi_d = mDistance.getMaxDistance(data, n, dim, distanceFunction); } if (debug) log.info("Maximum distance in data set is " + phi_d); // deal with cluster initialization if (centroids != null) { System.out.println("Using user specified cluster centroids"); } else { if (constraints_list != null) { // parse the input constraints parseConstraints(); // cluster selection if (useTC) { // use transitive closure to infer must-link constraints ConstraintsUtils constraintsUtils = new ConstraintsUtils(verbose); Mconstraints = constraintsUtils.TransitiveClosure(Mconstraints); // get neighborhoods int[] neighborhoodLambda = constraintsUtils.getNeighborhood(); // number of neighborhoods int lambda = utils.getMaxValue(neighborhoodLambda).getValue() + 1; if (debug) log.info("Number of lambda neighborhoods is " + lambda); // infer cannot-link constraints if (inferCannotLink) constraintsUtils.inferCannotLinkConstraints(Cconstraints, neighborhoodLambda, lambda); // get new centroids based on lambda neigh centroids = clusterSelection(neighborhoodLambda, lambda); } } } // if centroids are still null then generate them randomly if (centroids == null) { centroids = clusterUtils.generateRandomClusterCentroids(data, k); System.out.println("Random centrois has been generated"); } // cluster memberships without any constraints idx = clusterUtils.getClusterMemberships(data, centroids, distanceFunction); } private void parseParams(HMRFKmeansParams par) { // verbose this.verbose = par.verbose; this.debug = par.debug; // distance this.distanceFunction = par.distanceFunction; this.phi_d = par.phi_d; // max iterations this.max_number_of_iterations = par.max_number_of_iterations; this.max_number_of_iterations_icm = par.max_number_of_iterations_icm; // constraints this.constraints_list = par.constraints; // centroids this.centroids = par.centeroids; // cluster selection this.useTC = par.useTC; } private void parseConstraints() { Mconstraints = new SparseDoubleMatrix2D(n, n); Cconstraints = new SparseDoubleMatrix2D(n, n); for (int i = 0; i < constraints_list.length; i++) { int a = (int)constraints_list[i][0] - 1; // -1 since data ID starts from 0 but from 1 in the constant file int b = (int)constraints_list[i][1] - 1; if (((int)constraints_list[i][2])==1) { // must-link constraints Mconstraints.setQuick(a, b, 1.0); Mconstraints.setQuick(b, a, 1.0); if (debug) { log.info("Creating must-link constraint between object " + (int)constraints_list[i][0] + " and " + (int)constraints_list[i][1]); } } else if (((int)constraints_list[i][2])==2) { // cannot-link constraints Cconstraints.setQuick(a, b, 1.0); Cconstraints.setQuick(b, a, 1.0); if (debug) { log.info("Creating cannot-link constraint between object " + (int)constraints_list[i][0] + " and " + (int)constraints_list[i][1]); } } else { log.severe("Cannot find constraints type"); } } // end for loop } // Cluster - main procedure public void cluster() { while (true) { // one cycle number_of_iterations++; if (debug) log.info("Start iteration " + number_of_iterations); // EM algorithm E_step(); M_step(); // stop condition if (Arrays.equals(idx, prev_idx)) break; if (number_of_iterations >= max_number_of_iterations) { if (debug) System.out.println("The maximum number of iterations (" + max_number_of_iterations + ") has been reached"); break; } prev_idx = idx.clone(); } // output if (verbose) { System.out.println(""); System.out.println("HMRF-Kmeans:"); System.out.println("Number of clusters: " + k); System.out.println("Number of samples: " + n + ", with " + dim + " features"); System.out.println("Distance funciton: " + distanceFunction); if (constraints_list != null) System.out.println("Number of constraints: " + constraints_list.length); System.out.println("Number of iterations: " + number_of_iterations); System.out.println("Objective function value: " + objF); System.out.println(""); } } private void clusterInitialization() { if ((useTC)&(constraints_list != null)) { // use transitive closure to infer must-link constraints ConstraintsUtils constraintsUtils = new ConstraintsUtils(verbose); Mconstraints = constraintsUtils.TransitiveClosure(Mconstraints); // get neighborhoods int[] neighborhoodLambda = constraintsUtils.getNeighborhood(); // number of neighborhoods int lambda = utils.getMaxValue(neighborhoodLambda).getValue() + 1; if (debug) log.info("Number of lambda neighborhoods is " + lambda); // infer cannot-link constraints if (inferCannotLink) constraintsUtils.inferCannotLinkConstraints(Cconstraints, neighborhoodLambda, lambda); // get new centroids based on lambda neigh centroids = clusterSelection(neighborhoodLambda, lambda); } } private double[][] clusterSelection(int[] nLambda, int lambda) { double[][] c = new double[k][dim]; System.out.println("Initializing clusters centroids using constraint list"); if (lambda > k) { Clusters cluster_lambda = clusterUtils.getClusterCentoids(data, n, dim, nLambda, lambda, distanceFunction); double[][] c_lambda = cluster_lambda.centroids; // find the largest subgraph and normalized it int[] w = cluster_lambda.clusterSizes; double[] dw = new double[lambda]; SearchResult<Integer> sr = utils.getMaxValue(w); int max_w = sr.getValue(); if (max_w != 0) { for (int i = 0; i < w.length; i++) { dw[i] = (double)w[i]/max_w; } } else { log.severe("ACHTUNG! Division by zero! The maximum neighborhood size is 0."); } // do the weighted farther first algorithm int[] c_index = new int[k]; for (int i = 0; i < k; i++) c_index[i] = -1; int cnc = 0; // current number of clusters c_index[cnc] = sr.getIndex(); // first cluster is the biggest cluster c[cnc] = c_lambda[sr.getIndex()]; // distance between cluster`s centroids double[] dc = new double[lambda]; while ((cnc+1) < k) { for (int i = 0; i < lambda; i++) { // the current candidate double[] x = c_lambda[i]; dc[i] = 0; if (!utils.isContained(c_index, cnc+1, i)) { for (int j = 0; j < cnc+1; j++) { double[] y = c_lambda[c_index[j]]; dc[i] += dw[c_index[j]] * dw[i] * utils.getDistance(x, y, distanceFunction); } } } cnc++; SearchResult<Double> srd = utils.getMaxValue(dc); c_index[cnc] = srd.getIndex(); c[cnc] = c_lambda[srd.getIndex()]; } } else { int[] index = utils.getRandomPermutation(n, k-lambda); double[][] temp = new double[lambda][dim]; Clusters cl = clusterUtils.getClusterCentoids(data, n, dim, nLambda, lambda, distanceFunction); temp = cl.centroids; for (int i = 0; i < lambda; i++) { c[i] = temp[i]; } for (int i = 0; i < k-lambda; i++) { c[lambda + i] = data[index[i]]; } } return c; } private void E_step() { int[] prev_idx_icm = idx.clone(); // number of iterations int v_tt = 0; while (true) { v_tt++; objF = 0; if (debug) log.info("Starting ICM iteration " + v_tt); // randomize points the data set int[] rand_index = utils.getRandomPermutation(n, n); for (int i = 0; i<n; i++) { int current_point_index = rand_index[i]; double f_min = Double.MAX_VALUE; int j_min = 0; // assign points to clusters for (int j=0; j<k; j++) { double f = getObjectiveFunction(current_point_index, j); //System.out.println("Distance is " + f); if (f<f_min) { f_min = f; j_min = j; } } //System.out.println("Adding " + f_min + " to the objective function " + objF); idx[current_point_index] = j_min; objF = objF + f_min; } // break loop if (Arrays.equals(prev_idx_icm, idx)) { if (debug) { log.info("ICM algorithm has been finished in (" + v_tt + ") iterations"); log.info("Value of the objective function is " + objF); } break; } if (v_tt >= max_number_of_iterations_icm) { log.warning("Reached maximum number of iterations (" + max_number_of_iterations_icm + ") in ICM algorithm"); break; } prev_idx_icm = idx.clone(); } if (debug) { log.info("E-Step: the value of the objective function: " + objF); //System.out.println("Membership vector: "); //Arrays.toString(idx); } } private void M_step() { Clusters clusters = clusterUtils.getClusterCentoids(data, n, dim, idx, k, distanceFunction); centroids = clusters.centroids; if (debug) { log.info("M-Step:"); log.info("Cluster centroids: "); utils.printMatrix(centroids, k, dim); } } private double getObjectiveFunction(int point_index, int cluster_index) { double kmeans_part = 0; double must_link_part = 0; double cannot_link_part = 0; // simple k-means kmeans_part = utils.getDistance(data[point_index], centroids[cluster_index], distanceFunction); if (constraints_list == null) return kmeans_part; // get clustering label of the current point in the other domain /* double w_common = 1; if (useOtherDomain) { w_common = utils.getDistance(data[point_index], centroids[cluster_index], distanceFunction)/ (utils.getDistance(data[point_index], centroids[domainIDX[point_index]], distanceFunction) + Double.MIN_VALUE); if (debug) System.out.println("w_common: " + w_common); } */ // get must-link constraints for the current point DoubleMatrix1D target = Mconstraints.viewRow(point_index); IntArrayList cols = new IntArrayList(); DoubleArrayList values = new DoubleArrayList(); target.getNonZeros(cols, values); // must-link constraints part for (int i = 0; i < cols.size(); i++) { // native domain int other_point = cols.get(i); int cluster_index_other_point = idx[other_point]; //double[] centroid_other_point = centroids[cluster_index_other_point]; if ((cluster_index != cluster_index_other_point)&(point_index != other_point)) { /* double w_final = 1; if (useOtherDomain) { // other domain double[] centroid_other_point_domainIDX = centroids[domainIDX[other_point]]; double w = utils.getDistance(data[other_point], centroid_other_point, distanceFunction)/ (utils.getDistance(data[other_point], centroid_other_point_domainIDX, distanceFunction) + Double.MIN_VALUE); w_final = (w_common + w)/2; if (w_final > 1) { if (debug) System.out.println("The current weight is higher than 1: " + w_final + "\nSetting it to 1"); w_final = 1; } } */ double w_final = 0;//0.4391/0.3937; must_link_part += w_final*utils.getDistance(data[point_index], data[other_point], distanceFunction); if (debug) log.info("Must-link: Objective function has been modified - points " + point_index + " and " + other_point + " in cluster " + cluster_index + ". Current value of: " + must_link_part + ", with w = " + w_final); } } // get cannot-link constraints for the current point target = Cconstraints.viewRow(point_index); cols = new IntArrayList(); values = new DoubleArrayList(); target.getNonZeros(cols, values); // cannot-link constraints part for (int i = 0; i < cols.size(); i++) { // native domain int other_point = cols.get(i); int cluster_index_other_point = idx[other_point]; //double[] centroid_other_point = centroids[cluster_index_other_point]; if (cluster_index == cluster_index_other_point) { /* double w_final = 1; if (useOtherDomain) { // other domain double[] centroid_other_point_domainIDX = centroids[domainIDX[other_point]]; double w = utils.getDistance(data[other_point], centroid_other_point, distanceFunction)/ (utils.getDistance(data[other_point], centroid_other_point_domainIDX, distanceFunction) + Double.MIN_VALUE); //w_final = 0.4391/0.3937; if (w_final > 1) { if (debug) System.out.println("The current weight is higher than 1: " + w_final + "\nSetting it to 1"); w_final = 1; } } */ double w_final = 0;//0.4391/0.3937; cannot_link_part += w_final*(phi_d - utils.getDistance(data[point_index], data[other_point], distanceFunction)); if (debug) log.info("Cannot-link: Objective function has been modified - points " + point_index + " and " + other_point + " in cluster " + cluster_index + ". Current value of: " + cannot_link_part + ", with w: " + w_final); } } // return the sum of all components if (debug) { //if (kmeans_part >0) System.out.println("k-means part of the distance is " + kmeans_part); if (must_link_part >0) System.out.println("must-link part of the distance is " + must_link_part); if (cannot_link_part >0) System.out.println("cannot-link part of the distance is " + cannot_link_part); } return kmeans_part; //return kmeans_part + must_link_part + cannot_link_part; } }