package cc.mallet.cluster.tui; import gnu.trove.TIntHashSet; import java.io.File; import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.util.logging.Logger; import cc.mallet.cluster.Clustering; import cc.mallet.cluster.Clusterings; import cc.mallet.cluster.util.ClusterUtils; import cc.mallet.pipe.Noop; import cc.mallet.pipe.Pipe; import cc.mallet.types.Alphabet; import cc.mallet.types.Instance; import cc.mallet.types.InstanceList; import cc.mallet.types.LabelAlphabet; import cc.mallet.util.CommandOption; import cc.mallet.util.MalletLogger; import cc.mallet.util.Randoms; // In progress public class Clusterings2Clusterings { private static Logger logger = MalletLogger.getLogger(Clusterings2Clusterings.class.getName()); public static void main (String[] args) { CommandOption .setSummary(Clusterings2Clusterings.class, "A tool to manipulate Clusterings."); CommandOption.process(Clusterings2Clusterings.class, args); Clusterings clusterings = null; try { ObjectInputStream iis = new ObjectInputStream(new FileInputStream(inputFile.value)); clusterings = (Clusterings) iis.readObject(); } catch (Exception e) { System.err.println("Exception reading clusterings from " + inputFile.value + " " + e); e.printStackTrace(); } logger.info("number clusterings=" + clusterings.size()); // Prune clusters based on size. if (minClusterSize.value > 1) { for (int i = 0; i < clusterings.size(); i++) { Clustering clustering = clusterings.get(i); InstanceList oldInstances = clustering.getInstances(); Alphabet alph = oldInstances.getDataAlphabet(); LabelAlphabet lalph = (LabelAlphabet) oldInstances.getTargetAlphabet(); if (alph == null) alph = new Alphabet(); if (lalph == null) lalph = new LabelAlphabet(); Pipe noop = new Noop(alph, lalph); InstanceList newInstances = new InstanceList(noop); for (int j = 0; j < oldInstances.size(); j++) { int label = clustering.getLabel(j); Instance instance = oldInstances.get(j); if (clustering.size(label) >= minClusterSize.value) newInstances.add(noop.pipe(new Instance(instance.getData(), lalph.lookupLabel(new Integer(label)), instance.getName(), instance.getSource()))); } clusterings.set(i, createSmallerClustering(newInstances)); } if (outputPrefixFile.value != null) { try { ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(outputPrefixFile.value)); oos.writeObject(clusterings); oos.close(); } catch (Exception e) { logger.warning("Exception writing clustering to file " + outputPrefixFile.value + " " + e); e.printStackTrace(); } } } // Split into training/testing if (trainingProportion.value > 0) { if (clusterings.size() > 1) throw new IllegalArgumentException("Expect one clustering to do train/test split, not " + clusterings.size()); Clustering clustering = clusterings.get(0); int targetTrainSize = (int)(trainingProportion.value * clustering.getNumInstances()); TIntHashSet clustersSampled = new TIntHashSet(); Randoms random = new Randoms(123); LabelAlphabet lalph = new LabelAlphabet(); InstanceList trainingInstances = new InstanceList(new Noop(null, lalph)); while (trainingInstances.size() < targetTrainSize) { int cluster = random.nextInt(clustering.getNumClusters()); if (!clustersSampled.contains(cluster)) { clustersSampled.add(cluster); InstanceList instances = clustering.getCluster(cluster); for (int i = 0; i < instances.size(); i++) { Instance inst = instances.get(i); trainingInstances.add(new Instance(inst.getData(), lalph.lookupLabel(new Integer(cluster)), inst.getName(), inst.getSource())); } } } trainingInstances.shuffle(random); Clustering trainingClustering = createSmallerClustering(trainingInstances); InstanceList testingInstances = new InstanceList(null, lalph); for (int i = 0; i < clustering.getNumClusters(); i++) { if (!clustersSampled.contains(i)) { InstanceList instances = clustering.getCluster(i); for (int j = 0; j < instances.size(); j++) { Instance inst = instances.get(j); testingInstances.add(new Instance(inst.getData(), lalph.lookupLabel(new Integer(i)), inst.getName(), inst.getSource())); } } } testingInstances.shuffle(random); Clustering testingClustering = createSmallerClustering(testingInstances); logger.info(outputPrefixFile.value + ".train : " + trainingClustering.getNumClusters() + " objects"); logger.info(outputPrefixFile.value + ".test : " + testingClustering.getNumClusters() + " objects"); if (outputPrefixFile.value != null) { try { ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(new File(outputPrefixFile.value + ".train"))); oos.writeObject(new Clusterings(new Clustering[]{trainingClustering})); oos.close(); oos = new ObjectOutputStream(new FileOutputStream(new File(outputPrefixFile.value + ".test"))); oos.writeObject(new Clusterings(new Clustering[]{testingClustering})); oos.close(); } catch (Exception e) { logger.warning("Exception writing clustering to file " + outputPrefixFile.value + " " + e); e.printStackTrace(); } } } } private static Clustering createSmallerClustering (InstanceList instances) { Clustering c = ClusterUtils.createSingletonClustering(instances); return ClusterUtils.mergeInstancesWithSameLabel(c); } static CommandOption.String inputFile = new CommandOption.String( Clusterings2Clusterings.class, "input", "FILENAME", true, "text.clusterings", "The filename from which to read the list of instances.", null); static CommandOption.String outputPrefixFile = new CommandOption.String( Clusterings2Clusterings.class, "output-prefix", "FILENAME", false, "text.clusterings", "The filename prefix to write output. Suffices 'train' and 'test' appended.", null); static CommandOption.Integer minClusterSize = new CommandOption.Integer(Clusterings2Clusterings.class, "min-cluster-size", "INTEGER", false, 1, "Remove clusters with fewer than this many Instances.", null); static CommandOption.Double trainingProportion = new CommandOption.Double(Clusterings2Clusterings.class, "training-proportion", "DOUBLE", false, 0.0, "Split into training and testing, with this percentage of instances reserved for training.", null); }