package edu.cmu.minorthird.ui;
import java.io.IOException;
import edu.cmu.minorthird.classify.Classifier;
import edu.cmu.minorthird.classify.Dataset;
import edu.cmu.minorthird.classify.DatasetClassifierTeacher;
import edu.cmu.minorthird.text.learn.ClassifierAnnotator;
import edu.cmu.minorthird.util.CommandLineProcessor;
import edu.cmu.minorthird.util.IOUtil;
import edu.cmu.minorthird.util.JointCommandLineProcessor;
import edu.cmu.minorthird.util.gui.SmartVanillaViewer;
import edu.cmu.minorthird.util.gui.Viewer;
import edu.cmu.minorthird.util.gui.ViewerFrame;
/**
* Train a text classifier.
*
* @author William Cohen
*/
public class TrainClassifier extends UIMain{
// private data needed to train a classifier
private CommandLineUtil.SaveParams save=new CommandLineUtil.SaveParams();
private CommandLineUtil.ClassificationSignalParams signal=
new CommandLineUtil.ClassificationSignalParams(base);
private CommandLineUtil.TrainClassifierParams train=
new CommandLineUtil.TrainClassifierParams();
private Classifier classifier=null;
public CommandLineUtil.SaveParams getSaveParameters(){
return save;
}
public void setSaveParameters(CommandLineUtil.SaveParams p){
save=p;
}
public CommandLineUtil.ClassificationSignalParams getSignalParameters(){
return signal;
}
public void setSignalParameters(CommandLineUtil.ClassificationSignalParams p){
signal=p;
}
public CommandLineUtil.TrainClassifierParams getAdditionalParameters(){
return train;
}
public void setAdditionalParameters(CommandLineUtil.TrainClassifierParams p){
train=p;
}
public String getTrainClassifierHelp(){
return "<A HREF=\"http://minorthird.sourceforge.net/tutorials/TrainClassifier%20Tutorial.htm\">TrainClassifier Tutorial</A></html>";
}
@Override
public CommandLineProcessor getCLP(){
return new JointCommandLineProcessor(new CommandLineProcessor[]{gui,base,
save,signal,train});
}
//
// do the experiment
//
@Override
public void doMain(){
// check that inputs are valid
if(train.learner==null)
throw new IllegalArgumentException("-learner must be specified");
if(signal.spanProp==null&&signal.spanType==null)
throw new IllegalArgumentException(
"one of -spanProp or -spanType must be specified");
if(signal.spanProp!=null&&signal.spanType!=null)
throw new IllegalArgumentException(
"only one of -spanProp or -spanType can be specified");
// construct the dataset
Dataset d=
CommandLineUtil.toDataset(base.labels,train.fe,signal.spanProp,
signal.spanType,signal.candidateType);
if(train.showData){
System.out.println("Trying to show the Dataset");
new ViewerFrame("Dataset",d.toGUI());
}
/*Dataset seqDataset =
CommandLineUtil.toSeqDataset(base.labels,train.fe,signal.spanProp,"combined");
if (train.showData) {
System.out.println("Trying to create Sequential Dataset");
new ViewerFrame("SequenceDataset", seqDataset.toGUI());
}*/
// train the classifier
classifier=new DatasetClassifierTeacher(d).train(train.learner);
if(base.showResult){
Viewer cv=new SmartVanillaViewer();
cv.setContent(classifier);
new ViewerFrame("Classifier",cv);
}
String type=signal.getOutputType(train.output);
String prop=signal.getOutputProp(train.output);
ClassifierAnnotator ann=
new ClassifierAnnotator(train.fe,classifier,type,prop,
signal.candidateType);
if(save.saveAs!=null){
try{
IOUtil.saveSerialized(ann,save.saveAs);
}catch(IOException e){
throw new IllegalArgumentException("can't save to "+save.saveAs+": "+e);
}
}
}
@Override
public Object getMainResult(){
return classifier;
}
public static void main(String args[]){
new TrainClassifier().callMain(args);
}
}