/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */ /* * SeededKMeans.java * Copyright (C) 2002 Sugato Basu * */ package weka.clusterers; import java.io.*; import java.util.*; import weka.core.*; import weka.core.metrics.*; import weka.filters.Filter; import weka.filters.unsupervised.attribute.Remove; /** * Seeded k means clustering class. * * Valid options are:<p> * * -N <number of clusters> <br> * Specify the number of clusters to generate. <p> * * -R <random seed> <br> * Specify random number seed <p> * * -S <seeding method> <br> * The seeding method can be "seeded" (seeded KMeans) or "constrained" (constrained KMeans) * * -A <algorithm> <br> * The algorithm can be "simple" (simple KMeans) or "spherical" (spherical KMeans) * * -M <metric-class> <br> * Specifies the name of the distance metric class that should be used * * @author Sugato Basu(sugato@cs.utexas.edu) * @see Clusterer * @see OptionHandler */ public class SeededKMeans extends Clusterer implements OptionHandler,SemiSupClusterer,ActiveLearningClusterer { /** Name of clusterer */ String m_name = "SeededKMeans"; /** holds the clusters */ protected ArrayList m_FinalClusters = null; /** holds the instance indices in the clusters */ protected ArrayList m_IndexClusters = null; /** holds the ([seed instance] -> [clusterLabel of seed instance]) mapping */ protected HashMap m_SeedHash = null; /** distance Metric */ protected Metric m_metric = new WeightedDotP(); /** has the metric has been constructed? a fix for multiple buildClusterer's */ protected boolean m_metricBuilt = false; /** starting index of test data in unlabeledData if transductive clustering */ protected int m_StartingIndexOfTest = -1; /** indicates whether instances are sparse */ protected boolean isSparseInstance = false; /** Is the objective function increasing or decreasing? Depends on type * of metric used: for similarity-based metric, increasing, for distance-based - decreasing */ protected boolean m_objFunDecreasing = false; /** Name of metric */ protected String m_metricName = new String("WeightedDotP"); /** Points that are to be skipped in the clustering process * because they are collapsed to zero */ protected HashSet m_skipHash = new HashSet(); /** Index of the current element in the E-step */ protected int m_currIdx = 0; /** keep track of the number of iterations completed before convergence */ protected int m_Iterations = 0; /* Define possible seeding methods */ public static final int SEEDING_CONSTRAINED = 1; public static final int SEEDING_SEEDED = 2; public static final Tag [] TAGS_SEEDING = { new Tag(SEEDING_CONSTRAINED, "Constrained seeding"), new Tag(SEEDING_SEEDED, "Initial seeding only") }; /** seeding method, by default seeded */ protected int m_SeedingMethod = SEEDING_SEEDED; /** Define possible algorithms */ public static final int ALGORITHM_SIMPLE = 1; public static final int ALGORITHM_SPHERICAL = 2; public static final Tag[] TAGS_ALGORITHM = { new Tag(ALGORITHM_SIMPLE, "Simple K-Means"), new Tag(ALGORITHM_SPHERICAL, "Spherical K-Means") }; /** algorithm, by default spherical */ protected int m_Algorithm = ALGORITHM_SPHERICAL; /** min difference of objective function values for convergence*/ protected double m_ObjFunConvergenceDifference = 1e-5; /** value of objective function */ protected double m_Objective = Integer.MAX_VALUE; /** returns objective function */ public double objectiveFunction() { return m_Objective; } /** Verbose? */ protected boolean m_Verbose = false; /** * training instances with labels */ protected Instances m_TotalTrainWithLabels; /** * training instances */ protected Instances m_Instances; /** * number of clusters to generate, default is 3 */ protected int m_NumClusters = 3; /** * m_FastMode = true => fast computation of meanOrMode in centroid calculation, useful for high-D data sets * m_FastMode = false => usual computation of meanOrMode in centroid calculation */ protected boolean m_FastMode = true; /** * holds the cluster centroids */ protected Instances m_ClusterCentroids; /** * holds the global centroids */ protected Instance m_GlobalCentroid; /** * holds the default perturbation value for randomPerturbInit */ protected double m_DefaultPerturb = 0.7; /** weight of the concentration */ protected double m_Concentration = 10.0; /** number of extra phase1 runs */ protected double m_ExtraPhase1RunFraction = 50; /** * temporary variable holding cluster assignments while iterating */ protected int [] m_ClusterAssignments; /** * holds the random Seed, useful for randomPerturbInit */ protected int m_randomSeed = 1; /** semisupervision */ protected boolean m_Seedable = true; /* Constructor */ public SeededKMeans() { } /* Constructor */ public SeededKMeans(Metric metric) { m_metric = metric; m_metricName = m_metric.getClass().getName(); m_objFunDecreasing = metric.isDistanceBased(); } /** * We always want to implement SemiSupClusterer from a class extending Clusterer. * We want to be able to return the underlying parent class. * @return parent Clusterer class */ public Clusterer getThisClusterer() { return this; } /** * Cluster given instances to form the specified number of clusters. * * @param data instances to be clustered * @param num_clusters number of clusters to create * @exception Exception if something goes wrong. */ public void buildClusterer(Instances data, int num_clusters) throws Exception { setNumClusters(num_clusters); if (m_Algorithm == ALGORITHM_SPHERICAL && m_metric instanceof WeightedDotP) { ((WeightedDotP)m_metric).setLengthNormalized(false); // since instances and clusters are already normalized, we don't need to normalize again while computing similarity - saves time } if (data.instance(0) instanceof SparseInstance) { isSparseInstance = true; } buildClusterer(data); } /** * Clusters unlabeledData and labeledData (with labels removed), * using labeledData as seeds * * @param labeledData labeled instances to be used as seeds * @param unlabeledData unlabeled instances * @param classIndex attribute index in labeledData which holds class info * @param numClusters number of clusters * @param startingIndexOfTest from where test data starts in unlabeledData, useful if clustering is transductive * @exception Exception if something goes wrong. */ public void buildClusterer(Instances labeledData, Instances unlabeledData, int classIndex, int numClusters, int startingIndexOfTest) throws Exception { m_StartingIndexOfTest = startingIndexOfTest; buildClusterer(labeledData, unlabeledData, classIndex, numClusters); } /** * Clusters unlabeledData and labeledData (with labels removed), * using labeledData as seeds * * @param labeledData labeled instances to be used as seeds * @param unlabeledData unlabeled instances * @param classIndex attribute index in labeledData which holds class info * @param numClusters number of clusters * @param startingIndexOfTest from where test data starts in unlabeledData, useful if clustering is transductive * @exception Exception if something goes wrong. */ public void buildClusterer(Instances labeledData, Instances unlabeledData, int classIndex, Instances totalTrainWithLabels, int startingIndexOfTest) throws Exception { m_StartingIndexOfTest = startingIndexOfTest; m_TotalTrainWithLabels = totalTrainWithLabels; buildClusterer(labeledData, unlabeledData, classIndex, totalTrainWithLabels.numClasses()); } /** * Clusters unlabeledData and labeledData (with labels removed), * using labeledData as seeds * * @param labeledData labeled instances to be used as seeds * @param unlabeledData unlabeled instances * @param classIndex attribute index in labeledData which holds class info * @param numClusters number of clusters * @exception Exception if something goes wrong. */ public void buildClusterer(Instances labeledData, Instances unlabeledData, int classIndex, int numClusters) throws Exception { if (m_Algorithm == ALGORITHM_SPHERICAL) { if (labeledData != null) { for (int i=0; i<labeledData.numInstances(); i++) { normalize(labeledData.instance(i)); } } for (int i=0; i<unlabeledData.numInstances(); i++) { normalize(unlabeledData.instance(i)); } } Instances clusterData = new Instances(unlabeledData, 0);; if (getSeedable()) { // remove labels of labeledData before putting in seedHash clusterData = new Instances(labeledData); System.out.println("Numattributes: " + clusterData.numAttributes()); clusterData.deleteClassAttribute(); // create seedHash from labeledData Seeder seeder = new Seeder(clusterData, labeledData); setSeedHash(seeder.getAllSeeds()); } // add unlabeled data to labeled data (labels removed), not the // other way around, so that the labels in the hash table entries // and m_TotalTrainWithLabels are consistent for (int i=0; i<unlabeledData.numInstances(); i++) { clusterData.add(unlabeledData.instance(i)); } System.out.println("combinedData has size: " + clusterData.numInstances() + "\n"); // learn metric using labeled data, then cluster both the labeled and unlabeled data if (labeledData != null) { m_metric.buildMetric(labeledData); } else { m_metric.buildMetric(unlabeledData.numAttributes()); } m_metricBuilt = true; buildClusterer(clusterData, numClusters); } /** * Reset all values that have been learned */ public void resetClusterer() throws Exception{ if (m_metric instanceof LearnableMetric) ((LearnableMetric)m_metric).resetMetric(); m_SeedHash = null; m_ClusterCentroids = null; } /** * We can have clusterers that don't utilize seeding */ public boolean seedable() { return m_Seedable; } /** Initializes the cluster centroids - initial M step */ protected void initializeClusterer() { Random random = new Random(m_randomSeed); boolean globalCentroidComputed = false; if (m_Verbose) { // System.out.println("SeedHash is: " + m_SeedHash); } System.out.println("Initializing "); // makes initial cluster assignments for (int i = 0; i < m_Instances.numInstances(); i++) { Instance inst = m_Instances.instance(i); if (m_SeedHash != null && m_SeedHash.containsKey(inst)) { m_ClusterAssignments[i] = ((Integer) m_SeedHash.get(inst)).intValue(); if (m_ClusterAssignments[i] < 0) { m_ClusterAssignments[i] = -1; // For randomPerturbInit if (m_Verbose) { System.out.println("Invalid cluster specification for seed instance " + i + ": " + inst + ", making random initial assignment"); } } else { if (m_Verbose) { System.out.println("Seed instance " + i + ": " + inst + " assigned to cluster: " + m_ClusterAssignments[i]); } } } else { m_ClusterAssignments[i] = -1; // For randomPerturbInit } } Instances [] tempI = new Instances[m_NumClusters]; m_ClusterCentroids = new Instances(m_Instances, m_NumClusters); boolean [] clusterSeeded = new boolean[m_NumClusters]; for (int i = 0; i < m_NumClusters; i++) { tempI[i] = new Instances(m_Instances, 0); // tempI[i] stores the cluster instances for cluster i clusterSeeded[i] = false; // initialize all clusters to be unseeded } for (int i = 0; i < m_Instances.numInstances(); i++) { if (m_ClusterAssignments[i] >= 0) { // seeded cluster clusterSeeded[m_ClusterAssignments[i]] = true; tempI[m_ClusterAssignments[i]].add(m_Instances.instance(i)); } } // Calculates initial cluster centroids for (int i = 0; i < m_NumClusters; i++) { double [] values = new double[m_Instances.numAttributes()]; if (clusterSeeded[i] == true) { if (m_FastMode && isSparseInstance) { values = meanOrMode(tempI[i]); // uses fast meanOrMode } else { for (int j = 0; j < m_Instances.numAttributes(); j++) { values[j] = tempI[i].meanOrMode(j); // uses usual meanOrMode } } } else { // finds global centroid if has not been already computed if (!globalCentroidComputed) { double [] globalValues = new double[m_Instances.numAttributes()]; if (m_FastMode && isSparseInstance) { globalValues = meanOrMode(m_Instances); // uses fast meanOrMode } else { for (int j = 0; j < m_Instances.numAttributes(); j++) { globalValues[j] = m_Instances.meanOrMode(j); // uses usual meanOrMode } } // global centroid is dense in SPKMeans m_GlobalCentroid = new Instance(1.0, globalValues); m_GlobalCentroid.setDataset(m_Instances); if (m_Algorithm == ALGORITHM_SPHERICAL) { try { ((LearnableMetric)m_metric).normalizeInstanceWeighted(m_GlobalCentroid); } catch (Exception e) { e.printStackTrace(); } } globalCentroidComputed = true; if (m_Verbose) { System.out.println("Global centroid is: " + m_GlobalCentroid); } } // randomPerturbInit if (m_Verbose) { System.out.println("RandomPerturbInit seeding for centroid " + i); } for (int j = 0; j < m_Instances.numAttributes(); j++) { values[j] = m_GlobalCentroid.value(j) * (1 + m_DefaultPerturb * (random.nextFloat() - 0.5)); } } // cluster centroids are dense in SPKMeans m_ClusterCentroids.add(new Instance(1.0, values)); if (m_Algorithm == ALGORITHM_SPHERICAL) { try { ((LearnableMetric) m_metric).normalizeInstanceWeighted(m_ClusterCentroids.instance(i)); } catch (Exception e) { e.printStackTrace(); } } } } /** E-step of the KMeans clustering algorithm -- find best cluster assignments */ protected void findBestAssignments() throws Exception{ m_Objective = 0; int moved=0; for (int i = 0; i < m_Instances.numInstances(); i++) { m_currIdx = i; Instance inst = m_Instances.instance(i); boolean assigned = false; // Constrained KMeans algorithm if(m_SeedingMethod == SEEDING_CONSTRAINED) { if (m_SeedHash == null) { System.err.println("Needs seed information for constrained SeededKMeans"); } else if(m_SeedHash.containsKey(inst)) { // Seeded instances m_ClusterAssignments[i] = ((Integer) m_SeedHash.get(inst)).intValue(); assigned = true; if (m_Verbose) { System.out.println("Assigning cluster " + m_ClusterAssignments[i] + " for seed instance " + i + ": " + inst); } } } try { if (!assigned) { // Unseeded instances int newAssignment = assignClusterToInstance(inst); if (newAssignment != m_ClusterAssignments[i]) { moved++; if (m_Verbose) { System.out.println("Reassigning instance " + i + " old cluster=" + m_ClusterAssignments[i] + " new cluster=" + newAssignment); } } m_ClusterAssignments[i] = newAssignment; } // Update objective function if (!m_objFunDecreasing) { // objective function increases monotonically double newSimilarity = m_metric.similarity(inst, m_ClusterCentroids.instance(m_ClusterAssignments[i])); m_Objective += newSimilarity; } else { // objective function decreases monotonically double newDistance = m_metric.distance(inst, m_ClusterCentroids.instance(m_ClusterAssignments[i])); m_Objective += newDistance * newDistance; } } catch (Exception e) { System.out.println("Could not find distance. Exception: " + e); e.printStackTrace(); } } if(m_Verbose) { System.out.println("\nAfter iteration " + m_Iterations + ":\n"); /* for (int k=0; k<m_ClusterCentroids.numInstances(); k++) { System.out.println (" Centroid " + k + " is " + m_ClusterCentroids.instance(k)); } */ } System.out.println("Number of points moved in this E-step: " + moved); } /** M-step of the KMeans clustering algorithm -- updates cluster centroids */ protected void updateClusterCentroids() { // M-step: update cluster centroids Instances [] tempI = new Instances[m_NumClusters]; m_ClusterCentroids = new Instances(m_Instances, m_NumClusters); for (int i = 0; i < m_NumClusters; i++) { tempI[i] = new Instances(m_Instances, 0); // tempI[i] stores the cluster instances for cluster i } for (int i = 0; i < m_Instances.numInstances(); i++) { tempI[m_ClusterAssignments[i]].add(m_Instances.instance(i)); if (m_Verbose) { System.out.println("Instance " + i + " added to cluster " + m_ClusterAssignments[i]); } } // Calculates cluster centroids for (int i = 0; i < m_NumClusters; i++) { double [] values = new double[m_Instances.numAttributes()]; if (m_FastMode && isSparseInstance) { values = meanOrMode(tempI[i]); // uses fast meanOrMode } else { for (int j = 0; j < m_Instances.numAttributes(); j++) { values[j] = tempI[i].meanOrMode(j); // uses usual meanOrMode } } // cluster centroids are dense in SPKMeans m_ClusterCentroids.add(new Instance(1.0, values)); if (m_Algorithm == ALGORITHM_SPHERICAL) { try { ((LearnableMetric) m_metric).normalizeInstanceWeighted(m_ClusterCentroids.instance(i)); } catch (Exception e) { e.printStackTrace(); } } } } /** calculates objective function */ protected void calculateObjectiveFunction() throws Exception { m_Objective = 0; for (int i=0; i<m_Instances.numInstances(); i++) { if (m_objFunDecreasing) { double dist = m_metric.distance(m_Instances.instance(i), m_ClusterCentroids.instance(m_ClusterAssignments[i])); m_Objective += dist*dist; } else { //m_Objective += similarity(i, m_ClusterAssignments[i]); m_Objective += m_metric.similarity(m_Instances.instance(i), m_ClusterCentroids.instance(m_ClusterAssignments[i])); } } } /** * Generates a clusterer. Instances in data have to be * either all sparse or all non-sparse * * @param data set of instances serving as training data * @exception Exception if the clusterer has not been * generated successfully */ public void buildClusterer(Instances data) throws Exception { setInstances(data); // Don't rebuild the metric if it was already trained if (!m_metricBuilt) { m_metric.buildMetric(data); } m_ClusterCentroids = new Instances(m_Instances, m_NumClusters); m_ClusterAssignments = new int [m_Instances.numInstances()]; if (m_Verbose && m_SeedHash != null) { System.out.println("Using seeding ..."); } if (m_Instances.checkForNominalAttributes() && m_Instances.checkForStringAttributes()) { throw new UnsupportedAttributeTypeException("Cannot handle nominal attributes\n"); } initializeClusterer(); // Initializes cluster centroids (initial M-step) System.out.println("Done initializing clustering ..."); getIndexClusters(); printIndexClusters(); if (m_Verbose) { for (int i=0; i<m_NumClusters; i++) { System.out.println("Centroid " + i + ": " + m_ClusterCentroids.instance(i)); } } boolean converged = false; m_Iterations = 0; double oldObjective = m_objFunDecreasing ? Double.POSITIVE_INFINITY : Double.NEGATIVE_INFINITY; while (!converged) { // E-step: updates m_Objective System.out.println("Doing E-step ..."); findBestAssignments(); // M-step System.out.println("Doing M-step ..."); updateClusterCentroids(); m_Iterations++; calculateObjectiveFunction(); // Convergence check if(Math.abs(oldObjective - m_Objective) > m_ObjFunConvergenceDifference) { if (m_objFunDecreasing ? (oldObjective < m_Objective) : (oldObjective > m_Objective)) { converged = true; System.out.println("\nOSCILLATING, oldObjective=" + oldObjective + " newObjective=" + m_Objective); System.out.println("Seeding=" + m_Seedable + " SeedingMethod=" + m_SeedingMethod ); } else { converged = false; System.out.println("Objective function is: " + m_Objective); } } else { converged = true; System.out.println("Old Objective function was: " + oldObjective); System.out.println("Final Objective function is: " + m_Objective); } oldObjective = m_Objective; } } public InstancePair[] bestPairsForActiveLearning(int numActive) throws Exception { throw new Exception("Not implemented for SeededKMeans"); } /** Returns the indices of the best numActive instances for active learning */ public int[] bestInstancesForActiveLearning(int numActive) throws Exception{ int numInstances = m_Instances.numInstances(); int [] clusterSizes = new int[m_NumClusters]; int [] activeLearningPoints = new int[numActive]; int [] clusterAssignments = new int[numInstances]; Instance [] sumOfClusterInstances = new Instance[m_NumClusters]; HashSet visitedPoints = new HashSet(numInstances); boolean allClustersFound = false; int numPointsSelected = 0; // initialize clusterAssignments, clusterSizes, visitedPoints, sumOfClusterInstances for (int i=0; i<numInstances; i++) { Instance inst = m_Instances.instance(i); if (m_SeedHash != null && m_SeedHash.containsKey(inst)) { clusterAssignments[i] = ((Integer) m_SeedHash.get(inst)).intValue(); clusterSizes[clusterAssignments[i]]++; visitedPoints.add(new Integer(i)); sumOfClusterInstances[clusterAssignments[i]] = sumWithInstance(sumOfClusterInstances[clusterAssignments[i]], inst); if (m_Verbose) { // System.out.println("Init: adding point " + i + " to cluster " + clusterAssignments[i]); } } else { clusterAssignments[i] = -1; } } // set allClustersFound allClustersFound = setAllClustersFound(clusterSizes); int totalPointsSpecified=0; for (int i=0; i<m_NumClusters; i++) { totalPointsSpecified += clusterSizes[i]; // HACK!!! } System.out.println("Total points specified: " + totalPointsSpecified + ", limit: " + m_ExtraPhase1RunFraction); if (totalPointsSpecified < m_ExtraPhase1RunFraction) { allClustersFound = false; } while (numPointsSelected < numActive) { if (!allClustersFound) { // PHASE 1 System.out.println("In Phase 1"); // find next point, farthest from visited points int nextPoint = farthestFromSet(visitedPoints, null); if (nextPoint >= m_StartingIndexOfTest) { throw new Exception ("Test point " + nextPoint + " selected, something went wrong -- starting index of test is: " + m_StartingIndexOfTest); } visitedPoints.add(new Integer(nextPoint)); activeLearningPoints[numPointsSelected] = nextPoint; numPointsSelected++; // update cluster stats for this point int classLabel = (int) m_TotalTrainWithLabels.instance(nextPoint).classValue(); clusterAssignments[nextPoint] = classLabel; clusterSizes[classLabel]++; sumOfClusterInstances[classLabel] = sumWithInstance(sumOfClusterInstances[classLabel], m_Instances.instance(nextPoint)); // set allClustersFound // if (m_Verbose) { System.out.println("Active learning point number: " + numPointsSelected + " is: " + nextPoint + ", with class label: " + classLabel); // } allClustersFound = setAllClustersFound(clusterSizes); if (numPointsSelected >= numActive) { System.out.println("Out of queries before phase 1 extra loop. Queries so far: " + numPointsSelected); return activeLearningPoints; // go out of function } if (allClustersFound) { // Extra RUNS OF PHASE 1 int [] tempClusterSizes = new int[m_NumClusters]; // temp cluster sizes boolean tempAllClustersFound = false; HashSet points = new HashSet(numInstances); // points visited in this farthest first loop points.add(new Integer(nextPoint)); // mark only last point as visited tempClusterSizes[classLabel]++; // update temp cluster sizes for this point HashSet eliminationSet = new HashSet(numInstances); // don't include these points in farthest first search for (int i=0; i<numInstances; i++) { Instance inst = m_Instances.instance(i); if (m_SeedHash != null && m_SeedHash.containsKey(inst)) { eliminationSet.add(new Integer(i)); // add labeled data to elimination set } } Iterator iter = visitedPoints.iterator(); while(iter.hasNext()) { eliminationSet.add(iter.next()); // add already visited points to elim set } for (int i=0; i<m_ExtraPhase1RunFraction; i++) { System.out.println("Continuing Phase 1 run: " + i + " after all clusters visited"); // find next point, farthest from points, eliminating points in eliminationSet nextPoint = farthestFromSet(points, eliminationSet); if (nextPoint >= m_StartingIndexOfTest) { throw new Exception ("Test point " + nextPoint + " selected, something went wrong -- starting index of test is: " + m_StartingIndexOfTest); } visitedPoints.add(new Integer(nextPoint)); // add to total set of visited points points.add(new Integer(nextPoint)); // add to points visited in this farthest first loop activeLearningPoints[numPointsSelected] = nextPoint; numPointsSelected++; // update cluster stats for this point classLabel = (int) m_TotalTrainWithLabels.instance(nextPoint).classValue(); clusterAssignments[nextPoint] = classLabel; clusterSizes[classLabel]++; sumOfClusterInstances[classLabel] = sumWithInstance(sumOfClusterInstances[classLabel], m_Instances.instance(nextPoint)); tempClusterSizes[classLabel]++; // if (m_Verbose) { System.out.println("Active learning point number: " + numPointsSelected + " is: " + nextPoint + ", with class label: " + classLabel); // } tempAllClustersFound = setAllClustersFound(tempClusterSizes); if (tempAllClustersFound) { // found all clusters, reset local variables System.out.println("Resetting variables for next round of farthest first"); tempClusterSizes = new int[m_NumClusters]; tempAllClustersFound = false; Iterator tempIter = points.iterator(); while(tempIter.hasNext()) { eliminationSet.add((Integer) tempIter.next()); // add already visited points to elim set } points.clear(); // clear current set points.add(new Integer(nextPoint)); // add the last point tempClusterSizes[classLabel]++; // for the last point } if (numPointsSelected >= numActive) { System.out.println("Out of queries within phase 1 extra loop. Queries so far: " + numPointsSelected); return activeLearningPoints; // go out of function } } } } else { // PHASE 2 // find smallest cluster System.out.println("In Phase 2"); int smallestSize = Integer.MAX_VALUE, smallestCluster = -1; for (int i=0; i<m_NumClusters; i++) { if (clusterSizes[i] < smallestSize) { smallestSize = clusterSizes[i]; smallestCluster = i; } } if (m_Verbose) { System.out.println("Smallest cluster now: " + smallestCluster + ", with size: " + smallestSize); } // compute centroid of smallest cluster Instance centroidOfSmallestCluster; if (isSparseInstance) { centroidOfSmallestCluster = new SparseInstance(sumOfClusterInstances[smallestCluster]); } else { centroidOfSmallestCluster = new Instance(sumOfClusterInstances[smallestCluster]); } centroidOfSmallestCluster.setDataset(m_Instances); if (!m_objFunDecreasing) { normalize(centroidOfSmallestCluster); } else { normalizeByWeight(centroidOfSmallestCluster); } // find next point, closest to centroid of smallest cluster int nextPoint = nearestFromPoint(centroidOfSmallestCluster, visitedPoints); if (nextPoint >= m_StartingIndexOfTest) { throw new Exception ("Test point selected, something went wrong!"); } visitedPoints.add(new Integer(nextPoint)); activeLearningPoints[numPointsSelected] = nextPoint; numPointsSelected++; // update cluster stats for this point int classLabel = (int) m_TotalTrainWithLabels.instance(nextPoint).classValue(); clusterAssignments[nextPoint] = classLabel; clusterSizes[classLabel]++; sumOfClusterInstances[classLabel] = sumWithInstance(sumOfClusterInstances[classLabel], m_Instances.instance(nextPoint)); // if (m_Verbose) { System.out.println("Active learning point number: " + numPointsSelected + " is: " + nextPoint + ", with class label: " + classLabel); // } allClustersFound = setAllClustersFound(clusterSizes); if (allClustersFound != true) { throw new Exception("Something went wrong - all clusters should be set in phase 2!!"); } } } return activeLearningPoints; } /** Returns true if all clusterSizes are non-zero */ boolean setAllClustersFound(int [] clusterSizes) { boolean found = true; for (int i=0; i<m_NumClusters; i++) { if (clusterSizes[i] == 0) { found = false; } //if (m_Verbose) { System.out.println("Cluster " + i + " has size: " + clusterSizes[i]); //} } return found; } /** Finds the sum of instance sum with instance inst */ Instance sumWithInstance(Instance sum, Instance inst) throws Exception { Instance newSum; if (sum == null) { if (isSparseInstance) { newSum = new SparseInstance(inst); newSum.setDataset(m_Instances); } else { newSum = new Instance(inst); newSum.setDataset(m_Instances); } } else { newSum = sumInstances(sum, inst); } return newSum; } /** Finds point which has max min-distance from set visitedPoints */ int farthestFromSet(HashSet visitedPoints, HashSet eliminationSet) throws Exception { // implements farthest-first search algorithm: /* for (each datapoint x not in visitedPoints) { distance of x to visitedPoints = min{d(x,f):f \in visitedPoints} } select the point x with maximum distance as new center; */ if (visitedPoints.size() == 0) { Random rand = new Random(m_randomSeed); int point = rand.nextInt(m_StartingIndexOfTest); // Note - no need to check for labeled data now, since we have no visitedPoints // => no labeled data System.out.println("First point selected: " + point); return point; } else { if (m_Verbose) { Iterator iter = visitedPoints.iterator(); while(iter.hasNext()) { System.out.println("In visitedPoints set: " + ((Integer) iter.next()).intValue()); } if (eliminationSet != null) { iter = eliminationSet.iterator(); while(iter.hasNext()) { System.out.println("In elimination set: " + ((Integer) iter.next()).intValue()); } } } } double minSimilaritySoFar = Double.POSITIVE_INFINITY; double maxDistanceSoFar = Double.NEGATIVE_INFINITY; ArrayList bestPoints = new ArrayList(); for (int i=0; i<m_Instances.numInstances() && i<m_StartingIndexOfTest; i++) { // point should not belong to test set if (!visitedPoints.contains(new Integer(i))) { if (eliminationSet == null || !eliminationSet.contains(new Integer(i))) { // point should not belong to visitedPoints Instance inst = m_Instances.instance(i); Iterator iter = visitedPoints.iterator(); double minDistanceFromSet = Double.POSITIVE_INFINITY; double maxSimilarityFromSet = Double.NEGATIVE_INFINITY; while (iter.hasNext()) { Instance pointInSet = m_Instances.instance(((Integer) iter.next()).intValue()); if (!m_objFunDecreasing) { double sim = m_metric.similarity(inst, pointInSet); if (sim > maxSimilarityFromSet) { maxSimilarityFromSet = sim; // if (m_Verbose) { // System.out.println("Max similarity of " + i + " from set is: " + maxSimilarityFromSet); // } } } else { double dist = m_metric.distance(inst, pointInSet); if (dist < minDistanceFromSet) { minDistanceFromSet = dist; // if (m_Verbose) { // System.out.println("Min distance of " + i + " from set is: " + minDistanceFromSet); // } } } } if (m_Verbose) { System.out.println(i + " has sim: " + maxSimilarityFromSet + ", best: " + minSimilaritySoFar); } if (!m_objFunDecreasing) { if (maxSimilarityFromSet == minSimilaritySoFar) { minSimilaritySoFar = maxSimilarityFromSet; bestPoints.add(new Integer(i)); if (m_Verbose) { System.out.println("Additional point added: " + i + " with similarity: " + minSimilaritySoFar); } } else if (maxSimilarityFromSet < minSimilaritySoFar) { minSimilaritySoFar = maxSimilarityFromSet; bestPoints.clear(); bestPoints.add(new Integer(i)); if (m_Verbose) { System.out.println("Farthest point from set is: " + i + " with similarity: " + minSimilaritySoFar); } } } else { if (minDistanceFromSet == maxDistanceSoFar) { minDistanceFromSet = maxDistanceSoFar; bestPoints.add(new Integer(i)); if (m_Verbose) { System.out.println("Additional point added: " + i + " with similarity: " + minSimilaritySoFar); } } else if (minDistanceFromSet > maxDistanceSoFar) { maxDistanceSoFar = minDistanceFromSet; bestPoints.clear(); bestPoints.add(new Integer(i)); if (m_Verbose) { System.out.println("Farthest point from set is: " + i + " with distance: " + maxDistanceSoFar); } } } } } } int bestPoint = -1; if (bestPoints.size() > 1) { // multiple points, get random from whole set Random random = new Random(m_randomSeed); bestPoint = random.nextInt(m_StartingIndexOfTest); while ((visitedPoints != null && visitedPoints.contains(new Integer(bestPoint))) || (eliminationSet != null && eliminationSet.contains(new Integer(bestPoint)))) { bestPoint = random.nextInt(m_StartingIndexOfTest); } System.out.println("Randomly selected " + bestPoint + " with similarity: " + minSimilaritySoFar); } else { // only 1 point, fine bestPoint = ((Integer)bestPoints.get(0)).intValue(); System.out.println("Deterministically selected " + bestPoint + " with similarity: " + minSimilaritySoFar); } if (m_Verbose) { if (!m_objFunDecreasing) { System.out.println("Randomly selected " + bestPoint + " with similarity: " + minSimilaritySoFar); } else { System.out.println("Randomly selected " + bestPoint + " with similarity: " + maxDistanceSoFar); } } return bestPoint; } /** Finds point which is nearest to center. This point should not be * a test point and should not belong to visitedPoints */ int nearestFromPoint(Instance center, HashSet visitedPoints) throws Exception { double maxSimilarity = Double.NEGATIVE_INFINITY; double minDistance = Double.POSITIVE_INFINITY; int bestPoint = -1; for (int i=0; i<m_Instances.numInstances() && i<m_StartingIndexOfTest; i++) { // bestPoint should not be a test point if (!visitedPoints.contains(new Integer(i))) { // bestPoint should not belong to visitedPoints Instance inst = m_Instances.instance(i); if (!m_objFunDecreasing) { double sim = m_metric.similarity(inst, center); if (sim > maxSimilarity) { bestPoint = i; maxSimilarity = sim; if (m_Verbose) { System.out.println("Nearest point is: " + bestPoint + " with sim: " + maxSimilarity); } } } else { double dist = m_metric.distance(inst, center); if (dist < minDistance) { bestPoint = i; minDistance = dist; if (m_Verbose) { System.out.println("Nearest point is: " + bestPoint + " with dist: " + minDistance); } } } } } return bestPoint; } /** Finds sum of instances (handles sparse and non-sparse) */ protected Instance sumInstances(Instance inst1, Instance inst2) throws Exception { int numAttributes = inst1.numAttributes(); if (inst2.numAttributes() != numAttributes) { throw new Exception ("Error!! inst1 and inst2 should have same number of attributes."); } if (m_Verbose) { // System.out.println("Instance 1 is: " + inst1 + ", instance 2 is: " + inst2); } double weight1 = inst1.weight(), weight2 = inst2.weight(); double [] values = new double[numAttributes]; for (int i=0; i<numAttributes; i++) { values[i] = 0; } if (inst1 instanceof SparseInstance && inst2 instanceof SparseInstance) { for (int i=0; i<inst1.numValues(); i++) { int indexOfIndex = inst1.index(i); values[indexOfIndex] = inst1.valueSparse(i); } for (int i=0; i<inst2.numValues(); i++) { int indexOfIndex = inst2.index(i); values[indexOfIndex] += inst2.valueSparse(i); } SparseInstance newInst = new SparseInstance(weight1+weight2, values); newInst.setDataset(m_Instances); if (m_Verbose) { // System.out.println("Sum instance is: " + newInst); } return newInst; } else if (!(inst1 instanceof SparseInstance) && !(inst2 instanceof SparseInstance)){ for (int i=0; i<numAttributes; i++) { values[i] = inst1.value(i) + inst2.value(i); } } else { throw new Exception ("Error!! inst1 and inst2 should be both of same type -- sparse or non-sparse"); } Instance newInst = new Instance(weight1+weight2, values); newInst.setDataset(m_Instances); if (m_Verbose) { // System.out.println("Sum instance is: " + newInst); } return newInst; } /** This function divides every attribute value in an instance by * the instance weight -- useful to find the mean of a cluster in * Euclidean space * @param inst Instance passed in for normalization (destructive update) */ protected void normalizeByWeight(Instance inst) { double weight = inst.weight(); if (m_Verbose) { // System.out.println("Before weight normalization: " + inst); } if (inst instanceof SparseInstance) { for (int i=0; i<inst.numValues(); i++) { inst.setValueSparse(i, inst.valueSparse(i)/weight); } } else if (!(inst instanceof SparseInstance)) { for (int i=0; i<inst.numAttributes(); i++) { inst.setValue(i, inst.value(i)/weight); } } if (m_Verbose) { // System.out.println("After weight normalization: " + inst); } } public int[] oldBestInstancesForActiveLearning(int numActive) throws Exception{ int numInstances = m_Instances.numInstances(); double [] scores = new double [numInstances]; int numLabeledData = 0; if (m_SeedHash != null) { numLabeledData = m_SeedHash.size(); } // Remember: order of data -- labeled, then unlabeled, then test for (int i=0; i<numLabeledData; i++) { scores[i] = -1; } for (int i=numLabeledData; i<numInstances; i++) { double score = 0, normalizer = 0; Instance inst = m_Instances.instance(i); double[] prob = new double[m_NumClusters]; for (int j=0; j<m_NumClusters; j++) { if (!m_objFunDecreasing) { double sim = m_metric.similarity(inst, m_ClusterCentroids.instance(j)); prob[j] = Math.exp(sim * m_Concentration); // P(x|h) } else { double dist = m_metric.distance(inst, m_ClusterCentroids.instance(j)); prob[j] = Math.exp(-dist*dist * m_Concentration); // P(x|h) } normalizer += prob[j]; // P(x)/P(h) = Sum_h P(x|h) [uniform priors P(h)] } for (int j=0; j<m_NumClusters; j++) { prob[j] /= normalizer; // P(h|x) = P(x|h)*P(h)/P(x) score -= prob[j] * Math.log(prob[j]); } scores[i] = score * normalizer; // InfoGain = H(C|x).P(x) [with a constant factor of 1/P(h)] } System.out.println("NumInstances: "+ numInstances + ", starting index of unlabeled train: " + numLabeledData + ", starting index of test: " + m_StartingIndexOfTest); int [] indices = Utils.sort(scores); int [] mostConfused = new int [numActive]; for (int i=0,num=0; i<numInstances && num<numActive; i++) { int index = numInstances-1-i; if ((indices[index]<m_StartingIndexOfTest) && (scores[indices[index]]!=-1)) { // makes sure that labeled or test instances are not asked to be active labeled mostConfused[num] = (indices[index]); num++; } } for (int i=0; i<numActive; i++) { // System.out.println("Value: " + scores[mostConfused[i]] + ", index: " + mostConfused[i]); } return mostConfused; } /** * Checks if instance has to be normalized and classifies the * instance using the current clustering * * @param instance the instance to be assigned to a cluster * @return the number of the assigned cluster as an integer * if the class is enumerated, otherwise the predicted value * @exception Exception if instance could not be classified * successfully */ public int clusterInstance(Instance instance) throws Exception { if (m_Algorithm == ALGORITHM_SPHERICAL) { // check here, since evaluateModel calls this function on test data normalize(instance); } return assignClusterToInstance(instance); } /** * Classifies the instance using the current clustering * * @param instance the instance to be assigned to a cluster * @return the number of the assigned cluster as an integer * if the class is enumerated, otherwise the predicted value * @exception Exception if instance could not be classified * successfully */ public int assignClusterToInstance(Instance instance) throws Exception { int bestCluster = 0; double bestDistance = Double.POSITIVE_INFINITY; double bestSimilarity = Double.NEGATIVE_INFINITY; for (int i = 0; i < m_NumClusters; i++) { double distance = 0, similarity = 0; if (!m_objFunDecreasing) { similarity = m_metric.similarity(instance, m_ClusterCentroids.instance(i)); if (similarity > bestSimilarity) { bestSimilarity = similarity; bestCluster = i; } } else { distance = m_metric.distance(instance, m_ClusterCentroids.instance(i)); if (distance < bestDistance) { bestDistance = distance; bestCluster = i; } } } if (bestSimilarity == 0) { System.out.println("Note!! bestSimilarity is 0 for instance " + m_currIdx + ", assigned to cluster: " + bestCluster + " ... instance is: " + instance); } return bestCluster; } /** Return the number of clusters */ public int getNumClusters() { return m_NumClusters; } /** A duplicate function to conform to Clusterer abstract class. * @returns the number of clusters */ public int numberOfClusters() { return getNumClusters(); } /** Return the number of extra phase1 runs */ public double getExtraPhase1RunFraction() { return m_ExtraPhase1RunFraction; } /** Set the number of extra phase1 runs */ public void setExtraPhase1RunFraction(double w) { m_ExtraPhase1RunFraction = w; } /** Return the concentration */ public double getConcentration() { return m_Concentration; } /** Set the concentration */ public void setConcentration(double w) { m_Concentration = w; } /** Set the m_SeedHash */ public void setSeedHash(HashMap seedhash) { m_SeedHash = seedhash; } /** * Set the random number seed * @param s the seed */ public void setRandomSeed (int s) { m_randomSeed = s; } /** Return the random number seed */ public int getRandomSeed () { return m_randomSeed; } /** * Set the minimum value of the objective function difference required for convergence * @param objFunConvergenceDifference the minimum value of the objective function difference required for convergence */ public void setObjFunConvergenceDifference(double objFunConvergenceDifference) { m_ObjFunConvergenceDifference = objFunConvergenceDifference; } /** * Get the minimum value of the objective function difference required for convergence * @returns the minimum value of the objective function difference required for convergence */ public double getObjFunConvergenceDifference() { return m_ObjFunConvergenceDifference; } /** Sets training instances */ public void setInstances(Instances instances) { m_Instances = instances; } /** Return training instances */ public Instances getInstances() { return m_Instances; } /** * Set the number of clusters to generate * * @param n the number of clusters to generate */ public void setNumClusters(int n) { m_NumClusters = n; if (m_Verbose) { System.out.println("Number of clusters: " + n); } } /** * Set the distance metric * * @param s the metric */ public void setMetric (LearnableMetric m) { m_metric = m; m_metricName = m_metric.getClass().getName(); m_objFunDecreasing = m.isDistanceBased(); } /** * Get the distance metric * * @returns the distance metric used */ public Metric getMetric () { return m_metric; } /** * Get the distance metric name * * @returns the name of the distance metric used */ public String metricName () { return m_metricName; } /** * Set the seeding method. Values other than * SEEDING_CONSTRAINED, or SEEDING_SEEDED will be ignored * * @param seedingMethod the seeding method to use */ public void setSeedingMethod (SelectedTag seedingMethod) { if (seedingMethod.getTags() == TAGS_SEEDING) { if (m_Verbose) { System.out.println("Seeding method: " + seedingMethod.getSelectedTag().getReadable()); } m_SeedingMethod = seedingMethod.getSelectedTag().getID(); } } /** * Get the seeding method used. * * @returns the seeding method */ public SelectedTag getSeedingMethod () { return new SelectedTag(m_SeedingMethod, TAGS_SEEDING); } /** * Set the KMeans algorithm. Values other than * ALGORITHM_SIMPLE or ALGORITHM_SPHERICAL will be ignored * * @param algo algorithm type */ public void setAlgorithm (SelectedTag algo) { if (algo.getTags() == TAGS_ALGORITHM) { if (m_Verbose) { System.out.println("Algorithm: " + algo.getSelectedTag().getReadable()); } m_Algorithm = algo.getSelectedTag().getID(); } } /** * Get the KMeans algorithm type. Will be one of * ALGORITHM_SIMPLE or ALGORITHM_SPHERICAL * * @returns algorithm type */ public SelectedTag getAlgorithm () { return new SelectedTag(m_Algorithm, TAGS_ALGORITHM); } /** * Set the distance metric * * @param met the distance metric that should be used */ public void setMetricName (String metricName) { try { m_metricName = metricName; m_metric = (Metric) Class.forName(metricName).newInstance(); m_objFunDecreasing = m_metric.isDistanceBased(); } catch (Exception e) { System.err.println("Error instantiating metric " + metricName); } } /** Set default perturbation value * @param p perturbation fraction */ public void setDefaultPerturb(double p) { m_DefaultPerturb = p; } /** Get default perturbation value * @return perturbation fraction */ public double getDefaultPerturb(){ return m_DefaultPerturb; } /** Turn seeding on and off * @param seedable should seeding be done? */ public void setSeedable(boolean seedable) { m_Seedable = seedable; } /** Turn seeding on and off * @param seedable should seeding be done? */ public boolean getSeedable() { return m_Seedable; } /** Read the seeds from a hastable, where every key is an instance and every value is: * the cluster assignment of that instance * seedVector vector containing seeds */ public void seedClusterer(HashMap seedHash) { if(m_Seedable) { setSeedHash(seedHash); } } /** * Computes the clusters from the cluster assignments, for external access * * @exception Exception if clusters could not be computed successfully */ public ArrayList getIndexClusters() throws Exception { m_IndexClusters = new ArrayList(); Cluster [] clusterArray = new Cluster[m_Instances.numInstances()]; for (int i=0; i < m_Instances.numInstances(); i++) { if (m_ClusterAssignments[i]!=-1) { if (clusterArray[m_ClusterAssignments[i]] == null) { clusterArray[m_ClusterAssignments[i]] = new Cluster(); } clusterArray[m_ClusterAssignments[i]].add(new Integer(i), 1); } } for (int j =0; j< m_Instances.numInstances(); j++) m_IndexClusters.add(clusterArray[j]); return m_IndexClusters; } /** Outputs the current clustering * * @exception Exception if something goes wrong */ public void printIndexClusters() throws Exception { if (m_IndexClusters == null) throw new Exception ("Clusters were not created"); for (int i = 0; i < m_IndexClusters.size(); i++) { Cluster cluster = (Cluster) m_IndexClusters.get(i); if (cluster == null) { // System.out.println("Cluster " + i + " is null"); } else { System.out.println ("Cluster " + i + " consists of " + cluster.size() + " elements"); for (int j = 0; j < cluster.size(); j++) { int idx = ((Integer) cluster.get(j)).intValue(); System.out.println("\t\t" + idx); } } } } /** Prints clusters */ public void printClusters () throws Exception{ ArrayList clusters = getClusters(); for (int i=0; i<clusters.size(); i++) { Cluster currentCluster = (Cluster) clusters.get(i); System.out.println("\nCluster " + i + ": " + currentCluster.size() + " instances"); if (currentCluster == null) { System.out.println("(empty)"); } else { for (int j=0; j<currentCluster.size(); j++) { Instance instance = (Instance) currentCluster.get(j); System.out.println("Instance: " + instance); } } } } /** * Computes the final clusters from the cluster assignments, for external access * * @exception Exception if clusters could not be computed successfully */ public ArrayList getClusters() throws Exception { m_FinalClusters = new ArrayList(); Cluster [] clusterArray = new Cluster[m_NumClusters]; for (int i=0; i < m_Instances.numInstances(); i++) { Instance inst = m_Instances.instance(i); if(clusterArray[m_ClusterAssignments[i]] == null) clusterArray[m_ClusterAssignments[i]] = new Cluster(); clusterArray[m_ClusterAssignments[i]].add(inst, 1); } for (int j =0; j< m_NumClusters; j++) m_FinalClusters.add(clusterArray[j]); return m_FinalClusters; } public Enumeration listOptions () { return null; } /** * Gets the classifier specification string, which contains the class name of * the classifier and any options to the classifier * * @return the classifier string. */ protected String getMetricSpec() { if (m_metric instanceof OptionHandler) { return m_metric.getClass().getName() + " " + Utils.joinOptions(((OptionHandler)m_metric).getOptions()); } return m_metric.getClass().getName(); } public String [] getOptions () { String[] options = new String[80]; int current = 0; options[current++] = "-N"; options[current++] = "" + getNumClusters(); options[current++] = "-R"; options[current++] = "" + getRandomSeed(); if (getSeedable()) { options[current++] = "-S"; options[current++] = "" + getSeedingMethod().getSelectedTag().getID(); } options[current++] = "-A"; options[current++] = "" + getAlgorithm().getSelectedTag().getID(); options[current++] = "-M"; options[current++] = m_metric.getClass().getName(); if (m_metric instanceof OptionHandler) { String[] metricOptions = ((OptionHandler)m_metric).getOptions(); for (int i = 0; i < metricOptions.length; i++) { options[current++] = metricOptions[i]; } } while (current < options.length) { options[current++] = ""; } return options; } /** * Parses a given list of options. * @param options the list of options as an array of strings * @exception Exception if an option is not supported * **/ public void setOptions (String[] options) throws Exception { String optionString = Utils.getOption('N', options); if (optionString.length() != 0) { setNumClusters(Integer.parseInt(optionString)); } optionString = Utils.getOption('R', options); if (optionString.length() != 0) { setRandomSeed(Integer.parseInt(optionString)); } optionString = Utils.getOption('S', options); if (optionString.length() != 0) { setSeedingMethod(new SelectedTag(Integer.parseInt(optionString), TAGS_SEEDING)); } else { setSeedable(false); } optionString = Utils.getOption('A', options); if (optionString.length() != 0) { setAlgorithm(new SelectedTag(Integer.parseInt(optionString), TAGS_ALGORITHM)); } optionString = Utils.getOption('M', options); if (optionString.length() != 0) { String[] metricSpec = Utils.splitOptions(optionString); String metricName = metricSpec[0]; metricSpec[0] = ""; if (m_Verbose) { System.out.println("Metric name: " + metricName + "\nMetric parameters: " + concatStringArray(metricSpec)); } setMetric((LearnableMetric) LearnableMetric.forName(metricName, metricSpec)); } } /** A little helper to create a single String from an array of Strings * @param strings an array of strings * @returns a single concatenated string, separated by commas */ public static String concatStringArray(String[] strings) { String result = new String(); for (int i = 0; i < strings.length; i++) { result = result + "\"" + strings[i] + "\" "; } return result; } /** * return a string describing this clusterer * * @return a description of the clusterer as a string */ public String toString() { StringBuffer temp = new StringBuffer(); temp.append("\nkMeans\n======\n"); temp.append("\nNumber of iterations: " + m_Iterations+"\n"); temp.append("\nCluster centroids:\n"); for (int i = 0; i < m_NumClusters; i++) { temp.append("\nCluster "+i+"\n\t"); /* temp.append(m_ClusterCentroids.instance(i)); for (int j = 0; j < m_ClusterCentroids.numAttributes(); j++) { if (m_ClusterCentroids.attribute(j).isNominal()) { temp.append(" "+m_ClusterCentroids.attribute(j). value((int)m_ClusterCentroids.instance(i).value(j))); } else { temp.append(" "+m_ClusterCentroids.instance(i).value(j)); } } */ } temp.append("\n"); return temp.toString(); } /** * set the verbosity level of the clusterer * @param verbose messages on(true) or off (false) */ public void setVerbose (boolean verbose) { m_Verbose = verbose; } /** * get the verbosity level of the clusterer * @return messages on(true) or off (false) */ public boolean getVerbose () { return m_Verbose; } /** * Train the clusterer using specified parameters * * @param instances Instances to be used for training */ public void trainClusterer (Instances instances) throws Exception { if (m_metric instanceof LearnableMetric) { if (((LearnableMetric)m_metric).getTrainable()) { ((LearnableMetric)m_metric).learnMetric(instances); } else { throw new Exception ("Metric is not trainable"); } } else { throw new Exception ("Metric is not trainable"); } } /** Normalizes Instance or SparseInstance * * @author Sugato Basu * @param inst Instance to be normalized */ public void normalize(Instance inst) throws Exception { if (inst instanceof SparseInstance) { normalizeSparseInstance(inst); } else { ((LearnableMetric) m_metric).normalizeInstanceWeighted(inst); } } /** Normalizes the values of a normal Instance * * @author Sugato Basu * @param inst Instance to be normalized */ public void normalizeInstance(Instance inst) throws Exception{ double norm = 0; double values [] = inst.toDoubleArray(); if (inst instanceof SparseInstance) { throw new Exception("Use normalizeSparseInstance function"); } for (int i=0; i<values.length; i++) { if (i != inst.classIndex()) { // don't normalize the class index norm += values[i] * values[i]; } } norm = Math.sqrt(norm); for (int i=0; i<values.length; i++) { if (i != inst.classIndex()) { // don't normalize the class index if (norm == 0) { values[i]= 0; } else { values[i] /= norm; } } } inst.setValueArray(values); } /** Normalizes the values of a SparseInstance * * @author Sugato Basu * @param inst SparseInstance to be normalized */ public void normalizeSparseInstance(Instance inst) throws Exception{ double norm=0; int length = inst.numValues(); if (!(inst instanceof SparseInstance)) { throw new Exception("Use normalizeInstance function"); } for (int i=0; i<length; i++) { if (inst.index(i) != inst.classIndex()) { // don't normalize the class index norm += inst.valueSparse(i) * inst.valueSparse(i); } } norm = Math.sqrt(norm); for (int i=0; i<length; i++) { // don't normalize the class index if (inst.index(i) != inst.classIndex()) { inst.setValueSparse(i, inst.valueSparse(i)/norm); } } } /** Fast version of meanOrMode - streamlined from Instances.meanOrMode for efficiency * Does not check for missing attributes, assumes numeric attributes, assumes Sparse instances */ protected double[] meanOrMode(Instances insts) { int numAttributes = insts.numAttributes(); double [] value = new double[numAttributes]; double weight = 0; for (int i=0; i<numAttributes; i++) { value[i] = 0; } for (int j=0; j<insts.numInstances(); j++) { SparseInstance inst = (SparseInstance) (insts.instance(j)); weight += inst.weight(); for (int i=0; i<inst.numValues(); i++) { int indexOfIndex = inst.index(i); value[indexOfIndex] += inst.weight() * inst.valueSparse(i); } } if (Utils.eq(weight, 0)) { for (int k=0; k<numAttributes; k++) { value[k] = 0; } } else { for (int k=0; k<numAttributes; k++) { value[k] = value[k] / weight; } } return value; } /** * Main method for testing this class. * */ public static void main (String[] args) { try { String dataSet = new String("news"); //String dataSet = new String("iris"); if (dataSet.equals("iris")) { //////// Iris data String datafile = "/u/ml/software/weka-latest/data/iris.arff"; // set up the data FileReader reader = new FileReader (datafile); Instances data = new Instances (reader); // Make the last attribute be the class int theClass = data.numAttributes(); data.setClassIndex(theClass-1); // starts with 0 // Remove the class labels before clustering Instances clusterData = new Instances(data); clusterData.deleteClassAttribute(); // #clusters = #classes int num_clusters = data.numClasses(); // cluster with seeding Instances seeds = new Instances(data,0,5); seeds.add(data.instance(50)); seeds.add(data.instance(51)); seeds.add(data.instance(52)); seeds.add(data.instance(53)); seeds.add(data.instance(54)); seeds.add(data.instance(100)); seeds.add(data.instance(101)); seeds.add(data.instance(102)); seeds.add(data.instance(103)); seeds.add(data.instance(104)); data.delete(104); data.delete(103); data.delete(102); data.delete(101); data.delete(100); data.delete(54); data.delete(53); data.delete(52); data.delete(51); data.delete(50); data.delete(4); data.delete(3); data.delete(2); data.delete(1); data.delete(0); System.out.println("\nClustering the iris data with seeding, using seeded KMeans...\n"); WeightedEuclidean euclidean = new WeightedEuclidean(); SeededKMeans kmeans = new SeededKMeans (euclidean); kmeans.resetClusterer(); kmeans.setVerbose(false); kmeans.setSeedingMethod(new SelectedTag(SEEDING_SEEDED, TAGS_SEEDING)); kmeans.setAlgorithm(new SelectedTag(ALGORITHM_SIMPLE, TAGS_ALGORITHM)); euclidean.setExternal(false); euclidean.setTrainable(false); // phase 1 test kmeans.setSeedable(false); kmeans.buildClusterer(null, clusterData, theClass, data, 150); // phase 2 test //kmeans.setSeedable(true); //kmeans.buildClusterer(seeds, clusterData, theClass, data, 150); kmeans.getIndexClusters(); kmeans.printIndexClusters(); // kmeans.setVerbose(true); kmeans.bestInstancesForActiveLearning(50); } else if (dataSet.equals("news")) { //////// Text data - 3000 documents String datafile = "/u/ml/data/CCSfiles/arffFromCCS/cmu-newsgroup-clean-1000_fromCCS.arff"; System.out.println("\nClustering complete newsgroup data with seeding, using constrained KMeans...\n"); // set up the data FileReader reader = new FileReader (datafile); Instances data = new Instances (reader); System.out.println("Initial data has size: " + data.numInstances()); // Make the last attribute be the class int theClass = data.numAttributes(); data.setClassIndex(theClass-1); // starts with 0 int num_clusters = data.numClasses(); // cluster with seeding Instances seeds = new Instances(data, 0); /* seeds.add(data.instance(994)); seeds.add(data.instance(1431)); seeds.add(data.instance(1612)); seeds.add(data.instance(1747)); seeds.add(data.instance(2205)); seeds.add(data.instance(2736)); data.delete(2736); data.delete(2205); data.delete(1747); data.delete(1612); data.delete(1431); data.delete(994); seeds.add(data.instance(1000)); seeds.add(data.instance(1001)); seeds.add(data.instance(1002)); seeds.add(data.instance(1003)); seeds.add(data.instance(1004)); seeds.add(data.instance(2000)); seeds.add(data.instance(2001)); seeds.add(data.instance(2002)); seeds.add(data.instance(2003)); seeds.add(data.instance(2004)); // System.out.println("Labeled data has size: " + seeds.numInstances() + ", number of attributes: " + data.numAttributes()); data.delete(2004); data.delete(2003); data.delete(2002); data.delete(2001); data.delete(2000); data.delete(1004); data.delete(1003); data.delete(1002); data.delete(1001); data.delete(1000); data.delete(4); data.delete(3); data.delete(2); data.delete(1); data.delete(0); */ System.out.println("Unlabeled data has size: " + data.numInstances()); // Remove the class labels before clustering Instances clusterData = new Instances(data); clusterData.deleteClassAttribute(); WeightedDotP dotp = new WeightedDotP(); dotp.setExternal(false); dotp.setTrainable(false); dotp.setLengthNormalized(false); SeededKMeans kmeans = new SeededKMeans(dotp); kmeans.setVerbose(false); kmeans.setSeedingMethod(new SelectedTag(SEEDING_SEEDED, TAGS_SEEDING)); kmeans.setAlgorithm(new SelectedTag(ALGORITHM_SPHERICAL, TAGS_ALGORITHM)); kmeans.setNumClusters(3); // phase 1 test kmeans.setSeedable(false); kmeans.buildClusterer(null, clusterData, theClass, data, data.numInstances()); // phase 2 test //kmeans.setSeedable(true); //kmeans.buildClusterer(seeds, clusterData, theClass, data, 3000); kmeans.getIndexClusters(); kmeans.printIndexClusters(); // kmeans.setVerbose(true); //kmeans.bestInstancesForActiveLearning(50); // // cluster with seeding for small newsgroup // seeds = new Instances(data, 0, 3); // seeds.add(data.instance(100)); // seeds.add(data.instance(101)); // seeds.add(data.instance(102)); // seeds.add(data.instance(200)); // seeds.add(data.instance(201)); // seeds.add(data.instance(202)); // seeds.add(data.instance(300)); // seeds.add(data.instance(301)); // seeds.add(data.instance(302)); // seeds.add(data.instance(400)); // seeds.add(data.instance(401)); // seeds.add(data.instance(402)); // seeds.add(data.instance(500)); // seeds.add(data.instance(501)); // seeds.add(data.instance(502)); // seeds.add(data.instance(600)); // seeds.add(data.instance(601)); // seeds.add(data.instance(602)); // seeds.add(data.instance(700)); // seeds.add(data.instance(701)); // seeds.add(data.instance(702)); // seeds.add(data.instance(800)); // seeds.add(data.instance(801)); // seeds.add(data.instance(802)); // seeds.add(data.instance(900)); // seeds.add(data.instance(901)); // seeds.add(data.instance(902)); // seeds.add(data.instance(1000)); // seeds.add(data.instance(1001)); // seeds.add(data.instance(1002)); // seeds.add(data.instance(1100)); // seeds.add(data.instance(1101)); // seeds.add(data.instance(1102)); // seeds.add(data.instance(1200)); // seeds.add(data.instance(1201)); // seeds.add(data.instance(1202)); // seeds.add(data.instance(1300)); // seeds.add(data.instance(1301)); // seeds.add(data.instance(1302)); // seeds.add(data.instance(1400)); // seeds.add(data.instance(1401)); // seeds.add(data.instance(1402)); // seeds.add(data.instance(1500)); // seeds.add(data.instance(1501)); // seeds.add(data.instance(1502)); // seeds.add(data.instance(1600)); // seeds.add(data.instance(1601)); // seeds.add(data.instance(1602)); // seeds.add(data.instance(1700)); // seeds.add(data.instance(1701)); // seeds.add(data.instance(1702)); // seeds.add(data.instance(1800)); // seeds.add(data.instance(1801)); // seeds.add(data.instance(1802)); // seeds.add(data.instance(1900)); // seeds.add(data.instance(1901)); // seeds.add(data.instance(1902)); // System.out.println("Labeled data has size: " + seeds.numInstances() + ", number of attributes: " + data.numAttributes()); // data.delete(1902); // data.delete(1901); // data.delete(1900); // data.delete(1802); // data.delete(1801); // data.delete(1800); // data.delete(1702); // data.delete(1701); // data.delete(1700); // data.delete(1602); // data.delete(1601); // data.delete(1600); // data.delete(1502); // data.delete(1501); // data.delete(1500); // data.delete(1402); // data.delete(1401); // data.delete(1400); // data.delete(1302); // data.delete(1301); // data.delete(1300); // data.delete(1202); // data.delete(1201); // data.delete(1200); // data.delete(1102); // data.delete(1101); // data.delete(1100); // data.delete(1002); // data.delete(1001); // data.delete(1000); // data.delete(902); // data.delete(901); // data.delete(900); // data.delete(802); // data.delete(801); // data.delete(800); // data.delete(702); // data.delete(701); // data.delete(700); // data.delete(602); // data.delete(601); // data.delete(600); // data.delete(502); // data.delete(501); // data.delete(500); // data.delete(402); // data.delete(401); // data.delete(400); // data.delete(302); // data.delete(301); // data.delete(300); // data.delete(202); // data.delete(201); // data.delete(200); // data.delete(102); // data.delete(101); // data.delete(100); // data.delete(2); // data.delete(1); // data.delete(0); } } catch (Exception e) { e.printStackTrace(); } } }