import edu.cmu.minorthird.util.*; import edu.cmu.minorthird.text.*; import edu.cmu.minorthird.text.gui.*; import edu.cmu.minorthird.text.learn.*; import edu.cmu.minorthird.text.learn.experiments.*; import edu.cmu.minorthird.text.mixup.*; import edu.cmu.minorthird.classify.*; import edu.cmu.minorthird.classify.experiments.*; import edu.cmu.minorthird.classify.algorithms.linear.*; import edu.cmu.minorthird.classify.algorithms.trees.*; import java.util.*; import java.io.*; public class LearnImagePtrExtractor { static { Mixup.maxNumberOfMatchesPerToken = 20; } /** Heuristic used to find candidate image pointers */ static public SpanFinder candidateFinder; /** Computes features used by learner */ static public MixupProgram featureProgram; // initialize static { try { candidateFinder = new MixupFinder( new Mixup("... [L eq('(') !eq(')'){1,15}R eq(')') R] ...") ); featureProgram = new MixupProgram( new File("lib/features.mixup" ) ); } catch (Exception e) { throw new IllegalStateException("mixup or io error: "+e); } } /** Create the learner */ static private BatchFilteredFinderLearner makeAnnotatorLearner(BinaryClassifierLearner classifierLearner) { BatchFilteredFinderLearner annotatorLearner = new BatchFilteredFinderLearner( new ImgPtrFE(), classifierLearner, candidateFinder ); return annotatorLearner; } static public String predictedClassName(String className) { return "predicted"+className.substring(0,1).toUpperCase()+className.substring(1); } /** Load the initial labels */ static public MutableTextLabels loadLabels() throws IOException,Mixup.ParseException,java.text.ParseException { // load the data and labels //TextBase base = TextBaseLoader.loadDocPerLine(new File("data/captions/lines.txt"),true); TextBase base = new TextBaseLoader().load(new File("data/captions/caption-lines")); TextLabelsLoader eloader = new TextLabelsLoader(); MutableTextLabels labels = eloader.loadOps(base,new File("labels/imgptr.env")); return labels; } static public void main(String argv[]) throws IOException,Mixup.ParseException,java.text.ParseException { // load the labels and compute the features MutableTextLabels labels = loadLabels(); MixupInterpreter interp = new MixupInterpreter(featureProgram); interp.eval(labels); if (argv.length>0 && "-expt".equals(argv[0])) { String className = argv.length>=2 ? argv[1] : "regional"; String learnerName = argv.length>=3 ? argv[2] : "new AdaBoost()"; BinaryClassifierLearner learner = (BinaryClassifierLearner)Expt.toLearner(learnerName); AnnotatorLearner annnotatorLearner = makeAnnotatorLearner(learner); String predClassName = predictedClassName(className); TextLabelsExperiment expt = new TextLabelsExperiment(labels,new CrossValSplitter(10), annnotatorLearner, className, predClassName); expt.doExperiment(); TextBaseViewer.view( expt.getTestLabels() ); } else if (argv.length>0 && "-save".equals(argv[0])) { String[] classNames = new String[] { "regional","local" }; for (int i=0; i<classNames.length; i++) { System.out.println("Training classifier for "+classNames[i]+" imgptrs"); BinaryClassifierLearner learner = new AdaBoost(); BatchFilteredFinderLearner annotatorLearner = makeAnnotatorLearner(learner); new TextLabelsAnnotatorTeacher(labels,classNames[i]).train(annotatorLearner); Classifier classifier = annotatorLearner.getClassifier(); IOUtil.saveSerialized((Serializable)classifier,new File("lib/"+classNames[i]+"Filter.ser")); } } else { System.out.println("usage: -expt className learner"); System.out.println("usage: -save"); } } }