/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */ /* * SemiSupClustererEvaluation.java * Copyright (C) 2002 Sugato Basu, Misha Bilenko * */ package weka.clusterers; import java.util.*; import java.io.*; import weka.core.*; import weka.filters.Filter; import weka.filters.unsupervised.attribute.Remove; /** * Class for evaluating clustering models - extends ClusterEvaluation.java<p> * Implements different clustering evaluation metrics * * @author Sugato Basu, Misha Bilenko */ public class SemiSupClustererEvaluation extends ClusterEvaluation { /** Purity of the clustering */ protected double m_Purity; /** Entropy of the clustering */ protected double m_Entropy; /** Objective function of the clustering */ protected double m_Objective; /** MI Metric the clustering */ protected double m_MIMetric; /** KL Divergence of the clustering */ protected double m_KLDivergence; /** The number of underlying classes */ protected int m_NumClasses; /** The number of produced clusters */ protected int m_NumClusters; /** All labeled training instances */ protected Instances m_LabeledTrain; /** All unlabaled training instances */ protected Instances m_UnlabeledTrain; /** All test instances */ protected Instances m_Test; /** Training pairs */ protected ArrayList m_labeledTrainPairs; /** The weight of all incorrectly categorized test instances. */ protected double m_WeightTestIncorrect; /** The weight of all correctly categorized test instances. */ protected double m_WeightTestCorrect; /** The weight of all uncategorized test instances. */ protected double m_WeightTestUnclassified; /** The weight of test instances that had a class assigned to them. */ protected double m_WeightTestWithClass; /** Array for storing the confusion matrix. */ protected double [][] m_ConfusionMatrix; /** The names of the classes. */ protected String [] m_ClassNames; /** Is the class nominal or numeric? */ protected boolean m_ClassIsNominal; /** If the class is not nominal, we do not need the confusion matrix but do pairs counts directly */ protected int m_totalPairs; protected int m_goodPairs; protected int m_trueGoodPairs; /** The total cost of predictions (includes instance weights) */ protected double m_TotalCost; public String toSummaryString() { return super.toString(); } /** * Returns a string describing this evaluator * @return a description of the evaluator suitable for * displaying in the explorer/experimenter gui */ public String globalInfo() { return " A clusterer evaluator that evaluates results of running a " + "semi-supervised clustering algorithm."; } public SemiSupClustererEvaluation (Instances test, int numClasses, int numClusters) { m_NumClasses = numClasses; m_NumClusters = numClusters; m_ClassIsNominal = test.classAttribute().isNominal(); if (m_ClassIsNominal) { m_ConfusionMatrix = new double [m_NumClusters][m_NumClasses]; m_ClassNames = new String [m_NumClasses]; for(int i = 0; i < m_NumClasses; i++) { m_ClassNames[i] = test.classAttribute().value(i); } } } public SemiSupClustererEvaluation (ArrayList labeledTrainPairs, Instances test, int numClasses, int numClusters) { this (test,numClasses,numClusters); m_labeledTrainPairs = labeledTrainPairs; } /** * Evaluates the semi-sup clusterer on a given set of test instances * * @param clusterer semi-supervised clusterer * @param testInstances set of test instances for evaluation * @exception Exception if model could not be evaluated successfully */ public void evaluateModel (Clusterer clusterer, Instances testInstances, Instances unlabeledTest) throws Exception { if (m_ClassIsNominal) { m_Test = testInstances; m_Objective = ((SemiSupClusterer) clusterer).objectiveFunction(); // Assuming transductive clustering here ... will need to generalize in future System.out.println("Evaluating cluster results ..."); for (int i = 0; i < unlabeledTest.numInstances(); i++) { evaluateModelOnce(clusterer, unlabeledTest.instance(i), (int) (testInstances.instance(i)).classValue()); } } else { // string-based class attributes int numInstances = testInstances.numInstances(); Attribute classAttr = testInstances.classAttribute(); int [][] sharedClass = new int[numInstances][numInstances]; HashSet dontCareSet = new HashSet(); final int HAVE_SHARED_CLASS = 0; final int NO_SHARED_CLASS = 1; final int DONT_CARE = 2; m_totalPairs = 0; m_goodPairs = 0; // calculate the number of true pairs m_trueGoodPairs = 0; HashSet [] classSets = new HashSet[numInstances]; for (int i = 0; i < numInstances; i++) { System.out.println("Classattr: " + classAttr); String classList = testInstances.instance(i).stringValue(classAttr); if (classList.length() != 0) { // skip unassigned instances // parse the list of classes into a hashset HashSet classSet = new HashSet(); StringTokenizer tokenizer = new StringTokenizer(classList, "_"); while (tokenizer.hasMoreTokens()) { classSet.add(tokenizer.nextToken()); } classSets[i] = classSet; for (int j = 0; j < i; j++) { if (classSets[j] != null) { // skip unassigned instances HashSet prevSet = (HashSet) classSets[j]; Iterator iterator = prevSet.iterator(); boolean shareClass = false; // go through previously assigned instance's classes and see if current class list contains any while (iterator.hasNext() && !shareClass) { String classString = (String) iterator.next(); if (classSet.contains(classString)) { shareClass = true; } } if (shareClass) { m_trueGoodPairs++; sharedClass[i][j] = sharedClass[j][i] = HAVE_SHARED_CLASS; } else { sharedClass[i][j] = sharedClass[j][i] = NO_SHARED_CLASS; } } } } else { // all pairs with this instance are don't care dontCareSet.add(new Integer(i)); for (int j = 0; j < numInstances; j++) { sharedClass[i][j] = sharedClass[j][i] = DONT_CARE; } } } // now cluster and evaluate precision ArrayList[] classLists = new ArrayList[m_NumClasses]; for (int i = 0; i < classLists.length; i++) { classLists[i] = new ArrayList(); } for (int i = 0; i < unlabeledTest.numInstances(); i++) { if (!dontCareSet.contains(new Integer(i))) { int clusterIdx = clusterer.clusterInstance(unlabeledTest.instance(i)); // go through all instances previously assigned to the same cluster and check whether they have common classes for (int j = 0; j < classLists[clusterIdx].size(); j++) { int sameClusterInstanceIdx = ((Integer) classLists[clusterIdx].get(j)).intValue(); if (sharedClass[j][sameClusterInstanceIdx] == HAVE_SHARED_CLASS) { m_goodPairs++; } m_totalPairs++; } classLists[clusterIdx].add(new Integer(i)); } } } } /** * Evaluates the semi-sup clusterer on a given test instance * * @param clusterer semi-supervised clusterer * @param test test instance for evaluation * @exception Exception if model could not be evaluated successfully */ public void evaluateModelOnce (Clusterer clusterer, Instance testWithoutLabel, int classValue) throws Exception { double [] pred; if (m_ClassIsNominal) { if (clusterer instanceof DistributionClusterer) { pred = ((DistributionClusterer) clusterer).distributionForInstance(testWithoutLabel); } else { pred = makeDistribution(clusterer.clusterInstance(testWithoutLabel)); } updateStatsForClusterer(pred, classValue); } } /** * Convert a single prediction into a probability distribution * with all zero probabilities except the predicted value which * has probability 1.0; * * @param predictedClass the index of the predicted class * @return the probability distribution */ protected double [] makeDistribution(int predictedCluster) { double [] result = new double [m_NumClasses]; if (m_ClassIsNominal) { result[predictedCluster] = 1.0; } else { result[0] = predictedCluster; } return result; } /** * Updates all the statistics about a clusterer performance for * the current test instance. * * @param distrib the probabilities assigned to each class * @param test the test instance * @exception Exception if the class of the instance is not set */ protected void updateStatsForClusterer(double [] distrib, int classValue) { for (int i=0; i<distrib.length; i++) { // System.out.println("Adding value to distrib: " + i + " with classValue: " + classValue); m_ConfusionMatrix[i][classValue] += distrib[i]; } } public final double objectiveFunction() { return m_Objective; } public final double purity() { return m_Purity; } public final double entropy() { return m_Entropy; } public final double klDivergence() { return m_KLDivergence; } public final double mutualInformation() { if (m_ClassIsNominal) { double [] clusterTotals = new double[m_NumClusters]; double [] classTotals = new double[m_NumClasses]; for (int i=0; i<m_NumClusters; i++) { for (int j=0; j<m_NumClasses; j++) { clusterTotals[i] += m_ConfusionMatrix[i][j]; classTotals[j] += m_ConfusionMatrix[i][j]; } } try { System.out.println(toMatrixString("\nConfusion matrix:")); } catch(Exception e) { e.printStackTrace(); } // calculate MI from counts m_MIMetric = 0.0; int numInstances = m_Test.numInstances(); double MI = 0; for (int i=0; i<m_NumClusters; i++) { for (int j=0; j<m_NumClasses; j++) { if(m_ConfusionMatrix[i][j] !=0 && clusterTotals[i] != 0 && classTotals[i] != 0) { if (clusterTotals[i] != 0 && classTotals[j] != 0) { MI += (1.0 * m_ConfusionMatrix[i][j]/numInstances) * Math.log((1.0 * m_ConfusionMatrix[i][j] * numInstances) / (clusterTotals[i] * classTotals[j])); } } } } double classEntropy = 0, clusterEntropy = 0; for (int i=0; i<m_NumClusters; i++) { if (clusterTotals[i] != 0) { clusterEntropy -= (1.0 * clusterTotals[i])/numInstances * Math.log(1.0 * clusterTotals[i]/numInstances); } } for (int j=0; j<m_NumClasses; j++) { if (classTotals[j] != 0) { classEntropy -= (1.0 * classTotals[j])/numInstances * Math.log(1.0 * classTotals[j]/numInstances); } } m_MIMetric = 2*MI / (classEntropy + clusterEntropy); System.out.println("Final MI is: " + m_MIMetric + "\t" + classEntropy + "\t" + clusterEntropy); } return m_MIMetric; } /** * Outputs the performance statistics as a classification confusion * matrix. For each class value, shows the distribution of * predicted class values. * * @param title the title for the confusion matrix * @return the confusion matrix as a String * @exception Exception if the class is numeric */ public String toMatrixString(String title) throws Exception { StringBuffer text = new StringBuffer(); char [] IDChars = {'a','b','c','d','e','f','g','h','i','j', 'k','l','m','n','o','p','q','r','s','t', 'u','v','w','x','y','z'}; int IDWidth; boolean fractional = false; // Find the maximum value in the matrix // and check for fractional display requirement double maxval = 0; for(int i = 0; i < m_NumClusters; i++) { for(int j = 0; j < m_NumClasses; j++) { double current = m_ConfusionMatrix[i][j]; if (current < 0) { current *= -10; } if (current > maxval) { maxval = current; } double fract = current - Math.rint(current); if (!fractional && ((Math.log(fract) / Math.log(10)) >= -2)) { fractional = true; } } } IDWidth = 1 + Math.max((int)(Math.log(maxval) / Math.log(10) + (fractional ? 3 : 0)), (int)(Math.log(m_NumClasses) / Math.log(IDChars.length))); text.append(title).append("\n"); for(int i = 0; i < m_NumClasses; i++) { if (fractional) { text.append(" ").append(num2ShortID(i,IDChars,IDWidth - 3)) .append(" "); } else { text.append(" ").append(num2ShortID(i,IDChars,IDWidth)); } } text.append(" <-- classes; rows=clusters\n"); for(int i = 0; i< m_NumClusters; i++) { for(int j = 0; j < m_NumClasses; j++) { text.append(" ").append( Utils.doubleToString(m_ConfusionMatrix[i][j], IDWidth, (fractional ? 2 : 0))); } text.append(" | ").append(num2ShortID(i,IDChars,IDWidth)) .append(" = ").append(m_ClassNames[i]).append("\n"); } return text.toString(); } /** * Method for generating indices for the confusion matrix. * * @param num integer to format * @return the formatted integer as a string */ private String num2ShortID(int num,char [] IDChars,int IDWidth) { char ID [] = new char [IDWidth]; int i; for(i = IDWidth - 1; i >=0; i--) { ID[i] = IDChars[num % IDChars.length]; num = num / IDChars.length - 1; if (num < 0) { break; } } for(i--; i >= 0; i--) { ID[i] = ' '; } return new String(ID); } public final double pairwisePrecision() { if (m_ClassIsNominal) { int [] clusterTotals = new int[m_NumClusters]; int [] goodPairTotals = new int[m_NumClusters]; m_totalPairs = 0; m_goodPairs = 0; for (int i = 0; i < m_NumClusters; i++) { for (int j = 0; j < m_NumClasses; j++) { goodPairTotals[i] += m_ConfusionMatrix[i][j] * (m_ConfusionMatrix[i][j] - 1) / 2; clusterTotals[i] += m_ConfusionMatrix[i][j]; } } for (int i = 0; i < m_NumClusters; i++) { m_totalPairs += clusterTotals[i] * (clusterTotals[i] - 1) / 2; m_goodPairs += goodPairTotals[i]; } } return (m_goodPairs+0.0)/m_totalPairs; } public final double pairwiseRecall() { if (m_ClassIsNominal) { int [] classTotals = new int[m_NumClasses]; int [] goodPairTotals = new int[m_NumClasses]; m_trueGoodPairs = 0; m_goodPairs = 0; for (int i = 0; i < m_NumClasses; i++) { for (int j = 0; j < m_NumClusters; j++) { goodPairTotals[i] += m_ConfusionMatrix[j][i] * (m_ConfusionMatrix[j][i] - 1) / 2; classTotals[i] += m_ConfusionMatrix[j][i]; } } for (int i = 0; i < m_NumClasses; i++) { m_trueGoodPairs += classTotals[i] * (classTotals[i] - 1) / 2; m_goodPairs += goodPairTotals[i]; } } return (m_goodPairs+0.0)/m_trueGoodPairs; } public final double pairwiseFMeasure() { double fmeasure = 0; if (m_ClassIsNominal) { int [] clusterTotals = new int[m_NumClusters]; int [] classTotals = new int[m_NumClasses]; int [] goodPairTotals = new int[m_NumClusters]; int totalClassPairs = 0; int totalClusterPairs = 0; int goodPairs = 0; for (int i = 0; i < m_NumClusters; i++) { for (int j = 0; j < m_NumClasses; j++) { goodPairTotals[i] += m_ConfusionMatrix[i][j] * (m_ConfusionMatrix[i][j] - 1) / 2; clusterTotals[i] += m_ConfusionMatrix[i][j]; classTotals[j] += m_ConfusionMatrix[i][j]; } } for (int i = 0; i < m_NumClusters; i++) { totalClusterPairs += clusterTotals[i] * (clusterTotals[i] - 1) / 2; goodPairs += goodPairTotals[i]; } for (int i = 0; i < m_NumClasses; i++) { totalClassPairs += classTotals[i] * (classTotals[i] - 1) / 2; } double precision = (goodPairs+0.0)/totalClusterPairs; double recall = (goodPairs+0.0)/totalClassPairs; if (precision > 0) { // avoid divide by zero in the p=0&r=0 case fmeasure = 2 * (precision * recall) / (precision + recall); } System.out.println("Final F-Measure is: " + fmeasure + "; Precision=" + precision + " Recall=" + recall + "\n"); } else { // the class is not nominal fmeasure = 2.0 * m_goodPairs / (m_totalPairs + m_trueGoodPairs); } return fmeasure; } public final double numSameClassPairs() { int numSameClassPairs = 0; for (int i = 0; i < m_labeledTrainPairs.size(); i++) { InstancePair pair = (InstancePair) m_labeledTrainPairs.get(i); if (pair.linkType == InstancePair.MUST_LINK) { numSameClassPairs++; } } return numSameClassPairs; } public final double numDiffClassPairs() { int numDiffClassPairs = 0; for (int i = 0; i < m_labeledTrainPairs.size(); i++) { InstancePair pair = (InstancePair) m_labeledTrainPairs.get(i); if (pair.linkType == InstancePair.CANNOT_LINK) { numDiffClassPairs++; } } return numDiffClassPairs; } }