package org.apache.samoa.evaluation.measures; /* * #%L * SAMOA * %% * Copyright (C) 2014 - 2015 Apache Software Foundation * %% * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * #L% */ import java.util.ArrayList; import java.util.Arrays; import org.apache.samoa.moa.cluster.Clustering; import org.apache.samoa.moa.core.DataPoint; import org.apache.samoa.moa.evaluation.MeasureCollection; import org.apache.samoa.moa.evaluation.MembershipMatrix; public class StatisticalCollection extends MeasureCollection { private boolean debug = false; @Override protected String[] getNames() { // String[] names = {"van Dongen","Rand statistic", "C Index"}; return new String[] { "van Dongen", "Rand statistic" }; } @Override protected boolean[] getDefaultEnabled() { return new boolean[] { false, false }; } @Override public void evaluateClustering(Clustering clustering, Clustering trueClustering, ArrayList<DataPoint> points) throws Exception { MembershipMatrix mm = new MembershipMatrix(clustering, points); int numClasses = mm.getNumClasses(); int numCluster = clustering.size() + 1; int n = mm.getTotalEntries(); double dongenMaxFC = 0; double dongenMaxSumFC = 0; for (int i = 0; i < numCluster; i++) { double max = 0; for (int j = 0; j < numClasses; j++) { if (mm.getClusterClassWeight(i, j) > max) max = mm.getClusterClassWeight(i, j); } dongenMaxFC += max; if (mm.getClusterSum(i) > dongenMaxSumFC) dongenMaxSumFC = mm.getClusterSum(i); } double dongenMaxHC = 0; double dongenMaxSumHC = 0; for (int j = 0; j < numClasses; j++) { double max = 0; for (int i = 0; i < numCluster; i++) { if (mm.getClusterClassWeight(i, j) > max) max = mm.getClusterClassWeight(i, j); } dongenMaxHC += max; if (mm.getClassSum(j) > dongenMaxSumHC) dongenMaxSumHC = mm.getClassSum(j); } double dongen = (dongenMaxFC + dongenMaxHC) / (2 * n); // normalized dongen // double dongen = 1-(2*n - dongenMaxFC - dongenMaxHC)/(2*n - dongenMaxSumFC // - dongenMaxSumHC); if (debug) System.out.println("Dongen HC:" + dongenMaxHC + " FC:" + dongenMaxFC + " Total:" + dongen + " n " + n); addValue("van Dongen", dongen); // Rand index // http://www.cais.ntu.edu.sg/~qihe/menu4.html double m1 = 0; for (int j = 0; j < numClasses; j++) { double v = mm.getClassSum(j); m1 += v * (v - 1) / 2.0; } double m2 = 0; for (int i = 0; i < numCluster; i++) { double v = mm.getClusterSum(i); m2 += v * (v - 1) / 2.0; } double m = 0; for (int i = 0; i < numCluster; i++) { for (int j = 0; j < numClasses; j++) { double v = mm.getClusterClassWeight(i, j); m += v * (v - 1) / 2.0; } } double M = n * (n - 1) / 2.0; double rand = (M - m1 - m2 + 2 * m) / M; // normalized rand // double rand = (m - m1*m2/M)/(m1/2.0 + m2/2.0 - m1*m2/M); addValue("Rand statistic", rand); // addValue("C Index",cindex(clustering, points)); } public double cindex(Clustering clustering, ArrayList<DataPoint> points) { int numClusters = clustering.size(); double withinClustersDistance = 0; int numDistancesWithin = 0; double numDistances = 0; // double[] withinClusters = new double[numClusters]; double[] minWithinClusters = new double[numClusters]; double[] maxWithinClusters = new double[numClusters]; ArrayList<Integer>[] pointsInClusters = new ArrayList[numClusters]; for (int c = 0; c < numClusters; c++) { pointsInClusters[c] = new ArrayList<>(); minWithinClusters[c] = Double.MAX_VALUE; maxWithinClusters[c] = Double.MIN_VALUE; } for (int p = 0; p < points.size(); p++) { for (int c = 0; c < clustering.size(); c++) { if (clustering.get(c).getInclusionProbability(points.get(p)) > 0.8) { pointsInClusters[c].add(p); numDistances++; } } } // calc within cluster distances + min and max values for (int c = 0; c < numClusters; c++) { int numDistancesInC = 0; ArrayList<Integer> pointsInC = pointsInClusters[c]; for (int p = 0; p < pointsInC.size(); p++) { DataPoint point = points.get(pointsInC.get(p)); for (int p1 = p + 1; p1 < pointsInC.size(); p1++) { numDistancesWithin++; numDistancesInC++; DataPoint point1 = points.get(pointsInC.get(p1)); double dist = point.getDistance(point1); withinClustersDistance += dist; if (minWithinClusters[c] > dist) minWithinClusters[c] = dist; if (maxWithinClusters[c] < dist) maxWithinClusters[c] = dist; } } } double minWithin = Double.MAX_VALUE; double maxWithin = Double.MIN_VALUE; for (int c = 0; c < numClusters; c++) { if (minWithinClusters[c] < minWithin) minWithin = minWithinClusters[c]; if (maxWithinClusters[c] > maxWithin) maxWithin = maxWithinClusters[c]; } double cindex = 0; if (numDistancesWithin != 0) { double meanWithinClustersDistance = withinClustersDistance / numDistancesWithin; cindex = (meanWithinClustersDistance - minWithin) / (maxWithin - minWithin); } if (debug) { System.out.println("Min:" + Arrays.toString(minWithinClusters)); System.out.println("Max:" + Arrays.toString(maxWithinClusters)); System.out.println("totalWithin:" + numDistancesWithin); } return cindex; } }