package edu.cmu.minorthird.classify.sequential;
import org.apache.log4j.Logger;
import edu.cmu.minorthird.classify.Dataset;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.Splitter;
import edu.cmu.minorthird.classify.experiments.Evaluation;
import edu.cmu.minorthird.util.ProgressCounter;
import edu.cmu.minorthird.util.gui.ParallelViewer;
import edu.cmu.minorthird.util.gui.TransformedViewer;
import edu.cmu.minorthird.util.gui.Viewer;
import edu.cmu.minorthird.util.gui.Visible;
/**
* View result of some sort of train/test experiment
* on sequential data.
*
* @author William Cohen
*/
public class CrossValidatedSequenceDataset implements Visible
{
static private Logger log = Logger.getLogger(CrossValidatedSequenceDataset.class);
private ClassifiedSequenceDataset[] cds;
private ClassifiedSequenceDataset[] trainCds;
private Evaluation v;
public CrossValidatedSequenceDataset(
SequenceClassifierLearner learner,SequenceDataset d,Splitter<Example[]> splitter)
{
this(learner,d,splitter,false);
}
public CrossValidatedSequenceDataset(
SequenceClassifierLearner learner,SequenceDataset d,Splitter<Example[]> splitter,boolean saveTrainPartitions)
{
Dataset.Split s = d.splitSequence(splitter);
cds = new ClassifiedSequenceDataset[s.getNumPartitions()];
trainCds = saveTrainPartitions ? new ClassifiedSequenceDataset[s.getNumPartitions()] : null;
v = new Evaluation(d.getSchema());
ProgressCounter pc = new ProgressCounter("train/test","fold",s.getNumPartitions());
for (int k=0; k<s.getNumPartitions(); k++) {
SequenceDataset trainData = (SequenceDataset)s.getTrain(k);
SequenceDataset testData = (SequenceDataset)s.getTest(k);
log.info("splitting with "+splitter+", preparing to train on "+trainData.size()
+" and test on "+testData.size());
//showSubpops("subpops for test fold "+k+": ", testData);
SequenceClassifier c = new DatasetSequenceClassifierTeacher(trainData).train(learner);
cds[k] = new ClassifiedSequenceDataset(c, testData);
if (trainCds!=null) trainCds[k] = new ClassifiedSequenceDataset(c, trainData);
v.extend( cds[k].getClassifier(), testData, Evaluation.DEFAULT_PARTITION_ID );
log.info("splitting with "+splitter+", stored classified dataset");
pc.progress();
}
pc.finished();
}
public Evaluation getEvaluation()
{
return v;
}
// private void showSubpops(String msg,SequenceDataset d)
// {
// Set ids = new TreeSet();
// for (Iterator<Example> i=d.iterator(); i.hasNext(); ) {
// Example e = i.next();
// ids.add(e.getSubpopulationId());
// }
// log.debug(msg+ids.toString());
// }
@Override
public Viewer toGUI()
{
ParallelViewer main = new ParallelViewer();
for (int i=0; i<cds.length; i++) {
final int k = i;
main.addSubView(
"Test Partition "+(i+1),
new TransformedViewer(cds[0].toGUI()) {
static final long serialVersionUID=20080207L;
@Override
public Object transform(Object o) {
// CrossValidatedSequenceDataset cvd = (CrossValidatedSequenceDataset)o;
return cds[k];
}});
}
if (trainCds!=null) {
for (int i=0; i<trainCds.length; i++) {
final int k = i;
main.addSubView(
"Train Partition "+(i+1),
new TransformedViewer(cds[0].toGUI()) {
static final long serialVersionUID=20080207L;
@Override
public Object transform(Object o) {
// CrossValidatedSequenceDataset cvd = (CrossValidatedSequenceDataset)o;
return trainCds[k];
}});
}
}
main.addSubView(
"Overall Evaluation",
new TransformedViewer(v.toGUI()) {
static final long serialVersionUID=20080207L;
@Override
public Object transform(Object o) {
CrossValidatedSequenceDataset cvd = (CrossValidatedSequenceDataset)o;
return cvd.v;
}
});
main.setContent(this);
return main;
}
public static void main(String[] args)
{
}
}