/* Copyright (C) 2003 Univ. of Massachusetts Amherst, Computer Science Dept. This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit). http://www.cs.umass.edu/~mccallum/mallet This software is provided under the terms of the Common Public License, version 1.0, as published by http://www.opensource.org. For further information, see the file `LICENSE' included with this distribution. */ /** A clustering of a set of points (instances). @author Jerod Weinman <A HREF="mailto:weinman@cs.umass.edu">weinman@cs.umass.edu</A> */ package cc.mallet.cluster; import java.io.IOException; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.io.Serializable; import java.util.Arrays; import cc.mallet.types.InstanceList; public class Clustering implements Serializable { protected int numLabels; protected int labels[]; protected InstanceList instances; /** Clustering constructor. * * @param instances Instances that are clustered * @param numLabels Number of clusters * @param labels Assignment of instances to clusters; many-to-one with * range [0,numLabels). */ public Clustering (InstanceList instances, int numLabels, int[] labels) { if (instances.size() != labels.length) throw new IllegalArgumentException("Instance list length does not match cluster labeling"); if (numLabels < 1) throw new IllegalArgumentException("Number of labels must be strictly positive."); for (int i = 0 ; i < labels.length ; i++) if (labels[i] < 0 || labels[i] >= numLabels) throw new IllegalArgumentException("Label mapping must have range [0,numLabels)."); this.instances = instances; this.numLabels = numLabels; this.labels = labels; } // GETTERS public InstanceList getInstances () { return this.instances; } /** Return an list of instances with a particular label. */ public InstanceList getCluster(int label) { InstanceList cluster = new InstanceList(instances.getPipe()); for (int n=0 ; n<instances.size() ; n++) if (labels[n] == label) cluster.add(instances.get(n)); return cluster; } /** Returns an array of instance lists corresponding to clusters. */ public InstanceList[] getClusters() { InstanceList[] clusters = new InstanceList[numLabels]; for (int c= 0 ; c<numLabels ; c++) clusters[c] = getCluster(c); return clusters; } /** Get the cluster label for a particular instance. */ public int getLabel(int index) { return labels[index]; } public int[] getLabels() { return labels; } public int getNumClusters() { return numLabels; } public int getNumInstances() { return instances.size(); } public int size (int label) { int size = 0; for (int i = 0; i < labels.length; i++) if (labels[i] == label) size++; return size; } public int[] getIndicesWithLabel (int label) { int[] indices = new int[size(label)]; int count = 0; for (int i = 0; i < labels.length; i++) if (labels[i] == label) indices[count++] = i; return indices; } public boolean equals (Object o) { Clustering c = (Clustering) o; return Arrays.equals(c.getLabels(), labels); } public String toString () { String result=""; result+="#Clusters: "+getNumClusters()+"\n"; for(int i=0;i<getNumClusters();i++) { result+="\n--CLUSTER "+i+"--"; int[] cluster=getIndicesWithLabel(i); for(int k=0;k<cluster.length;k++) { result+="\n\t"+instances.get(cluster[k]).getData().toString(); } } return result; } public Clustering shallowCopy () { int[] newLabels = new int[labels.length]; System.arraycopy(labels, 0, newLabels, 0, labels.length); Clustering c = new Clustering(instances, numLabels, newLabels); return c; } // SETTERS /** Set the cluster label for a particular instance. */ public void setLabel(int index, int label) { labels[index] = label; } /** Set the number of clusters */ public void setNumLabels(int n) { numLabels = n; } // SERIALIZATION private static final long serialVersionUID = 1; private static final int CURRENT_SERIAL_VERSION = 1; private void writeObject (ObjectOutputStream out) throws IOException { out.defaultWriteObject (); out.writeInt (CURRENT_SERIAL_VERSION); } private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException { in.defaultReadObject (); int version = in.readInt (); } }