/* * To change this template, choose Tools | Templates * and open the template in the editor. */ package tr.gov.ulakbim.jDenetX.evaluation; import tr.gov.ulakbim.jDenetX.cluster.Clustering; import tr.gov.ulakbim.jDenetX.gui.visualization.DataPoint; import java.util.ArrayList; import java.util.HashMap; /** * @author jansen */ public class F1 extends MeasureCollection { private static final long serialVersionUID = 1L; private final double pointInclusionProbThreshold = 0.5; public F1() { super(); } @Override protected String[] getNames() { String[] names = {"Precision", "Recall", "F1"}; return names; } @SuppressWarnings("unchecked") public void evaluateClustering(Clustering clustering, Clustering trueClustering, ArrayList<DataPoint> points) { if (clustering.size() < 0) { addValue(0, 0); addValue(1, 0); addValue(2, 0); return; } //init labelmap HashMap<Integer, Integer> labelMap = new HashMap<Integer, Integer>(); int numGTCluster = trueClustering.size(); for (int c = 0; c < numGTCluster; c++) { labelMap.put((int) trueClustering.get(c).getGroundTruth(), c); } //init lists to hold points according to found clusters ArrayList<Integer>[] foundClusters = (ArrayList<Integer>[]) new ArrayList[clustering.size()]; for (int i = 0; i < foundClusters.length; i++) { foundClusters[i] = new ArrayList<Integer>(); } int classWeightsFoundClusters[][] = new int[clustering.size()][numGTCluster]; int classWeightsHidden[] = new int[numGTCluster]; //map points to clusters for (int p = 0; p < points.size(); p++) { int worklabel = -1; if (points.get(p).classValue() != -1) worklabel = labelMap.get((int) points.get(p).classValue()); for (int c = 0; c < clustering.size(); c++) { double prob = clustering.get(c).getInclusionProbability(points.get(p)); //we need to change this in case we get real probabilities and not just 0 / 1 if (prob >= pointInclusionProbThreshold) { foundClusters[c].add(p); if (worklabel != -1) classWeightsFoundClusters[c][worklabel]++; } } //real class distribution if (worklabel != -1) classWeightsHidden[worklabel]++; } //figure out f1 per cluster double[] precision = new double[clustering.size()]; double[] recall = new double[clustering.size()]; double[] f1 = new double[clustering.size()]; double F1 = 0.0; double precisionTotal = 0.0; double recallTotal = 0.0; int realClusters = 0; //F1 as defined in P3C, try using F1 optimization for (int i = 0; i < clustering.size(); i++) { int max_weight = 0; int max_weight_index = -1; int cluster_weight = 0; for (int j = 0; j < numGTCluster; j++) { if (classWeightsFoundClusters[i][j] > max_weight) { max_weight = classWeightsFoundClusters[i][j]; max_weight_index = j; } cluster_weight += classWeightsFoundClusters[i][j]; } if (max_weight_index != -1) { realClusters++; precision[i] = (double) classWeightsFoundClusters[i][max_weight_index] / (double) cluster_weight; recall[i] = (double) classWeightsFoundClusters[i][max_weight_index] / (double) classWeightsHidden[max_weight_index]; if (precision[i] > 0 || recall[i] > 0) f1[i] = 2 * precision[i] * recall[i] / (precision[i] + recall[i]); clustering.get(i).setMeasureValue("Precision", Double.toString(precision[i])); clustering.get(i).setMeasureValue("Recall", Double.toString(recall[i])); clustering.get(i).setMeasureValue("F1", Double.toString(f1[i])); precisionTotal += precision[i]; recallTotal += recall[i]; F1 += f1[i]; } } if (realClusters > 0) { F1 /= realClusters; recallTotal /= realClusters; precisionTotal /= realClusters; } addValue("F1", F1); addValue("Precision", precisionTotal); addValue("Recall", recallTotal); } }