package edu.cmu.minorthird.ui; import java.io.IOException; import java.io.Serializable; import edu.cmu.minorthird.classify.experiments.FixedTestSetSplitter; import edu.cmu.minorthird.classify.experiments.Tester; import edu.cmu.minorthird.classify.multi.MultiCrossValidatedDataset; import edu.cmu.minorthird.classify.multi.MultiDataset; import edu.cmu.minorthird.classify.multi.MultiEvaluation; import edu.cmu.minorthird.classify.multi.MultiExample; import edu.cmu.minorthird.text.MonotonicTextLabels; import edu.cmu.minorthird.util.CommandLineProcessor; import edu.cmu.minorthird.util.IOUtil; import edu.cmu.minorthird.util.JointCommandLineProcessor; import edu.cmu.minorthird.util.gui.SmartVanillaViewer; import edu.cmu.minorthird.util.gui.ViewerFrame; // to do: // show labels should be a better viewer // -baseType type /** * Do a train/test experiment on a text classifier for data with multiple labels. * * @author Cameron Williams */ public class TrainTestMultiClassifier extends UIMain{ // private data needed to train a classifier protected CommandLineUtil.SaveParams save=new CommandLineUtil.SaveParams(); protected CommandLineUtil.MultiClassificationSignalParams signal= new CommandLineUtil.MultiClassificationSignalParams(base); protected CommandLineUtil.TrainClassifierParams train= new CommandLineUtil.TrainClassifierParams(); protected CommandLineUtil.SplitterParams trainTest= new CommandLineUtil.SplitterParams(); protected Object result=null; // for GUI public CommandLineUtil.SaveParams getSaveParameters(){ return save; } public void setSaveParameters(CommandLineUtil.SaveParams base){ this.save=base; } public CommandLineUtil.MultiClassificationSignalParams getSignalParameters(){ return signal; } public void setSignalParameters( CommandLineUtil.MultiClassificationSignalParams base){ this.signal=base; } public CommandLineUtil.TrainClassifierParams getTrainingParameters(){ return train; } public void setTrainingParameters(CommandLineUtil.TrainClassifierParams train){ this.train=train; } public CommandLineUtil.SplitterParams getSplitterParameters(){ return trainTest; } public void setSplitterParameters(CommandLineUtil.SplitterParams trainTest){ this.trainTest=trainTest; } @Override public CommandLineProcessor getCLP(){ return new JointCommandLineProcessor(new CommandLineProcessor[]{gui,base, save,signal,train,trainTest}); } // // do the experiment // @Override public void doMain(){ // check that inputs are valid if(train.learner==null) throw new IllegalArgumentException("-learner must be specified"); if(signal.multiSpanProp==null) throw new IllegalArgumentException("-multiSpanProp must be specified"); // construct the dataset MultiDataset d= CommandLineUtil.toMultiDataset(base.labels,train.fe, signal.multiSpanProp); // show the data if necessary if(train.showData) new ViewerFrame("Dataset",d.toGUI()); // construct the splitter, if necessary if(trainTest.labels!=null){ MultiDataset testData= CommandLineUtil.toMultiDataset(trainTest.labels, train.fe,signal.multiSpanProp); trainTest.splitter=new FixedTestSetSplitter<MultiExample>(testData.multiIterator()); } // do the experiment MultiCrossValidatedDataset cvd=null; MultiEvaluation evaluation=null; if(trainTest.showTestDetails){ cvd= new MultiCrossValidatedDataset(train.learner,d,trainTest.splitter, false,signal.cross); evaluation=cvd.getEvaluation(); result=cvd; }else{ cvd=null; evaluation= Tester.multiEvaluate(train.learner,d,trainTest.splitter,signal.cross); result=evaluation; } if(base.showResult){ new ViewerFrame("Result",new SmartVanillaViewer(result)); } if(save.saveAs!=null){ try{ IOUtil.saveSerialized((Serializable)evaluation,save.saveAs); }catch(IOException e){ throw new IllegalArgumentException("can't save to "+save.saveAs+": "+e); } } evaluation.summarize(); } @Override public Object getMainResult(){ return result; } public static void main(String args[]){ new TrainTestMultiClassifier().callMain(args); } }