package edu.cmu.minorthird.ui;
import java.io.IOException;
import org.apache.log4j.Logger;
import edu.cmu.minorthird.classify.Dataset;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.experiments.CrossValidatedDataset;
import edu.cmu.minorthird.classify.experiments.Evaluation;
import edu.cmu.minorthird.classify.experiments.FixedTestSetSplitter;
import edu.cmu.minorthird.classify.experiments.Tester;
import edu.cmu.minorthird.ui.CommandLineUtil.ClassificationSignalParams;
import edu.cmu.minorthird.ui.CommandLineUtil.SaveParams;
import edu.cmu.minorthird.ui.CommandLineUtil.SplitterParams;
import edu.cmu.minorthird.ui.CommandLineUtil.TrainClassifierParams;
import edu.cmu.minorthird.util.CommandLineProcessor;
import edu.cmu.minorthird.util.IOUtil;
import edu.cmu.minorthird.util.JointCommandLineProcessor;
import edu.cmu.minorthird.util.gui.SmartVanillaViewer;
import edu.cmu.minorthird.util.gui.ViewerFrame;
// to do:
// show labels should be a better viewer
// -baseType type
/**
* Do a train/test experiment on a text classifier.
*
* @author William Cohen
*/
public class TrainTestClassifier extends UIMain{
protected static Logger log=Logger.getLogger(TrainTestClassifier.class);
// private data needed to train a classifier
protected SaveParams save=new SaveParams();
protected ClassificationSignalParams signal=new ClassificationSignalParams(base);
protected TrainClassifierParams train=new TrainClassifierParams();
protected SplitterParams trainTest=new SplitterParams();
protected Object result=null;
// for GUI
public SaveParams getSaveParameters(){
return save;
}
public void setSaveParameters(SaveParams save){
this.save=save;
}
public ClassificationSignalParams getSignalParameters(){
return signal;
}
public void setSignalParameters(ClassificationSignalParams signal){
this.signal=signal;
}
public TrainClassifierParams getTrainingParameters(){
return train;
}
public void setTrainingParameters(TrainClassifierParams train){
this.train=train;
}
public SplitterParams getSplitterParameters(){
return trainTest;
}
public void setSplitterParameters(CommandLineUtil.SplitterParams trainTest){
this.trainTest=trainTest;
}
@Override
public CommandLineProcessor getCLP(){
return new JointCommandLineProcessor(new CommandLineProcessor[]{gui,base,
save,signal,train,trainTest});
}
public String getTrainTestClassifierHelp(){
return "<A HREF=\"http://minorthird.sourceforge.net/tutorials/TrainTestClassifier%20Tutorial.htm\">TrainTestClassifier Tutorial</A></html>";
}
@Override
public void doMain(){
// check that inputs are valid
if(train.learner==null)
throw new IllegalArgumentException("-learner must be specified");
if(signal.spanProp==null&&signal.spanType==null)
throw new IllegalArgumentException(
"one of -spanProp or -spanType must be specified");
if(signal.spanProp!=null&&signal.spanType!=null)
throw new IllegalArgumentException(
"only one of -spanProp or -spanType can be specified");
// construct the dataset
Dataset d=
CommandLineUtil.toDataset(base.labels,train.fe,signal.spanProp,
signal.spanType,signal.candidateType);
// show the data if necessary
if(train.showData)
new ViewerFrame("Dataset",d.toGUI());
// construct the splitter, if necessary
if(trainTest.labels!=null){
if(signal.spanPropString!=null)
CommandLineUtil.createSpanProp(signal.spanPropString,trainTest.labels);
Dataset testData=
CommandLineUtil.toDataset(trainTest.labels,train.fe,signal.spanProp,
signal.spanType,signal.candidateType);
trainTest.splitter=new FixedTestSetSplitter<Example>(testData.iterator());
}
// do the experiment
CrossValidatedDataset cvd=null;
Evaluation evaluation=null;
if(trainTest.showTestDetails){
cvd=new CrossValidatedDataset(train.learner,d,trainTest.splitter);
evaluation=cvd.getEvaluation();
result=cvd;
}else{
cvd=null;
evaluation=Tester.evaluate(train.learner,d,trainTest.splitter);
result=evaluation;
}
if(base.showResult){
new ViewerFrame("Result",new SmartVanillaViewer(result));
}
if(save.saveAs!=null){
try{
IOUtil.saveSerialized(evaluation,save.saveAs);
}catch(IOException e){
throw new IllegalArgumentException("can't save to "+save.saveAs+": "+e);
}
}
evaluation.summarize();
}
@Override
public Object getMainResult(){
return result;
}
public static void main(String args[]){
new TrainTestClassifier().callMain(args);
}
}