package org.seqcode.ml.clustering.affinitypropagation; import java.io.PrintStream; import java.util.HashMap; import java.util.Vector; import org.seqcode.ml.clustering.Clusterable; import org.seqcode.ml.clustering.ClusterablePair; /** * * @author reeder * */ public class APCluster { public static double cluster(Vector<Clusterable> objects, SimilarityMeasure<Clusterable> s, double lam, int convit, int maxit) { s.addNoise(); HashMap<ClusterablePair, Double> a = new HashMap<ClusterablePair, Double>(); HashMap<ClusterablePair, Double> r = new HashMap<ClusterablePair, Double>(); double max1 = SimilarityMeasure.NEGINF, max2 = SimilarityMeasure.NEGINF; int i1=0; double tmp = 0.0; int it = 0, decit = 0; boolean done = false; int[][] e = new int[s.size()][convit]; int[] se = new int[s.size()]; int decsumc = 0, decsum0 = 0; //initialize variables for (int i=0; i<s.size(); i++) { for (int j=0; j<s.size(); j++) { ClusterablePair ijpair = new ClusterablePair(objects.get(i), objects.get(j)); if (s.exists(ijpair)) { a.put(ijpair, 0.0); r.put(ijpair, 0.0); } } for (int j=0; j<convit; j++) { e[i][j] = 0; } se[i] = 0; } System.err.println("a size: "+a.size()); System.err.println("r size: "+r.size()); /* System.out.println("a:"); print(a); System.out.println(); System.out.println("r:"); print(r); System.out.println(); */ while (!done) { if (it%10 == 0) { System.err.println("Iteration: "+it); } //compute responsibilities for (int i=0; i<s.size(); i++) { max1 = SimilarityMeasure.NEGINF; max2 = SimilarityMeasure.NEGINF; i1=0; for (int k=0; k<s.size(); k++) { ClusterablePair ikpair = new ClusterablePair(objects.get(i), objects.get(k)); if (!s.exists(ikpair)) { //System.out.println("!exists"); continue; } tmp = a.get(ikpair) + s.evaluate(ikpair); //if (it==0) System.out.println(i+" "+k+" "+tmp); if (tmp > max1) { max2 = max1; max1 = tmp; i1 = k; } else if (tmp > max2) { max2 = tmp; } } System.err.println("MAX1: " + max1 + " MAX2: " + max2); for (int k=0; k<s.size(); k++) { ClusterablePair ikpair = new ClusterablePair(objects.get(i), objects.get(k)); if (!s.exists(ikpair)) continue; if (k==i1) { double message = lam*r.get(ikpair) + (1.0-lam)*(s.evaluate(ikpair) - max2); System.err.println("R " + objects.get(i).name() + " to " + objects.get(k).name() + ": " + message); r.put(ikpair, message); } else { double message = lam*r.get(ikpair) + (1.0-lam)*(s.evaluate(ikpair) - max1); System.err.println("R " + objects.get(i).name() + " to " + objects.get(k).name() + ": " + message); r.put(ikpair, message); } } } /* if (false) { System.out.println("max1: "+max1); System.out.println("max2: "+max2); System.out.println("i1: "+i1); for (int i=0; i<s.size(); i++) { for (int j=0; j<s.size(); j++) { if (!s.exists(objects.get(i), objects.get(j))) continue; System.out.print(r.get(objects.get(i).name()+"SEP"+objects.get(j).name())); System.out.print("\t"); } System.out.println(); } } /* if (it==0) { System.out.println("max1: "+max1); System.out.println("max2: "+max2); System.out.println("r("+it+"):"); print(r); System.out.println(); } */ //compute availabilities for (int k=0; k<s.size(); k++) { tmp = 0.0; for (int i=0; i<s.size(); i++) { ClusterablePair ikpair = new ClusterablePair(objects.get(i), objects.get(k)); if (!s.exists(ikpair)) continue; if ((i!=k)&&(r.get(ikpair)>0)) { tmp += r.get(ikpair); } } for (int i=0; i<s.size(); i++) { ClusterablePair ikpair = new ClusterablePair(objects.get(i), objects.get(k)); if (!s.exists(ikpair)) continue; double tmp2 = tmp; double ikr = r.get(ikpair); if ((i!=k)&&(ikr>0)) { tmp2 -= ikr; } if (i!=k) { tmp2 += r.get(new ClusterablePair(objects.get(k),objects.get(k))); } if (i==k) { double message = lam*a.get(ikpair) + (1.0 - lam)*tmp2; System.err.println("A " + objects.get(i).name() + " to " + objects.get(k).name() + ": " + message); a.put(ikpair, message); } else if (tmp2 < 0) { double message = lam*a.get(ikpair) + (1.0-lam)*tmp2; System.err.println("A " + objects.get(i).name() + " to " + objects.get(k).name() + ": " + message); a.put(ikpair, message); } else { double message = lam*a.get(ikpair); System.err.println("A " + objects.get(i).name() + " to " + objects.get(k).name() + ": " + message); a.put(ikpair, message); } } } //check for convergence for (int j = 0; j<s.size(); j++) { ClusterablePair jjpair = new ClusterablePair(objects.get(j), objects.get(j)); if (a.get(jjpair)+r.get(jjpair) > 0) { e[j][decit] = 1; } else { e[j][decit] = 0; } se[j] = 0; for (int d=0; d<convit; d++) { se[j] += e[j][d]; } } /* System.out.println("e("+it+"):"); print(e); */ decsumc = 0; decsum0 = 0; for (int j = 0; j<s.size(); j++) { if (se[j]==convit) { decsumc++; } else if (se[j]==0) { decsum0++; } } if (((decsumc+decsum0==s.size())&&(decsumc>0))||(it>=maxit)) { done = true; decit--; } it++; decit++; if (decit>=convit) { decit = 0; } } System.out.println("iterations: "+it); //compute assignments to exemplars //e[][decit] represents the exemplars //decsumc is the number of exemplars Vector<Integer> exidx = new Vector<Integer>(); int tmpidx = 0; int[] assgn = new int[s.size()]; for (int i=0; i<s.size(); i++) { if (e[i][decit]==1) { exidx.add(i); } } for (int i=0; i<s.size(); i++) { max1 = SimilarityMeasure.NEGINF; for (int j=0; j<exidx.size(); j++) { ClusterablePair ijpair = new ClusterablePair(objects.get(i), objects.get(exidx.get(j))); if (s.evaluate(ijpair) > max1) { max1 = s.evaluate(ijpair); assgn[i] = j; } } } for (int i=0; i<decsumc; i++) { assgn[exidx.get(i)] = i; } int[] intarry = new int[1]; s.putAssignments(assgn); s.putExemplars(toIntArray(exidx)); double netsim = 0.0; for (int i=0; i<s.size(); i++) { netsim += s.evaluate(new ClusterablePair(objects.get(i), objects.get(exidx.get(assgn[i])))); } return netsim; } public static int[] toIntArray(Vector<Integer> vec) { int[] toreturn = new int[vec.size()]; for (int i=0; i<toreturn.length; i++) { toreturn[i] = vec.get(i); } return toreturn; } private static void print(double[][] a) { for (int i=0; i<a.length; i++) { for (int j=0; j<a[i].length; j++) { System.out.print(a[i][j]+" "); } System.out.println(); } } private static void print(int[][] a) { for (int i=0; i<a.length; i++) { for (int j=0; j<a[i].length; j++) { System.out.print(a[i][j]+" "); } System.out.println(); } } /** * argv[0] = similarities file * argv[1] = preference value * argv[2] = results file * argv[3] = indices file * @param argv * @throws Exception */ public static void main(String[] argv) throws Exception { FileSimilarityMeasure<Clusterable> fsm = new FileSimilarityMeasure<Clusterable>(argv[0], " ", Double.valueOf(argv[1]).doubleValue()); double netsim = cluster(fsm.objects(), fsm, 0.5, 50, 500); PrintStream outstream = new PrintStream(argv[2]); fsm.printExemplars(outstream); outstream.println(); fsm.printAssignments(outstream); outstream.println(); outstream.println("Net Similarity: "+netsim); fsm.printClusterCenterIndices(new PrintStream(argv[3])); } }