package ids.framework; import ids.clustering.algorithm.HMRFKmeansParams; import ids.clustering.algorithm.HMRFKmeansU; import ids.clustering.model.Clusters; import ids.clustering.model.Distance; import ids.clustering.model.Domain; import ids.clustering.model.ObjectiveFunctionType; import ids.clustering.utils.ClusterUtils; import ids.utils.CommonUtils; public class EnsembleBased { private Domain domain1; private Domain domain2; int numberOfInstancesDomain1 = 1; int numberOfInstancesDomain2 = 1; private int[][] groupIDX; private int[] finalIDX; public boolean debug = false; // Utilities CommonUtils utils = new CommonUtils(false); ClusterUtils clusterUtils = new ClusterUtils(false); public EnsembleBased(Domain d1, Domain d2, int run_d1, int run_d2) { this.domain1 = d1; this.domain2 = d2; this.numberOfInstancesDomain1 = run_d1; this.numberOfInstancesDomain2 = run_d2; // cluster groupIDX = cluster(); } public EnsembleBased(Domain d1, int run_d1) { this.domain1 = d1; this.domain2 = null; this.numberOfInstancesDomain1 = run_d1; this.numberOfInstancesDomain2 = 0; // cluster groupIDX = clusterDomain1(); } private int[][] clusterDomain1() { // number of data objects int n = domain1.data.length; // total number of algorithms that will be used in clustering int totalNumberOfInstances = numberOfInstancesDomain1; // final IDX holder int[][] algorithmIDX = new int[n][totalNumberOfInstances]; // do domain 1 first for (int i = 0; i < numberOfInstancesDomain1; i++) { HMRFKmeansParams domain1_par = new HMRFKmeansParams(); domain1_par.distanceFunction = domain1.distance; domain1_par.obj_type = ObjectiveFunctionType.CENTROIDS; domain1_par.verbose = true; HMRFKmeansU kmeans = new HMRFKmeansU(domain1.data, domain1.k, domain1_par); kmeans.cluster(); int[] idx = kmeans.getIDX(); for (int j = 0; j < n; j++) algorithmIDX[j][i] = idx[j]; } return algorithmIDX; } /** * Cluster domain 1 and domain 2 * @return */ private int[][] cluster() { if (domain1.data.length != domain2.data.length) { System.out.println("Data in domains has been the same length"); return null; } int n = domain1.data.length; // total number of algorithms that will be used in clustering int totalNumberOfInstances = numberOfInstancesDomain1 + numberOfInstancesDomain2; // final IDX holder int[][] algorithmIDX = new int[n][totalNumberOfInstances]; int instancesCounter = -1; // do domain 1 first for (int i = 0; i < numberOfInstancesDomain1; i++) { instancesCounter++; HMRFKmeansParams domain1_par = new HMRFKmeansParams(); domain1_par.distanceFunction = domain1.distance; domain1_par.obj_type = ObjectiveFunctionType.CENTROIDS; domain1_par.verbose = true; HMRFKmeansU kmeans = new HMRFKmeansU(domain1.data, domain1.k, domain1_par); kmeans.cluster(); int[] idx = kmeans.getIDX(); for (int j = 0; j < n; j++) algorithmIDX[j][instancesCounter] = idx[j]; } // same thing for domain 2 for (int i = 0; i < numberOfInstancesDomain2; i++) { instancesCounter++; HMRFKmeansParams domain2_par = new HMRFKmeansParams(); domain2_par.distanceFunction = domain2.distance; domain2_par.obj_type = ObjectiveFunctionType.CENTROIDS; domain2_par.verbose = true; HMRFKmeansU kmeans = new HMRFKmeansU(domain2.data, domain2.k, domain2_par); kmeans.cluster(); int[] idx = kmeans.getIDX(); for (int j = 0; j < n; j++) algorithmIDX[j][instancesCounter] = idx[j]; } return algorithmIDX; } /** * This function use voting method as a consensus function * @return final cluster membership */ public int[] Do_Voting() { if (groupIDX == null) cluster(); finalIDX = null; int n = groupIDX.length; int totalNumberOfInstances = groupIDX[0].length; int[] res = new int[n]; // solve clustering correspondence if (debug) System.out.print("Solving cluster correspondence problem.."); int[][] temp = new int[n][totalNumberOfInstances]; int[] firstColumn = utils.getColumn(groupIDX, 0); utils.fillColumn(temp, firstColumn, 0); for (int index = 1; index < totalNumberOfInstances; index++) { int[] currentColumn = utils.getColumn(groupIDX, index); int[] q = clusterUtils.findClusterCorrespondence(firstColumn, currentColumn); if (q != null) { utils.fillColumn(temp, q, index); } else { utils.fillColumn(temp, currentColumn, index); } } if (debug) System.out.println("done."); // find mode for (int i = 0; i < n; i++) res[i] = utils.getMode(temp[i]); // output if (debug) { int v_tt = 20; if (v_tt > n) v_tt = n; for (int i = 0; i < v_tt; i++) { System.out.printf("%d.\t", i); for (int j = 0; j < totalNumberOfInstances; j++) { System.out.printf("%d\t", temp[i][j]); } System.out.printf("%d\n", res[i]); } } finalIDX = res; return finalIDX; } /** * Use clustering method (kmodes) as a consensus function * @param final_k * @return final cluster membership */ public int[] Do_Clustering(int final_k) { if (groupIDX == null) cluster(); finalIDX = null; int n = groupIDX.length; int[] res = new int[n]; HMRFKmeansParams par = new HMRFKmeansParams(Distance.MATCH, ObjectiveFunctionType.CENTROIDS); par.verbose = debug; HMRFKmeansU kmeans = new HMRFKmeansU(groupIDX, final_k, par); kmeans.cluster(); res = kmeans.getIDX(); if (debug) { int v_tt = 20; if (v_tt > n) v_tt = n; System.out.println("Centroids:"); utils.printMatrix(kmeans.getCentroids()); System.out.println("Membership matrix"); for (int i = 0; i < v_tt; i++) { System.out.printf("%d.\t", i); for (int j = 0; j < groupIDX[0].length; j++) { System.out.printf("%d\t", groupIDX[i][j]); } System.out.printf("%d\n", res[i]); } } finalIDX = res; return finalIDX; } /** * Returns domain`s centroids * @return */ public double[][] getCentroidD1() { return getDomainCentroids(domain1); } public double[][] getCentroidD2() { return getDomainCentroids(domain2); }; private double[][] getDomainCentroids(Domain d) { if (finalIDX != null) { Clusters cl = clusterUtils.getClusterCentoids(d.data, finalIDX, d.k, d.distance); return cl.centroids; } System.out.println("PLese cluster first"); return null; } public void clear() { this.groupIDX = null; this.finalIDX = null; } public void printMembershipMatrix() { if ((groupIDX != null)&&(finalIDX != null)) { int n = domain1.data.length; int m = groupIDX[0].length; // print header System.out.printf("#\t"); for (int j = 0; j < m; j++) System.out.printf("%d\t", j); System.out.printf("Final IDX\n"); // print data for (int i = 0; i < n; i++) { System.out.printf("%d.\t", i); for (int j = 0; j < m; j++) { System.out.printf("%d\t", groupIDX[i][j]); } System.out.printf("%d\n", finalIDX[i]); } } } }