package cc.mallet.cluster.iterator; import java.util.ArrayList; import java.util.Iterator; import cc.mallet.cluster.Clustering; import cc.mallet.cluster.neighbor_evaluator.AgglomerativeNeighbor; import cc.mallet.cluster.util.ClusterUtils; import cc.mallet.types.Instance; import cc.mallet.util.Randoms; /** * Sample clusters of Instances. * * @author "Aron Culotta" <culotta@degas.cs.umass.edu> * @version 1.0 * @since 1.0 * @see PairSampleIterator, NeighborIterator */ public class ClusterSampleIterator extends PairSampleIterator { /** * * @param clustering True clustering. * @param random Source of randomness. * @param positiveProportion Proportion of Instances that should be positive examples. * @param numberSamples Total number of samples to generate. * @return */ public ClusterSampleIterator (Clustering clustering, Randoms random, double positiveProportion, int numberSamples) { super(clustering, random, positiveProportion, numberSamples); } public Instance next () { AgglomerativeNeighbor neighbor = null; if ((positiveCount < positiveTarget || clustering.getNumClusters() == 1) && nonsingletonClusters.length > 0) { positiveCount++; int label = nonsingletonClusters[random.nextInt(nonsingletonClusters.length)]; int[] instances = clustering.getIndicesWithLabel(label); int[][] clusters = sampleSplitFromArray(instances, random, 2); neighbor = new AgglomerativeNeighbor(clustering, clustering, clusters); } else { int labeli = random.nextInt(clustering.getNumClusters()); int labelj = random.nextInt(clustering.getNumClusters()); while (labeli == labelj) labelj = random.nextInt(clustering.getNumClusters()); neighbor = new AgglomerativeNeighbor(clustering, ClusterUtils.copyAndMergeClusters(clustering, labeli, labelj), sampleFromArray(clustering.getIndicesWithLabel(labeli), random, 1), sampleFromArray(clustering.getIndicesWithLabel(labelj), random, 1)); } totalCount++; return new Instance(neighbor, null, null, null); } /** * Samples a subset of elements from this array. * @param a * @param random * @return */ protected int[] sampleFromArray (int[] a, Randoms random, int minSize) { // Sample size. int size = Math.max(random.nextInt(a.length) + 1, minSize); ArrayList toInclude = new ArrayList(); for (int i = 0; i < a.length; i++) toInclude.add(new Integer(i)); while (toInclude.size() > size && (size != a.length)) toInclude.remove(random.nextInt(toInclude.size())); int[] ret = new int[toInclude.size()]; int i = 0; for (Iterator iter = toInclude.iterator(); iter.hasNext(); ) ret[i++] = a[((Integer)iter.next()).intValue()]; return ret; } /** * Samples a two disjoint subset of elements from this array. * @param a * @param random * @return */ protected int[][] sampleSplitFromArray (int[] a, Randoms random, int minSize) { // Sample size. int size = Math.max(random.nextInt(a.length) + 1, minSize); ArrayList toInclude = new ArrayList(); for (int i = 0; i < a.length; i++) toInclude.add(new Integer(i)); while (toInclude.size() > size && (size != a.length)) toInclude.remove(random.nextInt(toInclude.size())); int[][] ret = new int[2][]; int size1 = Math.max(random.nextInt(toInclude.size() - 1), 1); ret[0] = new int[size1]; ret[1] = new int[toInclude.size() - size1]; for (int i = 0; i < size1; i++) ret[0][i] = ((Integer)toInclude.get(i)).intValue(); int nadded = 0; for (int i = size1; i < toInclude.size(); i++) ret[1][nadded++] = ((Integer)toInclude.get(i)).intValue(); return ret; } }