package edu.cmu.minorthird.ui; import java.io.IOException; import edu.cmu.minorthird.classify.multi.MultiClassifier; import edu.cmu.minorthird.classify.multi.MultiDataset; import edu.cmu.minorthird.classify.multi.MultiDatasetClassifierTeacher; import edu.cmu.minorthird.classify.transform.AbstractInstanceTransform; import edu.cmu.minorthird.classify.transform.PredictedClassTransform; import edu.cmu.minorthird.classify.transform.TransformingMultiClassifier; import edu.cmu.minorthird.text.learn.MultiClassifierAnnotator; import edu.cmu.minorthird.util.CommandLineProcessor; import edu.cmu.minorthird.util.IOUtil; import edu.cmu.minorthird.util.JointCommandLineProcessor; import edu.cmu.minorthird.util.gui.SmartVanillaViewer; import edu.cmu.minorthird.util.gui.Viewer; import edu.cmu.minorthird.util.gui.ViewerFrame; /** * Train a text classifier. * * @author William Cohen */ public class TrainMultiClassifier extends UIMain{ // private data needed to train a classifier private CommandLineUtil.SaveParams save=new CommandLineUtil.SaveParams(); private CommandLineUtil.MultiClassificationSignalParams signal= new CommandLineUtil.MultiClassificationSignalParams(base); private CommandLineUtil.TrainClassifierParams train= new CommandLineUtil.TrainClassifierParams(); private MultiClassifier classifier=null; public CommandLineUtil.SaveParams getSaveParameters(){ return save; } public void setSaveParameters(CommandLineUtil.SaveParams p){ save=p; } public CommandLineUtil.MultiClassificationSignalParams getSignalParameters(){ return signal; } public void setSignalParameters( CommandLineUtil.MultiClassificationSignalParams p){ signal=p; } public CommandLineUtil.TrainClassifierParams getAdditionalParameters(){ return train; } public void setAdditionalParameters(CommandLineUtil.TrainClassifierParams p){ train=p; } @Override public CommandLineProcessor getCLP(){ return new JointCommandLineProcessor(new CommandLineProcessor[]{gui,base, save,signal,train}); } // // do the experiment // @Override public void doMain(){ // check that inputs are valid if(train.learner==null) throw new IllegalArgumentException("-learner must be specified"); if(signal.multiSpanProp==null) throw new IllegalArgumentException("-multiSpanProp must be specified"); // construct the dataset MultiDataset d= CommandLineUtil.toMultiDataset(base.labels,train.fe, signal.multiSpanProp); if(signal.cross) d=d.annotateData(); if(train.showData){ System.out.println("Trying to show the Dataset"); new ViewerFrame("Dataset",d.toGUI()); } // train the classifier classifier=new MultiDatasetClassifierTeacher(d).train(train.learner); // create a transforming multiClassifier if cross if(signal.cross){ AbstractInstanceTransform transformer= new PredictedClassTransform(classifier); classifier=new TransformingMultiClassifier(classifier,transformer); } if(base.showResult){ Viewer cv=new SmartVanillaViewer(); if(classifier instanceof TransformingMultiClassifier) cv.setContent(classifier); else cv.setContent(classifier); new ViewerFrame("Classifier",cv); } MultiClassifierAnnotator ann= new MultiClassifierAnnotator(train.fe,classifier,signal.multiSpanProp); if(save.saveAs!=null){ try{ IOUtil.saveSerialized(ann,save.saveAs); }catch(IOException e){ throw new IllegalArgumentException("can't save to "+save.saveAs+": "+e); } } } @Override public Object getMainResult(){ return classifier; } public static void main(String args[]){ new TrainMultiClassifier().callMain(args); } }