/* Copyright 2003, Carnegie Mellon, All Rights Reserved */
package edu.cmu.minorthird.classify.experiments;
import java.io.File;
import java.io.IOException;
import java.io.Serializable;
import edu.cmu.minorthird.classify.Classifier;
import edu.cmu.minorthird.classify.ClassifierLearner;
import edu.cmu.minorthird.classify.Dataset;
import edu.cmu.minorthird.classify.DatasetClassifierTeacher;
import edu.cmu.minorthird.classify.DatasetLoader;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.SampleDatasets;
import edu.cmu.minorthird.classify.Splitter;
import edu.cmu.minorthird.util.BasicCommandLineProcessor;
import edu.cmu.minorthird.util.CommandLineProcessor;
import edu.cmu.minorthird.util.IOUtil;
import edu.cmu.minorthird.util.StringUtil;
import edu.cmu.minorthird.util.gui.ViewerFrame;
/** Simple experiment on a classifier.
*
* @author William Cohen
*/
public class Expt implements CommandLineProcessor.Configurable
{
private Dataset trainData=null, testData=null;
private Splitter<Example> splitter=null;
private ClassifierLearner learner=null;
private String splitterArg=null,trainArg=null,testArg=null,learnerArg=null;
private class MyCLP extends BasicCommandLineProcessor
{
/*
public void train(String s) {
try {
trainData = toDataset(s);
trainArg = s;
} catch (IOException ex) {
throw new IllegalArgumentException("Error loading "+s+": "+ex);
}
}
public void test(String s) {
try {
testData = toDataset(s);
splitter = new FixedTestSetSplitter<Example>(testData.iterator());
testArg = s;
} catch (IOException ex) {
throw new IllegalArgumentException("Error loading "+s+": "+ex);
}
}
public void splitter(String s) {
splitterArg = s;
splitter = toSplitter(s);
}
public void learner(String s) {
learner = toLearner(s);
learnerArg = s;
}
*/
@Override
public void usage() {
System.out.println("classify.Expt parameters:");
System.out.println(" -train FILE training data is in FILE");
System.out.println(" [-test FILE] test data is in FILE");
System.out.println(" [-splitter SPLITTER] do cross-validation study with the SPLITTER");
System.out.println(" [-learner LEARNER] use learner defined by bean-shell command");
System.out.println();
}
}
@Override
public CommandLineProcessor getCLP() { return new MyCLP(); }
public Expt(ClassifierLearner learner,Dataset trainData,Dataset testData)
{
this.learner = learner;
this.trainData = trainData;
this.testData = testData;
this.splitter = null;
}
public Expt(ClassifierLearner learner,Dataset trainData,Splitter<Example> splitter)
{
this.learner = learner;
this.trainData = trainData;
this.splitter = splitter;
}
/** Convert a set of command-line arguments to an 'experiment'
* Examples:
* -learn \"new NaiveBayes()\" -train sample:toy -split k10 (k-fold CV)
* -learn \"new PoissonLearner()\" -train sample:toy -split s10 (stratified s-fold CV)
* -learn \"new AdaBoost(new DecisionTreeLearner())\" -train file:foo.data -split r70
* -learn \"new AdaBoost()\" -train seqfile:foo.data -split r70
*/
public Expt(String[] args) throws IOException
{
int pos = 0;
while (pos<args.length) {
String opt = args[pos++];
if (opt.startsWith("-tr")) {
trainData = toDataset(trainArg = args[pos++]);
} else if (opt.startsWith("-te")) {
if (splitter!=null) throw new IllegalArgumentException("only one of splitter, testData allowed");
testData = toDataset(testArg = args[pos++]);
splitter = new FixedTestSetSplitter<Example>(testData.iterator());
} else if (opt.startsWith("-spl")) {
if (splitter!=null) throw new IllegalArgumentException("only one of splitter, testData allowed");
splitter = toSplitter(splitterArg = args[pos++]);
} else if (opt.startsWith("-lea")) {
learner = toLearner(learnerArg = args[pos++]);
} else if (opt.startsWith("-")) {
pos++;
}
}
if (trainData==null || learner==null)
throw new IllegalArgumentException("learner and trainData must be specified");
if (testData==null && splitter==null)
splitter = new FixedTestSetSplitter<Example>(trainData.iterator());
}
public Evaluation evaluation()
{
Evaluation v = Tester.evaluate(learner,trainData,splitter);
v.setProperty("learner",learnerArg);
v.setProperty("train",trainArg);
if (splitterArg!=null) v.setProperty("splitter",splitterArg);
if (testArg!=null) v.setProperty("test",testArg);
return v;
}
public CrossValidatedDataset crossValidatedDataset(boolean saveTrain)
{
return new CrossValidatedDataset(learner,trainData,splitter,saveTrain);
}
public Classifier getClassifier()
{
return new DatasetClassifierTeacher(trainData).train(learner);
}
@Override
public String toString()
{
return
"[Expt:\n learner:"+learner+"\n splitter:"+splitter+"\n train:\n"+trainData+" test:\n"+testData+"Expt]";
}
/** Decode splitter names.
*/
static public <T> Splitter<T> toSplitter(String splitterName,Class<T> clazz)
{
if (splitterName.charAt(0)=='k') {
int folds = StringUtil.atoi(splitterName.substring(1,splitterName.length()));
return new CrossValSplitter<T>(folds);
}
if (splitterName.charAt(0)=='r') {
double pct = StringUtil.atoi(splitterName.substring(1,splitterName.length())) / 100.0;
return new RandomSplitter<T>(pct);
}
// if (splitterName.charAt(0)=='s') {
// int folds = StringUtil.atoi(splitterName.substring(1,splitterName.length()));
// return new StratifiedCrossValSplitter(folds);
// }
if (splitterName.startsWith("l")) {
return new LeaveOneOutSplitter<T>();
}
throw new IllegalArgumentException("illegal splitterName '"+splitterName+"'");
}
public static Splitter<Example> toSplitter(String splitterName){
return toSplitter(splitterName,Example.class);
}
/** Decode dataset names. Allowed names are:
*
*<ul>
* <li>sample:foo,
* <li>sample:foo.test
* <li>sample:foo.train,
* <li>file:bar
* <li>seqfile:bar
* <li>bar (bar is a filename)
*</ul>
*/
static public Dataset toDataset(String datasetName) throws IOException
{
String[] words = datasetName.split("\\:");
if (words.length==1) {
// file
return DatasetLoader.loadFile(new File(words[0]));
}
if (words.length==2 && "file".equals(words[0])) {
// file:bar
return DatasetLoader.loadFile(new File(words[1]));
}
if (words.length==2 && "seqfile".equals(words[0])) {
// file:bar
return DatasetLoader.loadSequence(new File(words[1]));
}
if ("sample".equals(words[0])) {
String[] parts = words[1].split("\\.");
if (parts.length==1) {
//sample:foo
return SampleDatasets.sampleData(parts[0],false);
} else if ("test".equals(parts[1])) {
//sample:foo.test
return SampleDatasets.sampleData(parts[0],true);
} else if ("train".equals(parts[1])) {
//sample:foo.train
return SampleDatasets.sampleData(parts[0],false);
}
}
throw new IllegalArgumentException("illegal datasetName: "+datasetName);
}
/**
* Decode learner name, which should be a legitimate java constructor,
* e.g. <code>new NaiveBayes()</code>.
*/
static public ClassifierLearner toLearner(String learnerName)
{
try {
bsh.Interpreter interp = new bsh.Interpreter();
interp.eval("import edu.cmu.minorthird.classify.*;");
interp.eval("import edu.cmu.minorthird.classify.algorithms.linear.*;");
interp.eval("import edu.cmu.minorthird.classify.algorithms.trees.*;");
interp.eval("import edu.cmu.minorthird.classify.algorithms.ranking.*;");
interp.eval("import edu.cmu.minorthird.classify.algorithms.knn.*;");
interp.eval("import edu.cmu.minorthird.classify.algorithms.svm.*;");
interp.eval("import edu.cmu.minorthird.classify.transform.*;");
interp.eval("import edu.cmu.minorthird.classify.semisupervised.*;");
return (ClassifierLearner)interp.eval(learnerName);
} catch (bsh.EvalError e) {
throw new IllegalArgumentException("error parsing learnerName '"+learnerName+"':\n"+e);
}
}
static public void main(String[] args)
{
try {
Expt expt = new Expt(args);
int pos = 0;
Serializable toSave = null;
File saveFile = null;
while (pos<args.length) {
String opt = args[pos++];
if (opt.startsWith("-show")) {
String what = args[pos++];
if (what.startsWith("eval")) {
Evaluation v = expt.evaluation();
new ViewerFrame("Evaluation", v.toGUI());
} else if (what.startsWith("all")) {
boolean saveTrain = "all+".equals(what);
CrossValidatedDataset cdv = expt.crossValidatedDataset(saveTrain);
new ViewerFrame("CrossValidatedDataset", cdv.toGUI());
} else {
throw new IllegalArgumentException("can't show '"+what+"'");
}
} else if (opt.startsWith("-save")) {
String what = args[pos++];
if (what.startsWith("eval")) {
toSave = expt.evaluation();
} else if (what.startsWith("cla")) {
toSave = (Serializable) expt.getClassifier();
} else {
throw new IllegalArgumentException("can't save '"+what+"'");
}
} else if (opt.startsWith("-file")) {
saveFile = new File(args[pos++]);
} else if (opt.startsWith("-")) {
pos++;
}
}
if (saveFile!=null && toSave!=null) {
IOUtil.saveSerialized(toSave,saveFile);
}
if ((saveFile==null) != (toSave==null)) {
throw new IllegalArgumentException("must specify -file FILE with -save WHAT");
}
} catch (Exception e) {
e.printStackTrace();
System.out.println(
"usage: -learn L -train D1 [-split S] [-test D] [-show eval|all|all+] [-save eval|classifier]");
}
}
}