/* Copyright 2003, Carnegie Mellon, All Rights Reserved */ package edu.cmu.minorthird.classify; import java.awt.GridBagConstraints; import java.awt.GridBagLayout; import java.awt.event.ActionEvent; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.PrintStream; import java.util.Iterator; import javax.swing.AbstractAction; import javax.swing.JButton; import javax.swing.JComponent; import javax.swing.JLabel; import javax.swing.JPanel; import javax.swing.JProgressBar; import javax.swing.border.TitledBorder; import org.apache.log4j.Logger; import edu.cmu.minorthird.classify.algorithms.knn.KnnLearner; import edu.cmu.minorthird.classify.algorithms.linear.MaxEntLearner; import edu.cmu.minorthird.classify.algorithms.linear.NaiveBayes; import edu.cmu.minorthird.classify.algorithms.linear.VotedPerceptron; import edu.cmu.minorthird.classify.algorithms.svm.SVMLearner; import edu.cmu.minorthird.classify.algorithms.trees.AdaBoost; import edu.cmu.minorthird.classify.algorithms.trees.DecisionTreeLearner; import edu.cmu.minorthird.classify.experiments.ClassifiedDataset; import edu.cmu.minorthird.classify.experiments.CrossValSplitter; import edu.cmu.minorthird.classify.experiments.Evaluation; import edu.cmu.minorthird.classify.experiments.RandomSplitter; import edu.cmu.minorthird.classify.experiments.StratifiedCrossValSplitter; import edu.cmu.minorthird.classify.multi.InstanceFromPrediction; import edu.cmu.minorthird.classify.multi.MultiClassLabel; import edu.cmu.minorthird.classify.multi.MultiClassifiedDataset; import edu.cmu.minorthird.classify.multi.MultiClassifier; import edu.cmu.minorthird.classify.multi.MultiDataset; import edu.cmu.minorthird.classify.multi.MultiEvaluation; import edu.cmu.minorthird.classify.multi.MultiExample; import edu.cmu.minorthird.classify.sequential.ClassifiedSequenceDataset; import edu.cmu.minorthird.classify.sequential.CollinsPerceptronLearner; import edu.cmu.minorthird.classify.sequential.GenericCollinsLearner; import edu.cmu.minorthird.classify.sequential.SequenceClassifier; import edu.cmu.minorthird.classify.sequential.SequenceDataset; import edu.cmu.minorthird.classify.transform.FrequencyBasedTransformLearner; import edu.cmu.minorthird.classify.transform.InfoGainTransformLearner2; import edu.cmu.minorthird.classify.transform.T1InstanceTransformLearner; import edu.cmu.minorthird.classify.transform.TFIDFTransformLearner; import edu.cmu.minorthird.classify.transform.TransformingBatchLearner; import edu.cmu.minorthird.util.BasicCommandLineProcessor; import edu.cmu.minorthird.util.CommandLineProcessor; import edu.cmu.minorthird.util.IOUtil; import edu.cmu.minorthird.util.JointCommandLineProcessor; import edu.cmu.minorthird.util.ProgressCounter; import edu.cmu.minorthird.util.StringUtil; import edu.cmu.minorthird.util.gui.ComponentViewer; import edu.cmu.minorthird.util.gui.Console; import edu.cmu.minorthird.util.gui.SmartVanillaViewer; import edu.cmu.minorthird.util.gui.TypeSelector; import edu.cmu.minorthird.util.gui.Viewer; import edu.cmu.minorthird.util.gui.ViewerFrame; /** * Main UI program for the 'classify' package. * * @author William Cohen */ public class Test { private static Logger log = Logger.getLogger(UI.class); private static final Class<?>[] SELECTABLE_TYPES = new Class[]{ DataClassificationTask.class, //ClassifyCommandLineUtil.SimpleTrainParams.class, ClassifyCommandLineUtil.MultiTrainParams.class, ClassifyCommandLineUtil.SeqTrainParams.class, ClassifyCommandLineUtil.SimpleTestParams.class, ClassifyCommandLineUtil.MultiTestParams.class, ClassifyCommandLineUtil.SeqTestParams.class, ClassifyCommandLineUtil.Learner.SequentialLearner.class, ClassifyCommandLineUtil.Learner.ClassifierLearner.class, KnnLearner.class, NaiveBayes.class, VotedPerceptron.class, SVMLearner.class, DecisionTreeLearner.class, AdaBoost.class, BatchVersion.class, TransformingBatchLearner.class, MaxEntLearner.class, // transformations FrequencyBasedTransformLearner.class, InfoGainTransformLearner2.class, T1InstanceTransformLearner.class, TFIDFTransformLearner.class, // sequential learner CollinsPerceptronLearner.class, GenericCollinsLearner.class, // splitters CrossValSplitter.class, RandomSplitter.class, StratifiedCrossValSplitter.class, }; //private static final Set<String> LEGAL_OPS = new HashSet<String>(Arrays.asList(new String[]{"train","test","trainTest"})); public static class DataClassificationTask implements CommandLineProcessor.Configurable,/*Saveable,*/ Console.Task { private ClassifyCommandLineUtil.TestParams testParams = new ClassifyCommandLineUtil.TestParams(); public Object resultToShow; public boolean useGUI; public Console.Task main; // for gui public ClassifyCommandLineUtil.TestParams getTestParameters() { return testParams; } public void setTestParameters(ClassifyCommandLineUtil.TestParams p) { testParams=p; } public String getTestParamsHelp() { return "Define what type of experiment you would like to run: <br>" + "Simple - Standard classify experiment <br> " + "Multi - Classify Experiment with Multiple labels per example <br>" + "Seq - Classify experiment with a Sequential Dataset, where each example has a history, <br> " + " and uses a Sequential Learner"; } protected class GUIParams extends BasicCommandLineProcessor { public void gui() { useGUI=true; if(ClassifyCommandLineUtil.TestParams.type != null) testParams = ClassifyCommandLineUtil.TestParams.type; else testParams = new ClassifyCommandLineUtil.SimpleTestParams(); } @Override public void usage() { System.out.println("presentation parameters:"); System.out.println(" -gui use graphic interface to set parameters"); System.out.println(); } } public String getDatasetFilename() { return testParams.testDataFilename; } @Override public CommandLineProcessor getCLP() { JointCommandLineProcessor jlpTest = new JointCommandLineProcessor(new CommandLineProcessor[] { new GUIParams(),testParams}); return jlpTest; } /** Returns whether base.labels exits */ @Override public boolean getLabels(){ return (getDatasetFilename() != null); } public MultiDataset annotateData(MultiClassifier multiClassifier, MultiDataset md) { MultiDataset annotatedDataset = new MultiDataset(); for(Iterator<MultiExample> i = md.multiIterator(); i.hasNext(); ) { MultiExample ex = i.next(); Instance instance = ex.asInstance(); MultiClassLabel predicted = multiClassifier.multiLabelClassification(instance); Instance annotatedInstance = new InstanceFromPrediction(instance, predicted.bestClassName()); MultiExample newEx = new MultiExample(annotatedInstance, ex.getMultiLabel(), ex.getWeight()); annotatedDataset.addMulti(newEx); } return annotatedDataset; } // main action @Override public void doMain() { if (testParams.testData==null) { System.out.println("The testing data needs to be specified with the -test option."); return; } if ((testParams.typeString.equals("seq")) && (!(testParams.testData instanceof SequenceDataset))) { System.out.println("The training data should be a sequence dataset"); return; } if (testParams.showData) new ViewerFrame("Training data",testParams.testData.toGUI()); try { if (testParams.loadFrom==null) { System.out.println("The classifier to test needs to be specified with -classifierFile option."); return; } Object c; if (testParams.typeString.equals("seq")) { Evaluation e = new Evaluation(testParams.testData.getSchema()); c = IOUtil.loadSerialized(testParams.loadFrom); e.extend((SequenceClassifier)c, (SequenceDataset)testParams.testData); e.summarize(); testParams.resultToShow = testParams.resultToSave = e; } else if (testParams.typeString.equals("multi")) { MultiDataset md = (MultiDataset)testParams.testData; MultiEvaluation e = new MultiEvaluation(md.getMultiSchema()); c = IOUtil.loadSerialized(testParams.loadFrom); if(testParams.crossDim) { md = annotateData((MultiClassifier)c, md); new ViewerFrame("Annotated data",md.toGUI()); } e.extend((MultiClassifier)c, md); e.summarize(); testParams.resultToShow = testParams.resultToSave = e; } else { Evaluation e = new Evaluation(testParams.testData.getSchema()); c = IOUtil.loadSerialized(testParams.loadFrom); e.extend((Classifier)c, testParams.testData, 0); e.summarize(); testParams.resultToShow = testParams.resultToSave = e; } if (testParams.showTestDetails) { if (testParams instanceof ClassifyCommandLineUtil.SeqTestParams) { ClassifiedSequenceDataset cd = new ClassifiedSequenceDataset((SequenceClassifier)c, (SequenceDataset)testParams.testData); testParams.resultToShow = cd; } else if(testParams.multi >0) { MultiClassifiedDataset cd = new MultiClassifiedDataset((MultiClassifier)c, (MultiDataset)testParams.testData); testParams.resultToShow = cd; }else { ClassifiedDataset cd = new ClassifiedDataset((Classifier)c, testParams.testData); testParams.resultToShow = cd; } } resultToShow = testParams.resultToShow; } catch (IOException ex) { log.error("Can't load classifier from "+testParams.loadFromFilename+": "+ex); return; } if (testParams.showResult) new ViewerFrame("Result", new SmartVanillaViewer(testParams.resultToShow)); if (testParams.saveAs!=null) { if (IOUtil.saveSomehow(testParams.resultToSave,testParams.saveAs)) { log.info("Result saved in "+testParams.saveAs); } else { log.error("Can't save "+testParams.resultToSave.getClass()+" to "+testParams.saveAs); } } } @Override public Object getMainResult() { return resultToShow; } // // implements Saveable // /*public String[] getFormatNames() { return clp.getFormatNames(); } public String getExtensionFor(String format) { return clp.getExtensionFor(format); } public void saveAs(File file, String format) throws IOException { clp.saveAs(file,format); } public Object restore(File file) throws IOException { DataClassificationTask task = new DataClassificationTask(); task.clp.config(file.getAbsolutePath()); return task; }*/ // gui around main action public void callMain(final String[] args) { try { getCLP().processArguments(args); if (!useGUI) { doMain(); } else { main = this; final Viewer v = new ComponentViewer() { static final long serialVersionUID=20071015; @Override public JComponent componentFor(Object o) { Viewer ts = new TypeSelector(SELECTABLE_TYPES, "selectableTypes.txt", DataClassificationTask.class); ts.setContent(o); // we'll put the type selector in a nice panel JPanel panel = new JPanel(); panel.setBorder(new TitledBorder(StringUtil.toString(args,"Command line: ",""," "))); panel.setLayout(new GridBagLayout()); GridBagConstraints gbc; // another panel to allow parameter modifications JPanel subpanel1 = new JPanel(); subpanel1.setBorder(new TitledBorder("Parameter modification")); //subpanel1.add(new JLabel("Use the edit button to change the parameters given in the command line")); subpanel1.add( ts ); gbc = Viewer.fillerGBC(); gbc.weighty=0; panel.add( subpanel1, gbc ); // a control panel for controls JPanel subpanel2 = new JPanel(); subpanel2.setBorder(new TitledBorder("Execution controls")); // a button to show the results final JButton viewButton = new JButton(new AbstractAction("View results") { static final long serialVersionUID=20071015; @Override public void actionPerformed(ActionEvent event) { Viewer rv = new SmartVanillaViewer(); rv.setContent( getMainResult() ); new ViewerFrame("Result", rv); } }); viewButton.setEnabled(false); // another panel for error messages and other outputs JPanel errorPanel = new JPanel(); errorPanel.setBorder(new TitledBorder("Error messages and output")); final Console console = new Console(main, getDatasetFilename() != null, viewButton); errorPanel.add(console.getMainComponent()); // a button to start this thread JButton goButton = new JButton(new AbstractAction("Start task") { static final long serialVersionUID=20071015; @Override public void actionPerformed(ActionEvent event) { console.start(); } }); // and a button to show the current labels JButton showLabelsButton = new JButton(new AbstractAction("Show train data") { static final long serialVersionUID=20071015; @Override public void actionPerformed(ActionEvent ev) { new ViewerFrame("Labeled TextBase", new SmartVanillaViewer(testParams.testData)); } }); // and a button to clear the errorArea JButton clearButton = new JButton(new AbstractAction("Clear window") { static final long serialVersionUID=20071015; @Override public void actionPerformed(ActionEvent ev) { console.clear(); } }); // and a button for help JButton helpParamsButton = new JButton(new AbstractAction("Parameters") { static final long serialVersionUID=20071015; @Override public void actionPerformed(ActionEvent ev) { PrintStream oldSystemOut = System.out; ByteArrayOutputStream outBuffer = new ByteArrayOutputStream(); System.setOut(new PrintStream(outBuffer)); getCLP().usage(); console.append(outBuffer.toString()); System.setOut(oldSystemOut); } }); subpanel2.add( goButton ); subpanel2.add( viewButton ); subpanel2.add( showLabelsButton ); subpanel2.add( clearButton ); subpanel2.add( new JLabel("Help:") ); subpanel2.add( helpParamsButton ); gbc = Viewer.fillerGBC(); gbc.weighty=0; gbc.gridy=1; panel.add(subpanel2, gbc ); gbc = Viewer.fillerGBC(); gbc.weighty=1; gbc.gridy=2; panel.add(errorPanel, gbc); // now some progress bars JProgressBar progressBar1 = new JProgressBar(); JProgressBar progressBar2 = new JProgressBar(); JProgressBar progressBar3 = new JProgressBar(); ProgressCounter.setGraphicContext(new JProgressBar[]{progressBar1, progressBar2,progressBar3}); gbc = Viewer.fillerGBC(); gbc.weighty=0; gbc.gridy=3; panel.add(progressBar1, gbc); gbc = Viewer.fillerGBC(); gbc.weighty=0; gbc.gridy=4; panel.add(progressBar2, gbc); gbc = Viewer.fillerGBC(); gbc.weighty=0; gbc.gridy=5; panel.add(progressBar3, gbc); return panel; } }; v.setContent(this); String className = this.getClass().toString().substring("class ".length()); new ViewerFrame(className,v); } } catch (Exception e) { e.printStackTrace(); System.out.println("Use option -help for help"); } } } public static void main(String[] args) { new DataClassificationTask().callMain(args); } }