package it.unito.geosummly.experiments;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map.Entry;
import java.util.Set;
import java.util.Vector;
import jp.ndca.similarity.distance.Jaccard;
import org.apache.commons.io.FileUtils;
public class ClusterOutputValidation {
public static void main(String[] args)
{
ClusterOutputValidation main = new ClusterOutputValidation();
List<HashMap<String, Vector<Integer>>> holdout =
main.computeCrossFoldValidation("output/evaluation/clustering_output_validation/10-holdout/holdout_results_all_clusters.log");
Jaccard jacc = new Jaccard();
Double jaccOnLabels = 0.0;
Double jaccOnSet = 0.0;
HashMap<String,Vector<Double>> jaccOnSets = new HashMap<>();
int iterations = 0;
for (int i=0; i<holdout.size()-1; i++)
{
HashMap<String, Vector<Integer>> ho1 = holdout.get(i);
for(int j=i+1; j<holdout.size(); j++)
{
System.out.printf("pair (%s,%s)\n", i,j);
HashMap<String, Vector<Integer>> ho2 = holdout.get(j);
System.out.printf("\tjaccard_labels=%s\n", jacc.calc(ho1.keySet().toArray(), ho2.keySet().toArray()) );
//get cluster names from first set
Set<String> cluster_names = new HashSet<String>();
cluster_names.addAll(ho1.keySet());
cluster_names.addAll(ho2.keySet());
Double jaccOnPair = 0.0;
for(String name : cluster_names) {
Vector<Integer> ho1_objects= (ho1.get(name) == null) ? new Vector<Integer>() : ho1.get(name);
Vector<Integer> ho2_objects= (ho2.get(name) == null) ? new Vector<Integer>() : ho2.get(name);
System.out.printf("\tjaccard_on_set(%s)=%s\n", name, jacc.calc(ho1_objects, ho2_objects));
if(!jaccOnSets.containsKey(name)){
Vector<Double> v = new Vector<>();
v.add(jacc.calc(ho1_objects, ho2_objects));
jaccOnSets.put(name, v);
}
else {
Vector<Double> v = jaccOnSets.get(name);
jaccOnSets.remove(name);
v.add(jacc.calc(ho1_objects, ho2_objects));
jaccOnSets.put(name, v);
}
jaccOnPair += jacc.calc(ho1_objects, ho2_objects);
}
System.out.printf("\tjaccard_on_set_average=%s\n", jaccOnPair/cluster_names.size());
jaccOnLabels += jacc.calc(ho1.keySet().toArray(), ho2.keySet().toArray() );
jaccOnSet += jaccOnPair/cluster_names.size();
iterations++;
}
}
System.out.println("#####\n# Totals\n###");
System.out.println("avg_jaccard_labels=" + jaccOnLabels/iterations );
System.out.println("avg_jaccard_objects=" + jaccOnSet/iterations );
for(Entry<String,Vector<Double>> entry: jaccOnSets.entrySet()) {
Double counter = 0.0;
for(Double d : entry.getValue()) {
counter += d;
}
System.out.printf("avg_jaccard_cluster(%s)=%s\n", entry.getKey(), counter/entry.getValue().size());
}
}
public List<HashMap<String, Vector<Integer>>> computeCrossFoldValidation(String path)
{
List<HashMap<String, Vector<Integer>>> holdout = new ArrayList<>();
try {
String blob = FileUtils.readFileToString(new File(path));
String[] chunks = blob.split("_END_HO.*");
for (String chunk : chunks)
{
String[] lines = chunk.trim().split("\n");
// HO level
HashMap<String, Vector<Integer>> ho = new HashMap<>();
for (String line : lines)
{
String[] items = line.trim().split(";");
String cluster_name = items[0];
String[] os = items[1].trim().split(" ");
Vector<Integer> objects = new Vector<>();
for (String o : os) {
objects.add(Integer.parseInt(o));
}
ho.put(cluster_name, objects);
}
holdout.add(ho);
}
return holdout;
} catch (IOException e) {
e.printStackTrace();
}
return null;
}
}