/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept. This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit). http://www.cs.umass.edu/~mccallum/mallet This software is provided under the terms of the Common Public License, version 1.0, as published by http://www.opensource.org. For further information, see the file `LICENSE' included with this distribution. */ /** @author Andrew McCallum <a href="mailto:mccallum@cs.umass.edu">mccallum@cs.umass.edu</a> */ package cc.mallet.pipe.iterator; import java.net.URI; import java.util.Iterator; import java.util.logging.*; import cc.mallet.pipe.Pipe; import cc.mallet.pipe.iterator.PipeInputIterator; import cc.mallet.types.Alphabet; import cc.mallet.types.Dirichlet; import cc.mallet.types.Instance; import cc.mallet.types.Label; import cc.mallet.types.Multinomial; import cc.mallet.types.TokenSequence; import cc.mallet.util.MalletLogger; import cc.mallet.util.Randoms; public class RandomTokenSequenceIterator implements Iterator<Instance> { private static Logger logger = MalletLogger.getLogger(RandomTokenSequenceIterator.class.getName()); Randoms r; Dirichlet classCentroidDistribution; double classCentroidAvergeAlphaMean; double classCentroidAvergeAlphaVariance; double featureVectorSizePoissonLambda; double classInstanceCountPoissonLamba; String[] classNames; int[] numInstancesPerClass; // indexed over classes Dirichlet[] classCentroid; // indexed over classes int currentClassIndex; int currentInstanceIndex; public RandomTokenSequenceIterator (Randoms r, // the generator of all random-ness used here Dirichlet classCentroidDistribution, // includes a Alphabet double classCentroidAvergeAlphaMean, // Gaussian mean on the sum of alphas double classCentroidAvergeAlphaVariance, // Gaussian variance on the sum of alphas double featureVectorSizePoissonLambda, double classInstanceCountPoissonLamba, String[] classNames) { this.r = r; this.classCentroidDistribution = classCentroidDistribution; assert (classCentroidDistribution.getAlphabet() instanceof Alphabet); this.classCentroidAvergeAlphaMean = classCentroidAvergeAlphaMean; this.classCentroidAvergeAlphaVariance = classCentroidAvergeAlphaVariance; this.featureVectorSizePoissonLambda = featureVectorSizePoissonLambda; this.classInstanceCountPoissonLamba = classInstanceCountPoissonLamba; this.classNames = classNames; this.numInstancesPerClass = new int[classNames.length]; this.classCentroid = new Dirichlet[classNames.length]; for (int i = 0; i < classNames.length; i++) { logger.fine ("classCentroidAvergeAlphaMean = "+classCentroidAvergeAlphaMean); double aveAlpha = r.nextGaussian (classCentroidAvergeAlphaMean, classCentroidAvergeAlphaVariance); logger.fine ("aveAlpha = "+aveAlpha); classCentroid[i] = classCentroidDistribution.randomDirichlet (r, aveAlpha); //logger.fine ("Dirichlet for class "+classNames[i]); classCentroid[i].print(); } reset (); } public RandomTokenSequenceIterator (Randoms r, Alphabet vocab, String[] classnames) { this (r, new Dirichlet(vocab, 2.0), 30, 0, 10, 20, classnames); } public Alphabet getAlphabet () { return classCentroidDistribution.getAlphabet(); } private static Alphabet dictOfSize (int size) { Alphabet ret = new Alphabet (); for (int i = 0; i < size; i++) ret.lookupIndex ("feature"+i); return ret; } private static String[] classNamesOfSize (int size) { String[] ret = new String[size]; for (int i = 0; i < size; i++) ret[i] = "class"+i; return ret; } public RandomTokenSequenceIterator (Randoms r, int vocabSize, int numClasses) { this (r, new Dirichlet(dictOfSize(vocabSize), 2.0), 30, 0, 10, 20, classNamesOfSize(numClasses)); } public void reset () { for (int i = 0; i < classNames.length; i++) { this.numInstancesPerClass[i] = r.nextPoisson (classInstanceCountPoissonLamba); logger.fine ("Class "+classNames[i]+" will have " +numInstancesPerClass[i]+" instances."); } this.currentClassIndex = classNames.length - 1; this.currentInstanceIndex = numInstancesPerClass[currentClassIndex] - 1; } public Instance next () { if (currentInstanceIndex < 0) { if (currentClassIndex <= 0) throw new IllegalStateException ("No next TokenSequence."); currentClassIndex--; currentInstanceIndex = numInstancesPerClass[currentClassIndex] - 1; } URI uri = null; try { uri = new URI ("random:" + classNames[currentClassIndex] + "/" + currentInstanceIndex); } catch (Exception e) {e.printStackTrace(); throw new IllegalStateException (); } //xxx Producing small numbers? int randomSize = r.nextPoisson (featureVectorSizePoissonLambda); int randomSize = (int)featureVectorSizePoissonLambda; TokenSequence ts = classCentroid[currentClassIndex].randomTokenSequence (r, randomSize); //logger.fine ("FeatureVector "+currentClassIndex+" "+currentInstanceIndex); fv.print(); currentInstanceIndex--; return new Instance (ts, classNames[currentClassIndex], uri, null); } public boolean hasNext () { return ! (currentClassIndex <= 0 && currentInstanceIndex <= 0); } public void remove () { throw new IllegalStateException ("This Iterator<Instance> does not support remove()."); } }