package clear.util.cluster;
import clear.experiment.SRLVerbCluster;
import clear.util.tuple.JIntDoubleTuple;
import clear.util.tuple.JIntIntTuple;
import clear.util.tuple.JObjectDoubleTuple;
import clear.util.tuple.JObjectObjectTuple;
import com.carrotsearch.hppc.IntIntOpenHashMap;
import com.carrotsearch.hppc.IntOpenHashSet;
import com.carrotsearch.hppc.ObjectDoubleOpenHashMap;
import com.carrotsearch.hppc.cursors.ObjectCursor;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
public class SRLClusterBuilder {
// 0.78, 0.76
double d_hm_lower = 0.78; // 0.80 same
double d_km_lower = 0.78;
ObjectDoubleOpenHashMap<String> d_similarities;
ArrayList<ProbCluster> k_clusters;
public SRLClusterBuilder() {
}
public SRLClusterBuilder(double threshold) {
d_km_lower = threshold;
}
public ArrayList<ProbCluster> getInitClusters(HashMap<String, ObjectDoubleOpenHashMap<String>> map) {
ArrayList<ProbCluster> clusters = new ArrayList<>();
ProbCluster cluster;
for (String key : map.keySet()) {
cluster = new ProbCluster(key);
cluster.add(map.get(key));
clusters.add(cluster);
}
return clusters;
}
/**
* @return average cosine similarity between two clusters.
*/
public double getCtrSimilarity(ProbCluster cluster1, ProbCluster cluster2, boolean useDynamic) {
String key = getJoinedKey(cluster1, cluster2);
if (useDynamic && d_similarities.containsKey(key)) {
return d_similarities.get(key);
}
ObjectDoubleOpenHashMap<String> ctr1 = getCentroid(cluster1);
ObjectDoubleOpenHashMap<String> ctr2 = getCentroid(cluster2);
double sim = getCosineSimilarity(ctr1, ctr2);
if (useDynamic) {
d_similarities.put(key, sim);
}
return sim;
}
public ObjectDoubleOpenHashMap<String> getCentroid(ProbCluster cluster) {
ObjectDoubleOpenHashMap<String> centroid = new ObjectDoubleOpenHashMap<>();
String key;
for (ObjectDoubleOpenHashMap<String> map : cluster) {
for (ObjectCursor<String> cur : map.keys()) {
key = cur.value;
centroid.put(key, centroid.get(key) + map.get(key));
}
}
int size = cluster.size();
for (ObjectCursor<String> cur : centroid.keys()) {
key = cur.value;
centroid.put(key, centroid.get(key) / size);
}
return centroid;
}
/**
* @return average cosine similarity between two clusters.
*/
public double getAvgSimilarity(ProbCluster cluster1, ProbCluster cluster2, boolean useDynamic) {
String key = getJoinedKey(cluster1, cluster2);
if (useDynamic && d_similarities.containsKey(key)) {
return d_similarities.get(key);
}
double avg = 0;
for (ObjectDoubleOpenHashMap<String> map1 : cluster1) {
for (ObjectDoubleOpenHashMap<String> map2 : cluster2) {
avg += getCosineSimilarity(map1, map2);
}
}
avg /= (cluster1.size() * cluster2.size());
if (useDynamic) {
d_similarities.put(key, avg);
}
return avg;
}
/**
* @return cosine similarity of two maps.
*/
public double getCosineSimilarity(ObjectDoubleOpenHashMap<String> map1, ObjectDoubleOpenHashMap<String> map2) {
double dot = 0, scala1 = 0, scala2 = 0, val;
String key;
for (ObjectCursor<String> cur : map1.keys()) {
key = cur.value;
val = map1.get(key);
if (map2.containsKey(key)) {
dot += (val * map2.get(key));
}
scala1 += (val * val);
}
for (ObjectCursor<String> cur : map2.keys()) {
val = map2.get(cur.value);
scala2 += (val * val);
}
scala1 = Math.sqrt(scala1);
scala2 = Math.sqrt(scala2);
return dot / (scala1 * scala2);
}
/**
* @return joined key of two clusters.
*/
public String getJoinedKey(ProbCluster cluster1, ProbCluster cluster2) {
StringBuilder build = new StringBuilder();
/*
* build.append("["); build.append(cluster1.key); build.append(",");
* build.append(cluster2.key);
build.append("]");
*/
build.append(cluster1.key);
build.append(",");
build.append(cluster2.key);
return build.toString();
}
public void printCluster() {
Collections.sort(k_clusters);
int count = 0;
for (ProbCluster cluster : k_clusters) {
if (cluster.size() == 1) {
break;
}
System.out.println(cluster.key + " " + cluster.score);
count++;
}
System.out.println("# of clusters: " + count);
}
// ======================== Hierarchical agglomerative clustering ========================
public void hmCluster(HashMap<String, ObjectDoubleOpenHashMap<String>> map) {
d_similarities = new ObjectDoubleOpenHashMap<>();
k_clusters = getInitClusters(map);
hmClusterRec();
hmClusterTrim();
}
private void hmClusterRec() {
boolean cont = true;
for (int i = 0; cont; i++) {
System.out.println("== Iteration: " + i + " ==");
cont = hmClusterAux();
// if (cont) printCluster();
}
}
protected void hmClusterTrim() {
ArrayList<ProbCluster> remove = new ArrayList<>();
for (ProbCluster cluster : k_clusters) {
if (cluster.size() == 1) {
remove.add(cluster);
}
}
k_clusters.removeAll(remove);
d_similarities.clear();
printCluster();
}
/**
* @return true if clustering is performed.
*/
private boolean hmClusterAux() {
ArrayList<JObjectDoubleTuple<JIntIntTuple>> list = new ArrayList<>();
ProbCluster cluster1, cluster2;
double score;
for (int i = 0; i < k_clusters.size() - 1; i++) {
cluster1 = k_clusters.get(i);
for (int j = i + 1; j < k_clusters.size(); j++) {
cluster2 = k_clusters.get(j);
score = getAvgSimilarity(cluster1, cluster2, true);
// score = getCtrSimilarity(cluster1,cluster2,true);
list.add(new JObjectDoubleTuple<>(new JIntIntTuple(i, j), score));
}
}
IntOpenHashSet sClustered = new IntOpenHashSet();
ArrayList<ProbCluster> sRemove = new ArrayList<>();
JIntIntTuple idx;
Collections.sort(list);
for (int i = 0; i < list.size(); i++) {
JObjectDoubleTuple<JIntIntTuple> tup = list.get(i);
if (tup.value < d_hm_lower) {
break;
}
idx = tup.object;
if (sClustered.contains(idx.int1) || sClustered.contains(idx.int2)) {
continue;
}
sClustered.add(idx.int1);
sClustered.add(idx.int2);
cluster1 = k_clusters.get(idx.int1);
cluster2 = k_clusters.get(idx.int2);
cluster1.addAll(cluster2);
cluster1.set(getJoinedKey(cluster1, cluster2), tup.value);
sRemove.add(cluster2);
}
k_clusters.removeAll(sRemove);
return !sClustered.isEmpty();
}
// ======================== K-mean clustering ========================
public void kmCluster(HashMap<String, ObjectDoubleOpenHashMap<String>> map) {
ArrayList<ProbCluster> nClusters = getInitClusters(map);
ArrayList<JIntIntTuple> list = new ArrayList<>();
ProbCluster kCluster, nCluster;
JIntDoubleTuple max;
int i, j, k;
double sim;
IntOpenHashSet skip = new IntOpenHashSet();
for (k = 0; k < 3; k++) {
System.out.println("== Iteration: " + k + " ==");
list.clear();
for (i = 0; i < nClusters.size(); i++) {
if (skip.contains(i)) {
continue;
}
nCluster = nClusters.get(i);
max = new JIntDoubleTuple(-1, -1);
for (j = 0; j < k_clusters.size(); j++) {
sim = getAvgSimilarity(nCluster, k_clusters.get(j), false);
if (max.d < sim) {
max.set(j, sim);
}
}
if (max.d >= d_km_lower) {
list.add(new JIntIntTuple(max.i, i));
}
}
for (JIntIntTuple tup : list) {
kCluster = k_clusters.get(tup.int1);
nCluster = nClusters.get(tup.int2);
kCluster.addAll(nCluster);
kCluster.set(getJoinedKey(kCluster, nCluster), 1);
skip.add(tup.int2);
}
if (list.isEmpty()) {
break;
}
}
printCluster();
}
public JObjectObjectTuple<IntIntOpenHashMap, IntIntOpenHashMap> getClusterMaps() {
IntIntOpenHashMap lMap = new IntIntOpenHashMap();
IntIntOpenHashMap gMap = new IntIntOpenHashMap();
ProbCluster cluster;
String[] ids, key;
for (int i = 0; i < k_clusters.size(); i++) {
cluster = k_clusters.get(i);
ids = cluster.key.split(",");
for (String id : ids) {
key = id.split(":");
if (key[0].equals(SRLVerbCluster.FLAG_LOCAL)) {
lMap.put(Integer.parseInt(key[1]), i + 1);
} else {
gMap.put(Integer.parseInt(key[1]), i + 1);
}
}
}
return new JObjectObjectTuple<>(lMap, gMap);
}
public void cluster(Prob2dMap map, ObjectDoubleOpenHashMap<String> mWeights) {
d_similarities = new ObjectDoubleOpenHashMap<>();
k_clusters = getInitClusters(map, mWeights);
hmClusterRec();
}
/**
* Called from {@link SRLClusterBuilder#MapCluster(ProbMap, int)}.
*/
private ArrayList<ProbCluster> getInitClusters(Prob2dMap map, ObjectDoubleOpenHashMap<String> mWeights) {
ArrayList<ProbCluster> clusters = new ArrayList<>();
ObjectDoubleOpenHashMap<String> lmap;
ProbCluster cluster;
double weight;
for (String key : map.keySet()) {
cluster = new ProbCluster(key);
lmap = map.getProb2dMap(key);
for (ObjectCursor<String> cur : lmap.keys()) {
if ((weight = mWeights.get(cur.value)) > 0) {
lmap.put(cur.value, lmap.get(cur.value) * weight);
}
}
cluster.add(lmap);
clusters.add(cluster);
}
return clusters;
}
}