package edu.usc.cssl.tacit.classify.naivebayes.services; /* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept. This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit). http://www.cs.umass.edu/~mccallum/mallet This software is provided under the terms of the Common Public License, version 1.0, as published by http://www.opensource.org. For further information, see the file `LICENSE' included with this distribution. */ import java.io.Serializable; import java.util.ArrayList; import java.util.List; import java.util.NoSuchElementException; import cc.mallet.types.InstanceList; import cc.mallet.types.MultiInstanceList; /** * An iterator which splits an {@link InstanceList} into n-folds and iterates * over the folds for use in n-fold cross-validation. For each iteration, * list[0] contains a {@link InstanceList} with n-1 folds typically used for * training and list[1] contains an {@link InstanceList} with 1 fold typically * used for validation. * * This class uses {@link MultiInstanceList} to avoid creating a new * {@link InstanceList} each iteration. * * TODO - currently the distribution is completely random, an improvement would * be to provide a stratified random distribution. * * @see MultiInstanceList * @see InstanceList * * @author Aron Culotta <a href="mailto:culotta@cs.umass.edu">culotta@cs.umass.edu</a> */ public class CrossValidationIterator implements java.util.Iterator<InstanceList[]>, Serializable { private static final long serialVersionUID = 234516468015114991L; private final int nfolds; private final InstanceList[] folds; private int index; /** * Constructs a new n-fold cross-validation iterator * * @param ilist instance list to split into folds and iterate over * @param nfolds number of folds to split InstanceList into * @param r The source of randomness to use in shuffling. */ public CrossValidationIterator (InstanceList ilist, int nfolds, java.util.Random r) { assert (nfolds > 0) : "nfolds: " + this.nfolds; this.nfolds = nfolds; this.index = 0; double fraction = (double) 1 / nfolds; double[] proportions = new double[nfolds]; for (int i=0; i < nfolds; i++) { proportions[i] = fraction; } this.folds = ilist.split (r, proportions); } /** * Constructs a new n-fold cross-validation iterator * * @param ilist instance list to split into folds and iterate over * @param _nfolds number of folds to split InstanceList into */ public CrossValidationIterator (InstanceList ilist, int _nfolds) { this (ilist, _nfolds, new java.util.Random (System.currentTimeMillis ())); } /** * Calls clear on each fold. It is recommended that this be always be called * when the iterator is no longer needed so that implementations of * InstanceList such as PagedInstanceList can clean up any temporary data * they may have outside the JVM. */ public void clear () { for (InstanceList list : this.folds) { list.clear(); } } @Override public boolean hasNext () { return this.index < this.nfolds; } /** * Returns the next training/testing split. * * @return A two element array of {@link InstanceList}, where * <code>InstanceList[0]</code> contains n-1 folds for training and * <code>InstanceList[1]</code> contains 1 fold for testing. */ public InstanceList[] nextSplit () { if (!hasNext()) { throw new NoSuchElementException(); } InstanceList[] ret = new InstanceList[2]; if (this.folds.length == 1) { ret[0] = this.folds[0]; ret[1] = this.folds[0]; } else { InstanceList[] training = new InstanceList[this.folds.length - 1]; int j = 0; for (int i = 0; i < this.folds.length; i++) { if (i == this.index) { continue; } training[j++] = this.folds[i]; } ret[0] = new MultiInstanceList (training); ret[1] = this.folds[this.index]; } this.index++; return ret; } /** * Returns the next training/testing split. * * @return A two element array of {@link InstanceList}, where * <code>InstanceList[0]</code> contains <code>numTrainingFolds</code> * folds for training and <code>InstanceList[1]</code> contains * n - <code>numTrainingFolds</code> folds for testing. */ public InstanceList[] nextSplit (int numTrainFolds) { if (!hasNext()) { throw new NoSuchElementException (); } List<InstanceList> trainingSet = new ArrayList<InstanceList> (); List<InstanceList> testSet = new ArrayList<InstanceList> (); // train on folds [index, index+numTrainFolds), test on rest for (int i = 0; i < this.folds.length; i++) { int foldno = (this.index + i) % this.folds.length; if (i < numTrainFolds) { trainingSet.add (this.folds[foldno]); } else { testSet.add (this.folds[foldno]); } } InstanceList[] ret = new InstanceList[2]; ret[0] = new MultiInstanceList (trainingSet); ret[1] = new MultiInstanceList (testSet); this.index++; return ret; } /** * Returns the next training/testing split. * * @see java.util.Iterator#next() * @return A two element array of {@link InstanceList}, where * <code>InstanceList[0]</code> contains n-1 folds for training and * <code>InstanceList[1]</code> contains 1 fold for testing. */ @Override public InstanceList[] next () { return nextSplit(); } @Override public void remove () { throw new UnsupportedOperationException (); } }