package edu.cmu.minorthird.classify.ranking; import java.io.File; import java.io.IOException; import java.util.Iterator; import java.util.List; import java.util.Map; import edu.cmu.minorthird.classify.BinaryClassifier; import edu.cmu.minorthird.classify.Dataset; import edu.cmu.minorthird.classify.DatasetClassifierTeacher; import edu.cmu.minorthird.classify.DatasetLoader; import edu.cmu.minorthird.classify.Example; import edu.cmu.minorthird.classify.Splitter; import edu.cmu.minorthird.classify.experiments.CrossValSplitter; import edu.cmu.minorthird.classify.experiments.Expt; import edu.cmu.minorthird.util.BasicCommandLineProcessor; import edu.cmu.minorthird.util.CommandLineProcessor; import edu.cmu.minorthird.util.IOUtil; import edu.cmu.minorthird.util.ProgressCounter; import edu.cmu.minorthird.util.gui.ViewerFrame; /** * Learn from examples a GraphSearcher that re-ranks examples based on * scores from a learned classifier. */ public class RankingExpt { private String dataFileName = null; private BatchRankingLearner classifierLearner = new RankingPerceptron(); private Splitter<Example> splitter = new CrossValSplitter<Example>(); private String saveFile = null; // private StringEncoder encoder = new StringEncoder('%',"$ \t\n"); private boolean guiFlag = false; public CommandLineProcessor getCLP() { return new MyCLP(); } public class MyCLP extends BasicCommandLineProcessor { public void data(String s) { dataFileName = s; } public void splitter(String s) { splitter = Expt.toSplitter(s); } public void saveAs(String s) { saveFile = s; } public void learner(String s) { classifierLearner = (BatchRankingLearner)Expt.toLearner(s); } public void gui() { guiFlag = true; } } private RankingEvaluation doExpt() throws IOException,NumberFormatException { Dataset data = DatasetLoader.loadFile(new File(dataFileName)); System.out.println("loaded "+data.size()+" examples"); RankingEvaluation eval = new RankingEvaluation(); Dataset.Split split = data.split(splitter); ProgressCounter pc = new ProgressCounter("train/test", "fold", split.getNumPartitions()); for (int k=0; k<split.getNumPartitions(); k++) { // Dataset train = split.getTrain(k); Dataset test = split.getTest(k); DatasetClassifierTeacher teacher = new DatasetClassifierTeacher(data); BinaryClassifier classifier = (BinaryClassifier)teacher.train(classifierLearner); doTest( classifier, test, eval ); pc.progress(); } pc.finished(); return eval; } private void doTest( BinaryClassifier classifier, Dataset test, RankingEvaluation eval) { Map<String,List<Example>> bySubpopMap = BatchRankingLearner.splitIntoRankings(test); for (Iterator<String> i=bySubpopMap.keySet().iterator(); i.hasNext(); ) { String subpop = i.next(); List<Example> subdata = bySubpopMap.get(subpop); eval.extend( subpop, subdata, classifier ); } } public static void main(String[] args) throws IOException,NumberFormatException { RankingExpt x = new RankingExpt(); x.getCLP().processArguments(args); RankingEvaluation eval = x.doExpt(); System.out.println(eval.toTable()); if (x.guiFlag) new ViewerFrame("result", eval.toGUI()); if (x.saveFile!=null) IOUtil.saveSomehow(eval, new File(x.saveFile),true); } }