package edu.cmu.minorthird.ui; import java.io.IOException; import java.io.Serializable; import edu.cmu.minorthird.classify.multi.MultiClassifiedDataset; import edu.cmu.minorthird.classify.multi.MultiClassifier; import edu.cmu.minorthird.classify.multi.MultiDataset; import edu.cmu.minorthird.classify.multi.MultiEvaluation; import edu.cmu.minorthird.classify.transform.AbstractInstanceTransform; import edu.cmu.minorthird.classify.transform.TransformingMultiClassifier; import edu.cmu.minorthird.text.learn.MultiClassifierAnnotator; import edu.cmu.minorthird.util.CommandLineProcessor; import edu.cmu.minorthird.util.IOUtil; import edu.cmu.minorthird.util.JointCommandLineProcessor; import edu.cmu.minorthird.util.gui.SmartVanillaViewer; import edu.cmu.minorthird.util.gui.ViewerFrame; /** * Test an existing text classifier for multiple labels. * * @author Cameron Williams */ public class TestMultiClassifier extends UIMain{ // private data needed to test a classifier private CommandLineUtil.SaveParams save=new CommandLineUtil.SaveParams(); private CommandLineUtil.MultiClassificationSignalParams signal= new CommandLineUtil.MultiClassificationSignalParams(base); private CommandLineUtil.TestClassifierParams test= new CommandLineUtil.TestClassifierParams(); private Object result=null; // for gui public CommandLineUtil.SaveParams getSaveParameters(){ return save; } public void setSaveParameters(CommandLineUtil.SaveParams p){ save=p; } public CommandLineUtil.MultiClassificationSignalParams getSignalParameters(){ return signal; } public void setSignalParameters( CommandLineUtil.MultiClassificationSignalParams p){ signal=p; } public CommandLineUtil.TestClassifierParams getAdditionalParameters(){ return test; } public void setAdditionalParameters(CommandLineUtil.TestClassifierParams p){ test=p; } @Override public CommandLineProcessor getCLP(){ return new JointCommandLineProcessor(new CommandLineProcessor[]{gui,base, save,signal,test}); } // // load and test a classifier // @Override public void doMain(){ // check that inputs are valid if(test.loadFrom==null) throw new IllegalArgumentException("-loadFrom must be specified"); // load the classifier MultiClassifierAnnotator ann=null; try{ ann=(MultiClassifierAnnotator)IOUtil.loadSerialized(test.loadFrom); }catch(IOException ex){ throw new IllegalArgumentException("can't load annotator from "+ test.loadFrom+": "+ex); } // do the testing and show the result MultiDataset d= CommandLineUtil.toMultiDataset(base.labels,ann.getFE(), signal.multiSpanProp); MultiClassifier multiClassifier=ann.getMultiClassifier(); if(signal.cross){ // d=d.annotateData(multiClassifier); if(multiClassifier instanceof TransformingMultiClassifier){ AbstractInstanceTransform transformer= ((TransformingMultiClassifier)multiClassifier).getTransform(); d=transformer.transform(d); }else{ throw new IllegalArgumentException( "Must be a TransformingMultiClassifier to use cross dimensions"); } } MultiEvaluation evaluation=null; if(test.showData) new ViewerFrame("Dataset",d.toGUI()); if(test.showClassifier) new ViewerFrame("Classifier",new SmartVanillaViewer(multiClassifier)); evaluation=new MultiEvaluation(d.getMultiSchema()); evaluation.extend(multiClassifier,d); if(test.showTestDetails){ result=new MultiClassifiedDataset(multiClassifier,d); }else{ result=evaluation; } if(base.showResult){ new ViewerFrame("Result",new SmartVanillaViewer(result)); } if(save.saveAs!=null){ try{ IOUtil.saveSerialized((Serializable)evaluation,save.saveAs); }catch(IOException e){ throw new IllegalArgumentException("can't save to "+save.saveAs+": "+e); } } evaluation.summarize(); } @Override public Object getMainResult(){ return result; } public static void main(String args[]){ new TestMultiClassifier().callMain(args); } }