package com.yahoo.labs.samoa.evaluation.measures; /* * #%L * SAMOA * %% * Copyright (C) 2010 RWTH Aachen University, Germany * %% * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * #L% */ import java.util.ArrayList; import java.util.Arrays; import com.yahoo.labs.samoa.moa.cluster.Clustering; import com.yahoo.labs.samoa.moa.core.DataPoint; import com.yahoo.labs.samoa.moa.evaluation.MeasureCollection; import com.yahoo.labs.samoa.moa.evaluation.MembershipMatrix; public class StatisticalCollection extends MeasureCollection{ private boolean debug = false; @Override protected String[] getNames() { //String[] names = {"van Dongen","Rand statistic", "C Index"}; return new String[]{"van Dongen","Rand statistic"}; } @Override protected boolean[] getDefaultEnabled() { return new boolean[]{false, false}; } @Override public void evaluateClustering(Clustering clustering, Clustering trueClustering, ArrayList<DataPoint> points) throws Exception { MembershipMatrix mm = new MembershipMatrix(clustering, points); int numClasses = mm.getNumClasses(); int numCluster = clustering.size()+1; int n = mm.getTotalEntries(); double dongenMaxFC = 0; double dongenMaxSumFC = 0; for (int i = 0; i < numCluster; i++){ double max = 0; for (int j = 0; j < numClasses; j++) { if(mm.getClusterClassWeight(i, j)>max) max = mm.getClusterClassWeight(i, j); } dongenMaxFC+=max; if(mm.getClusterSum(i)>dongenMaxSumFC) dongenMaxSumFC = mm.getClusterSum(i); } double dongenMaxHC = 0; double dongenMaxSumHC = 0; for (int j = 0; j < numClasses; j++) { double max = 0; for (int i = 0; i < numCluster; i++){ if(mm.getClusterClassWeight(i, j)>max) max = mm.getClusterClassWeight(i, j); } dongenMaxHC+=max; if(mm.getClassSum(j)>dongenMaxSumHC) dongenMaxSumHC = mm.getClassSum(j); } double dongen = (dongenMaxFC + dongenMaxHC)/(2*n); //normalized dongen //double dongen = 1-(2*n - dongenMaxFC - dongenMaxHC)/(2*n - dongenMaxSumFC - dongenMaxSumHC); if(debug) System.out.println("Dongen HC:"+dongenMaxHC+" FC:"+dongenMaxFC+" Total:"+dongen+" n "+n); addValue("van Dongen", dongen); //Rand index //http://www.cais.ntu.edu.sg/~qihe/menu4.html double m1 = 0; for (int j = 0; j < numClasses; j++) { double v = mm.getClassSum(j); m1+= v*(v-1)/2.0; } double m2 = 0; for (int i = 0; i < numCluster; i++){ double v = mm.getClusterSum(i); m2+= v*(v-1)/2.0; } double m = 0; for (int i = 0; i < numCluster; i++){ for (int j = 0; j < numClasses; j++) { double v = mm.getClusterClassWeight(i, j); m+= v*(v-1)/2.0; } } double M = n*(n-1)/2.0; double rand = (M - m1 - m2 +2*m)/M; //normalized rand //double rand = (m - m1*m2/M)/(m1/2.0 + m2/2.0 - m1*m2/M); addValue("Rand statistic", rand); //addValue("C Index",cindex(clustering, points)); } public double cindex(Clustering clustering, ArrayList<DataPoint> points){ int numClusters = clustering.size(); double withinClustersDistance = 0; int numDistancesWithin = 0; double numDistances = 0; //double[] withinClusters = new double[numClusters]; double[] minWithinClusters = new double[numClusters]; double[] maxWithinClusters = new double[numClusters]; ArrayList<Integer>[] pointsInClusters = new ArrayList[numClusters]; for (int c = 0; c < numClusters; c++) { pointsInClusters[c] = new ArrayList<>(); minWithinClusters[c] = Double.MAX_VALUE; maxWithinClusters[c] = Double.MIN_VALUE; } for (int p = 0; p < points.size(); p++) { for (int c = 0; c < clustering.size(); c++) { if(clustering.get(c).getInclusionProbability(points.get(p)) > 0.8){ pointsInClusters[c].add(p); numDistances++; } } } //calc within cluster distances + min and max values for (int c = 0; c < numClusters; c++) { int numDistancesInC = 0; ArrayList<Integer> pointsInC = pointsInClusters[c]; for (int p = 0; p < pointsInC.size(); p++) { DataPoint point = points.get(pointsInC.get(p)); for (int p1 = p+1; p1 < pointsInC.size(); p1++) { numDistancesWithin++; numDistancesInC++; DataPoint point1 = points.get(pointsInC.get(p1)); double dist = point.getDistance(point1); withinClustersDistance+=dist; if(minWithinClusters[c] > dist) minWithinClusters[c] = dist; if(maxWithinClusters[c] < dist) maxWithinClusters[c] = dist; } } } double minWithin = Double.MAX_VALUE; double maxWithin = Double.MIN_VALUE; for (int c = 0; c < numClusters; c++) { if(minWithinClusters[c] < minWithin) minWithin = minWithinClusters[c]; if(maxWithinClusters[c] > maxWithin) maxWithin = maxWithinClusters[c]; } double cindex = 0; if(numDistancesWithin != 0){ double meanWithinClustersDistance = withinClustersDistance/numDistancesWithin; cindex = (meanWithinClustersDistance - minWithin)/(maxWithin-minWithin); } if(debug){ System.out.println("Min:"+Arrays.toString(minWithinClusters)); System.out.println("Max:"+Arrays.toString(maxWithinClusters)); System.out.println("totalWithin:"+numDistancesWithin); } return cindex; } }