package cc.mallet.cluster.tui; import java.io.BufferedWriter; import java.io.File; import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.FileWriter; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.util.ArrayList; import java.util.logging.Logger; import cc.mallet.classify.Classifier; import cc.mallet.classify.MaxEntTrainer; import cc.mallet.classify.Trial; import cc.mallet.cluster.Clusterer; import cc.mallet.cluster.Clustering; import cc.mallet.cluster.Clusterings; import cc.mallet.cluster.GreedyAgglomerativeByDensity; import cc.mallet.cluster.Record; import cc.mallet.cluster.evaluate.AccuracyEvaluator; import cc.mallet.cluster.evaluate.BCubedEvaluator; import cc.mallet.cluster.evaluate.ClusteringEvaluator; import cc.mallet.cluster.evaluate.ClusteringEvaluators; import cc.mallet.cluster.evaluate.MUCEvaluator; import cc.mallet.cluster.evaluate.PairF1Evaluator; import cc.mallet.cluster.iterator.PairSampleIterator; import cc.mallet.cluster.neighbor_evaluator.AgglomerativeNeighbor; import cc.mallet.cluster.neighbor_evaluator.NeighborEvaluator; import cc.mallet.cluster.neighbor_evaluator.PairwiseEvaluator; import cc.mallet.pipe.Pipe; import cc.mallet.types.Alphabet; import cc.mallet.types.FeatureVector; import cc.mallet.types.InfoGain; import cc.mallet.types.Instance; import cc.mallet.types.InstanceList; import cc.mallet.types.LabelAlphabet; import cc.mallet.util.CommandOption; import cc.mallet.util.MalletLogger; import cc.mallet.util.PropertyList; import cc.mallet.util.Randoms; import cc.mallet.util.Strings; //In progress public class Clusterings2Clusterer { private static Logger logger = MalletLogger.getLogger(Clusterings2Clusterer.class.getName()); public static void main(String[] args) throws Exception { CommandOption.setSummary(Clusterings2Clusterer.class, "A tool to train and test a Clusterer."); CommandOption.process(Clusterings2Clusterer.class, args); // TRAIN Randoms random = new Randoms(123); Clusterer clusterer = null; if (!loadClusterer.value.exists()) { Clusterings training = readClusterings(trainingFile.value); Alphabet fieldAlphabet = ((Record) training.get(0).getInstances() .get(0).getData()).fieldAlphabet(); Pipe pipe = new ClusteringPipe(string2ints(exactMatchFields.value, fieldAlphabet), string2ints(approxMatchFields.value, fieldAlphabet), string2ints(substringMatchFields.value, fieldAlphabet)); InstanceList trainingInstances = new InstanceList(pipe); for (int i = 0; i < training.size(); i++) { PairSampleIterator iterator = new PairSampleIterator(training .get(i), random, 0.5, training.get(i).getNumInstances()); while(iterator.hasNext()) { Instance inst = iterator.next(); trainingInstances.add(pipe.pipe(inst)); } } logger.info("generated " + trainingInstances.size() + " training instances"); Classifier classifier = new MaxEntTrainer().train(trainingInstances); logger.info("InfoGain:\n"); new InfoGain(trainingInstances).printByRank(System.out); logger.info("pairwise training accuracy=" + new Trial(classifier, trainingInstances).getAccuracy()); NeighborEvaluator neval = new PairwiseEvaluator(classifier, "YES", new PairwiseEvaluator.Average(), true); clusterer = new GreedyAgglomerativeByDensity( training.get(0).getInstances().getPipe(), neval, 0.5, false, random); training = null; trainingInstances = null; } else { ObjectInputStream ois = new ObjectInputStream(new FileInputStream(loadClusterer.value)); clusterer = (Clusterer) ois.readObject(); } // TEST Clusterings testing = readClusterings(testingFile.value); ClusteringEvaluator evaluator = (ClusteringEvaluator) clusteringEvaluatorOption.value; if (evaluator == null) evaluator = new ClusteringEvaluators( new ClusteringEvaluator[] { new BCubedEvaluator(), new PairF1Evaluator(), new MUCEvaluator(), new AccuracyEvaluator() }); ArrayList<Clustering> predictions = new ArrayList<Clustering>(); for (int i = 0; i < testing.size(); i++) { Clustering clustering = testing.get(i); Clustering predicted = clusterer.cluster(clustering.getInstances()); predictions.add(predicted); logger.info(evaluator.evaluate(clustering, predicted)); } logger.info(evaluator.evaluateTotals()); // WRITE OUTPUT ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(saveClusterer.value)); oos.writeObject(clusterer); oos.close(); if (outputClusterings.value != null) { BufferedWriter writer = new BufferedWriter(new FileWriter(new File(outputClusterings.value))); writer.write(predictions.toString()); writer.flush(); writer.close(); } } public static int[] string2ints(String[] ss, Alphabet alph) { int[] ret = new int[ss.length]; for (int i = 0; i < ss.length; i++) ret[i] = alph.lookupIndex(ss[i]); return ret; } public static Clusterings readClusterings(String f) throws Exception { ObjectInputStream ois = new ObjectInputStream(new FileInputStream( new File(f))); return (Clusterings) ois.readObject(); } static CommandOption.File loadClusterer = new CommandOption.File( Clusterings2Clusterer.class, "load-clusterer", "FILE", false, null, "The file from which to read the clusterer.", null); static CommandOption.File saveClusterer = new CommandOption.File( Clusterings2Clusterer.class, "save-clusterer", "FILE", false, new File("clusterer.mallet"), "The filename in which to write the clusterer after it has been trained.", null); static CommandOption.String outputClusterings = new CommandOption.String( Clusterings2Clusterer.class, "output-clusterings", "FILENAME", false, "predictions", "The filename in which to write the predicted clusterings.", null); static CommandOption.String trainingFile = new CommandOption.String( Clusterings2Clusterer.class, "train", "FILENAME", false, "text.clusterings.train", "Read the training set Clusterings from this file. " + "If this is specified, the input file parameter is ignored", null); static CommandOption.String testingFile = new CommandOption.String( Clusterings2Clusterer.class, "test", "FILENAME", false, "text.clusterings.test", "Read the test set Clusterings from this file. " + "If this option is specified, the training-file parameter must be specified and " + " the input-file parameter is ignored", null); static CommandOption.Object clusteringEvaluatorOption = new CommandOption.Object( Clusterings2Clusterer.class, "clustering-evaluator", "CONSTRUCTOR", true, null, "Java code for constructing a ClusteringEvaluator object", null); static CommandOption.SpacedStrings exactMatchFields = new CommandOption.SpacedStrings( Clusterings2Clusterer.class, "exact-match-fields", "STRING...", false, null, "The field names to be checked for exactly matching values", null); static CommandOption.SpacedStrings approxMatchFields = new CommandOption.SpacedStrings( Clusterings2Clusterer.class, "approx-match-fields", "STRING...", false, null, "The field names to be checked for approx matching values", null); static CommandOption.SpacedStrings substringMatchFields = new CommandOption.SpacedStrings( Clusterings2Clusterer.class, "substring-match-fields", "STRING...", false, null, "The field names to be checked for substring matching values. Note that values fewer than 3 characters are ignored.", null); public static class ClusteringPipe extends Pipe { private static final long serialVersionUID = 1L; int[] exactMatchFields; int[] approxMatchFields; int[] substringMatchFields; double approxMatchThreshold; public ClusteringPipe(int[] exactMatchFields, int[] approxMatchFields, int[] substringMatchFields) { super(new Alphabet(), new LabelAlphabet()); this.exactMatchFields = exactMatchFields; this.approxMatchFields = approxMatchFields; this.substringMatchFields = substringMatchFields; } private Record[] array2Records(int[] a, InstanceList list) { ArrayList<Record> records = new ArrayList<Record>(); for (int i = 0; i < a.length; i++) records.add((Record) list.get(a[i]).getData()); return (Record[]) records.toArray(new Record[] {}); } public Instance pipe(Instance carrier) { AgglomerativeNeighbor neighbor = (AgglomerativeNeighbor) carrier .getData(); Clustering original = neighbor.getOriginal(); int[] cluster1 = neighbor.getOldClusters()[0]; int[] cluster2 = neighbor.getOldClusters()[1]; InstanceList list = original.getInstances(); int[] mergedIndices = neighbor.getNewCluster(); Record[] records = array2Records(mergedIndices, list); Alphabet fieldAlph = records[0].fieldAlphabet(); Alphabet valueAlph = records[0].valueAlphabet(); PropertyList features = null; features = addExactMatch(records, fieldAlph, valueAlph, features); features = addApproxMatch(records, fieldAlph, valueAlph, features); features = addSubstringMatch(records, fieldAlph, valueAlph, features); carrier .setData(new FeatureVector(getDataAlphabet(), features, true)); LabelAlphabet ldict = (LabelAlphabet) getTargetAlphabet(); String label = (original.getLabel(cluster1[0]) == original .getLabel(cluster2[0])) ? "YES" : "NO"; carrier.setTarget(ldict.lookupLabel(label)); return carrier; } private PropertyList addExactMatch(Record[] records, Alphabet fieldAlph, Alphabet valueAlph, PropertyList features) { for (int fi = 0; fi < exactMatchFields.length; fi++) { int matches = 0; int comparisons = 0; for (int i = 0; i < records.length && exactMatchFields.length > 0; i++) { FeatureVector valsi = records[i] .values(exactMatchFields[fi]); for (int j = i + 1; j < records.length && valsi != null; j++) { FeatureVector valsj = records[j] .values(exactMatchFields[fi]); if (valsj != null) { comparisons++; for (int ii = 0; ii < valsi.numLocations(); ii++) { if (valsj.contains(valueAlph.lookupObject(valsi .indexAtLocation(ii)))) { matches++; break; } } } } if (matches == comparisons && comparisons > 1) features = PropertyList.add(fieldAlph .lookupObject(exactMatchFields[fi]) + "_all_match", 1.0, features); if (matches > 0) features = PropertyList.add(fieldAlph .lookupObject(exactMatchFields[fi]) + "_exists_match", 1.0, features); } } return features; } private PropertyList addApproxMatch(Record[] records, Alphabet fieldAlph, Alphabet valueAlph, PropertyList features) { for (int fi = 0; fi < approxMatchFields.length; fi++) { int matches = 0; int comparisons = 0; for (int i = 0; i < records.length && approxMatchFields.length > 0; i++) { FeatureVector valsi = records[i] .values(approxMatchFields[fi]); for (int j = i + 1; j < records.length && valsi != null; j++) { FeatureVector valsj = records[j] .values(approxMatchFields[fi]); if (valsj != null) { comparisons++; for (int ii = 0; ii < valsi.numLocations(); ii++) { String si = (String) valueAlph .lookupObject(valsi.indexAtLocation(ii)); for (int jj = 0; jj < valsj.numLocations(); jj++) { String sj = (String) valueAlph .lookupObject(valsj .indexAtLocation(jj)); if (Strings.levenshteinDistance(si, sj) < approxMatchThreshold) { matches++; break; } } } } } if (matches == comparisons && comparisons > 1) features = PropertyList.add(fieldAlph .lookupObject(approxMatchFields[fi]) + "_all_approx_match", 1.0, features); if (matches > 0) features = PropertyList.add(fieldAlph .lookupObject(approxMatchFields[fi]) + "_exists_approx_match", 1.0, features); } } return features; } private PropertyList addSubstringMatch(Record[] records, Alphabet fieldAlph, Alphabet valueAlph, PropertyList features) { for (int fi = 0; fi < substringMatchFields.length; fi++) { int matches = 0; int comparisons = 0; for (int i = 0; i < records.length && substringMatchFields.length > 0; i++) { FeatureVector valsi = records[i] .values(substringMatchFields[fi]); for (int j = i + 1; j < records.length && valsi != null; j++) { FeatureVector valsj = records[j] .values(substringMatchFields[fi]); if (valsj != null) { comparisons++; for (int ii = 0; ii < valsi.numLocations(); ii++) { String si = (String) valueAlph .lookupObject(valsi.indexAtLocation(ii)); if (si.length() < 2) break; for (int jj = 0; jj < valsj.numLocations(); jj++) { String sj = (String) valueAlph .lookupObject(valsj .indexAtLocation(jj)); if (sj.length() > 2 && (si.contains(si) || sj.contains(si))) { matches++; break; } } } } } if (matches == comparisons && comparisons > 1) features = PropertyList.add(fieldAlph .lookupObject(exactMatchFields[fi]) + "_all_substring_match", 1.0, features); if (matches > 0) features = PropertyList.add(fieldAlph .lookupObject(exactMatchFields[fi]) + "_exists_substring_match", 1.0, features); } } return features; } } }