package edu.cmu.minorthird.text.learn.experiments; import java.io.File; import java.util.ArrayList; import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Set; import java.util.TreeSet; import org.apache.log4j.Logger; import edu.cmu.minorthird.classify.OnlineBinaryClassifierLearner; import edu.cmu.minorthird.classify.Splitter; import edu.cmu.minorthird.classify.experiments.Expt; import edu.cmu.minorthird.classify.experiments.FixedTestSetSplitter; import edu.cmu.minorthird.classify.experiments.RandomSplitter; import edu.cmu.minorthird.classify.sequential.BatchSequenceClassifierLearner; import edu.cmu.minorthird.classify.sequential.GenericCollinsLearner; import edu.cmu.minorthird.text.Annotator; import edu.cmu.minorthird.text.FancyLoader; import edu.cmu.minorthird.text.MonotonicTextLabels; import edu.cmu.minorthird.text.NestedTextLabels; import edu.cmu.minorthird.text.Span; import edu.cmu.minorthird.text.SpanDifference; import edu.cmu.minorthird.text.TextLabels; import edu.cmu.minorthird.text.TextLabelsLoader; import edu.cmu.minorthird.text.gui.TextBaseViewer; import edu.cmu.minorthird.text.learn.AnnotatorLearner; import edu.cmu.minorthird.text.learn.AnnotatorTeacher; import edu.cmu.minorthird.text.learn.Extraction2TaggingReduction; import edu.cmu.minorthird.text.learn.InsideOutsideReduction; import edu.cmu.minorthird.text.learn.SampleFE; import edu.cmu.minorthird.text.learn.SequenceAnnotatorLearner; import edu.cmu.minorthird.text.learn.TextLabelsAnnotatorTeacher; import edu.cmu.minorthird.util.ProgressCounter; import edu.cmu.minorthird.util.StringUtil; import edu.cmu.minorthird.util.gui.ParallelViewer; import edu.cmu.minorthird.util.gui.SmartVanillaViewer; import edu.cmu.minorthird.util.gui.TransformedViewer; import edu.cmu.minorthird.util.gui.Viewer; import edu.cmu.minorthird.util.gui.ViewerFrame; import edu.cmu.minorthird.util.gui.Visible; /** * Run an annotation-learning experiment based on pre-labeled text. * * @author William Cohen */ public class TextLabelsExperiment implements Visible{ private SampleFE.ExtractionFE fe=new SampleFE.ExtractionFE(); private Extraction2TaggingReduction reduction=new InsideOutsideReduction(); private int classWindowSize=3; private TextLabels labels; private Splitter<Span> splitter; private AnnotatorLearner learner; private String inputType,inputProp; private TextLabels testLabelsUsedInSplitter; private MonotonicTextLabels fullTestLabels; private MonotonicTextLabels[] testLabels; private String outputLabel; private Annotator[] annotators; private static Logger log=Logger.getLogger(TextLabelsExperiment.class); private ExtractionEvaluation extractionEval=new ExtractionEvaluation(); /** * @param labels * The labels and base to be annotated in the example These are the * training examples * @param splitter * splitter for the documents in the labels to create test vs. train * @param learnerName * AnnotatorLearner algorithm object to use * @param spanType * spanType in the TextLabels to use as training data. (I.e., the * spanType to learn. * @param spanProp * span property in the TextLabels to use as training data. * @param outputLabel * the spanType that will be assigned to spans predicted to be of * type inputLabel by the learner (I.e., the output type associated * with the learned annotator.) */ public TextLabelsExperiment(TextLabels labels,Splitter<Span> splitter, String learnerName,String spanType,String spanProp,String outputLabel){ this(labels,splitter,learnerName,spanType,spanProp,outputLabel,null); } public TextLabelsExperiment(TextLabels labels,Splitter<Span> splitter, String learnerName,String spanType,String spanProp,String outputLabel, Extraction2TaggingReduction reduce){ if(reduce!=null) this.reduction=reduce; this.labels=labels; this.splitter=splitter; this.inputType=spanType; this.inputProp=spanProp; this.outputLabel=outputLabel; this.learner=toAnnotatorLearner(learnerName); learner.setAnnotationType(outputLabel); } public TextLabelsExperiment(TextLabels labels,Splitter<Span> splitter, AnnotatorLearner learner,String inputType,String outputLabel){ this(labels,splitter,null,learner,inputType,null,outputLabel); } /** * @param labels * TextLabels to train on * @param splitter * how to partition labels into train/test * @param testLabels * if splitter is a FixedTestSetSplitter, these are the labels for * the test cases. Otherwise the labels for the test cases are given * in the "labels" input. * @param learner * the learner to user * @param inputType * the spanType to learn to extract (if non-null) * @param inputProp * the spanProp to learn to extract and label (if non-null) * @param outputLabel * the spanType/spanProp used for predictions */ public TextLabelsExperiment(TextLabels labels,Splitter<Span> splitter, TextLabels testLabels,AnnotatorLearner learner,String inputType, String inputProp,String outputLabel){ this.labels=labels; this.splitter=splitter; this.testLabelsUsedInSplitter=testLabels; this.inputType=inputType; this.inputProp=inputProp; this.outputLabel=outputLabel; this.learner=learner; learner.setAnnotationType(outputLabel); } public SampleFE.ExtractionFE getFE(){ return fe; } public void doExperiment(){ splitter.split(labels.getTextBase().documentSpanIterator()); annotators=new Annotator[splitter.getNumPartitions()]; Set<Span> allTestDocuments=new TreeSet<Span>(); for(int i=0;i<splitter.getNumPartitions();i++){ for(Iterator<Span> j=splitter.getTest(i);j.hasNext();){ // System.out.println("adding test case to allTestDocuments"); allTestDocuments.add(j.next()); } } // Progress counter ProgressCounter progressCounter= new ProgressCounter("train/test experiment","fold",splitter .getNumPartitions()); // figure out what the test set should be try{ // for most splitters, the test set will be a subset of the original // TextBase SubTextBase fullTestBase= new SubTextBase(labels.getTextBase(),allTestDocuments.iterator()); fullTestLabels= new NestedTextLabels(new SubTextLabels(fullTestBase,labels)); testLabels=new MonotonicTextLabels[splitter.getNumPartitions()]; }catch(SubTextBase.UnknownDocumentException ex){ // the other supported case is a fixed test set if(testLabelsUsedInSplitter==null) throw new IllegalArgumentException("exception: "+ex); if(!(splitter instanceof FixedTestSetSplitter)) throw new IllegalArgumentException("illegal splitter "+splitter); fullTestLabels=new NestedTextLabels(testLabelsUsedInSplitter); testLabels=new MonotonicTextLabels[1]; testLabels[0]=fullTestLabels; } for(int i=0;i<splitter.getNumPartitions();i++){ log.info("For partition "+(i+1)+" of "+splitter.getNumPartitions()); log.info("Creating teacher and train partition..."); SubTextLabels trainLabels=null; try{ SubTextBase trainBase= new SubTextBase(labels.getTextBase(),splitter.getTrain(i)); trainLabels=new SubTextLabels(trainBase,labels); }catch(SubTextBase.UnknownDocumentException ex){ throw new IllegalStateException("error building trainBase "+i+": "+ex); } AnnotatorTeacher teacher= new TextLabelsAnnotatorTeacher(trainLabels,inputType,inputProp); log.info("Training annotator: inputType="+inputType+" inputProp="+inputProp); annotators[i]=teacher.train(learner); // log.info("annotators["+i+"]="+annotators[i]); log.info("Creating test partition..."); try{ SubTextBase testBase= new SubTextBase(labels.getTextBase(),splitter.getTest(i)); testLabels[i]=new MonotonicSubTextLabels(testBase,fullTestLabels); }catch(SubTextBase.UnknownDocumentException ex){ // do nothing since testLabels[i] is already set } log.info("Labeling test partition, size="+testLabels[i].getTextBase().size()); annotators[i].annotate(testLabels[i]); log.info("Evaluating test partition..."); measurePrecisionRecall("Test partition "+(i+1),testLabels[i],false); // step progress counter progressCounter.progress(); } measurePrecisionRecall("Overall performance",fullTestLabels,true); // end progress counter progressCounter.finished(); } public ExtractionEvaluation getEvaluation(){ return extractionEval; } @Override public Viewer toGUI(){ ParallelViewer v=new ParallelViewer(); for(int i=0;i<annotators.length;i++){ final int index=i; v.addSubView("Annotator "+(i+1),new TransformedViewer( new SmartVanillaViewer()){ static final long serialVersionUID=20080306L; @Override public Object transform(Object o){ return annotators[index]; } }); v.addSubView("Test set "+(i+1),new TransformedViewer( new SmartVanillaViewer()){ static final long serialVersionUID=20080306L; @Override public Object transform(Object o){ return testLabels[index]; } }); } v.addSubView("Full test set", new TransformedViewer(new SmartVanillaViewer()){ static final long serialVersionUID=20080306L; @Override public Object transform(Object o){ return fullTestLabels; } }); v.addSubView("Evaluation",new TransformedViewer(new SmartVanillaViewer()){ static final long serialVersionUID=20080306L; @Override public Object transform(Object o){ return extractionEval; } }); v.setContent(this); return v; } // private void measurePrecisionRecall(String tag,TextLabels labels){ // measurePrecisionRecall(tag,labels,false); // } private void measurePrecisionRecall(String tag,TextLabels labels, boolean isOverallMeasure){ if(inputType!=null){ // only need one span difference here SpanDifference sd= new SpanDifference(labels.instanceIterator(outputLabel),labels .instanceIterator(inputType),labels.closureIterator(inputType)); System.out.println(tag+":"); System.out.println(sd.toSummary()); extractionEval.extend(tag,sd,isOverallMeasure); }else{ // will need one span difference for each possible property value Set<String> propValues=new HashSet<String>(); for(Iterator<Span> i=labels.getSpansWithProperty(inputProp);i.hasNext();){ Span s=i.next(); propValues.add(labels.getProperty(s,inputProp)); } SpanDifference[] sd=new SpanDifference[propValues.size()]; int k=0; for(Iterator<String> i=propValues.iterator();i.hasNext();k++){ String val=i.next(); sd[k]= new SpanDifference(propertyIterator(labels,outputLabel,val), propertyIterator(labels,inputProp,val),labels.getTextBase() .documentSpanIterator()); String tag1=tag+" for "+inputProp+":"+val; System.out.println(tag1+":"); System.out.println(sd[k].toSummary()); extractionEval.extend(tag1,sd[k],false); } SpanDifference sdAll=new SpanDifference(sd); String tag1=tag+" (micro-averaged) for "+inputProp; System.out.println(tag1+":"); System.out.println(sdAll.toSummary()); extractionEval.extend(tag1,sdAll,isOverallMeasure); } if(isOverallMeasure) extractionEval.measureTotalSize(labels.getTextBase()); } private Iterator<Span> propertyIterator(TextLabels labels,String prop, String value){ List<Span> accum=new ArrayList<Span>(); for(Iterator<Span> i=labels.getSpansWithProperty(prop);i.hasNext();){ Span s=i.next(); if(value==null||value.equals(labels.getProperty(s,prop))){ accum.add(s); } } return accum.iterator(); } public AnnotatorLearner toAnnotatorLearner(String s){ try{ OnlineBinaryClassifierLearner learner= (OnlineBinaryClassifierLearner)Expt.toLearner(s); BatchSequenceClassifierLearner seqLearner= new GenericCollinsLearner(learner,classWindowSize); return new SequenceAnnotatorLearner(seqLearner,fe,reduction); }catch(IllegalArgumentException ex){ /* that's ok, maybe it's something else */; } try{ BatchSequenceClassifierLearner seqLearner= (BatchSequenceClassifierLearner)SequenceAnnotatorExpt.toSeqLearner(s); return new SequenceAnnotatorLearner(seqLearner,fe,reduction); }catch(IllegalArgumentException ex){ /* that's ok, maybe it's something else */; } try{ bsh.Interpreter interp=new bsh.Interpreter(); interp.eval("import edu.cmu.minorthird.text.*;"); interp.eval("import edu.cmu.minorthird.text.learn.*;"); interp.eval("import edu.cmu.minorthird.classify.*;"); interp.eval("import edu.cmu.minorthird.classify.experiments.*;"); interp.eval("import edu.cmu.minorthird.classify.algorithms.linear.*;"); interp.eval("import edu.cmu.minorthird.classify.algorithms.trees.*;"); interp.eval("import edu.cmu.minorthird.classify.algorithms.knn.*;"); interp.eval("import edu.cmu.minorthird.classify.algorithms.svm.*;"); interp.eval("import edu.cmu.minorthird.classify.sequential.*;"); return (AnnotatorLearner)interp.eval(s); }catch(bsh.EvalError e){ throw new IllegalArgumentException("error parsing learnerName '"+s+ "':\n"+e); } } static public BatchSequenceClassifierLearner toSeqLearner(String learnerName){ try{ bsh.Interpreter interp=new bsh.Interpreter(); interp.eval("import edu.cmu.minorthird.classify.*;"); interp.eval("import edu.cmu.minorthird.classify.experiments.*;"); interp.eval("import edu.cmu.minorthird.classify.algorithms.linear.*;"); interp.eval("import edu.cmu.minorthird.classify.algorithms.trees.*;"); interp.eval("import edu.cmu.minorthird.classify.algorithms.knn.*;"); interp.eval("import edu.cmu.minorthird.classify.algorithms.svm.*;"); interp.eval("import edu.cmu.minorthird.classify.sequential.*;"); interp.eval("import edu.cmu.minorthird.classify.transform.*;"); return (BatchSequenceClassifierLearner)interp.eval(learnerName); }catch(bsh.EvalError e){ throw new IllegalArgumentException("error parsing learnerName '"+ learnerName+"':\n"+e); } } public TextLabels getTestLabels(){ return fullTestLabels; } public static void main(String[] args){ Splitter<Span> splitter=new RandomSplitter<Span>(0.7); String outputLabel="_prediction"; String learnerName="new CollinsPerceptronLearner()"; TextLabels labels=null; String spanType=null,spanProp=null,saveFileName=null,show=null,annotationNeeded= null; List<String> featureMods=new ArrayList<String>(); Extraction2TaggingReduction reduction=null; try{ int pos=0; while(pos<args.length){ String opt=args[pos++]; if(opt.startsWith("-lab")){ labels=FancyLoader.loadTextLabels(args[pos++]); }else if(opt.startsWith("-lea")){ learnerName=args[pos++]; }else if(opt.startsWith("-split")){ splitter=Expt.toSplitter(args[pos++],Span.class); }else if(opt.startsWith("-in")){ spanType=args[pos++]; }else if(opt.startsWith("-spanT")){ spanType=args[pos++]; }else if(opt.startsWith("-spanP")){ spanProp=args[pos++]; }else if(opt.startsWith("-out")){ outputLabel=args[pos++]; }else if(opt.startsWith("-save")){ saveFileName=args[pos++]; }else if(opt.startsWith("-show")){ show=args[pos++]; }else if(opt.startsWith("-mix")){ annotationNeeded=args[pos++]; }else if(opt.startsWith("-fe")){ featureMods.add(args[pos++]); }else if(opt.startsWith("-reduction")){ try{ bsh.Interpreter interp=new bsh.Interpreter(); interp.eval("import edu.cmu.minorthird.text.learn.*;"); reduction=(Extraction2TaggingReduction)interp.eval(args[pos++]); }catch(bsh.EvalError e){ throw new IllegalArgumentException("error parsing reductionName '"+ args[pos-1]+"':\n"+e); } }else{ usage(); return; } } if(labels==null||learnerName==null||splitter==null|| (spanProp==null&&spanType==null)||outputLabel==null){ usage(); return; } if(spanProp!=null&&spanType!=null){ usage(); return; } TextLabelsExperiment expt= new TextLabelsExperiment(labels,splitter,learnerName,spanType, spanProp,outputLabel,reduction); if(annotationNeeded!=null){ expt.getFE().setRequiredAnnotation(annotationNeeded); expt.getFE().setAnnotationProvider(annotationNeeded+".mixup"); expt.getFE().setTokenPropertyFeatures("*"); // use all defined // properties labels.require(annotationNeeded,annotationNeeded+".mixup"); } for(Iterator<String> i=featureMods.iterator();i.hasNext();){ String mod=i.next(); if(mod.startsWith("window=")){ expt.getFE().setFeatureWindowSize( StringUtil.atoi(mod.substring("window=".length()))); System.out.println("fe windowSize => "+ expt.getFE().getFeatureWindowSize()); }else if(mod.startsWith("charType")){ expt.getFE().setUseCharType( mod.substring("charType".length(),1).equals("+")); System.out.println("fe windowSize => "+expt.getFE().getUseCharType()); }else if(mod.startsWith("charPattern")){ expt.getFE().setUseCompressedCharType( mod.substring("charPattern".length(),1).equals("+")); System.out.println("fe windowSize => "+ expt.getFE().getUseCompressedCharType()); }else{ usage(); return; } } expt.doExperiment(); if(saveFileName!=null){ new TextLabelsLoader().saveTypesAsOps(expt.getTestLabels(),new File( saveFileName)); } if(show!=null){ TextBaseViewer.view(expt.getTestLabels()); if(show.startsWith("all")){ new ViewerFrame("Experiment",expt.toGUI()); } } }catch(Exception e){ e.printStackTrace(); usage(); return; } } private static void usage(){ String[] usageLines= new String[]{ "usage: options are:", " -label labelsKey dataset to load", " -spanType type defines the extraction target", " -spanProp prop defines the extraction target (specify exactly one of -spanType or -spanProp)", " -learn learner Java code to construct the learner, which could be an ", " an AnnotatorLearner, a BatchSequenceClassifierLearner, or an OnlineClassifierLearner", " - a BatchSequenceClassifierLearner is used to defined a SequenceAnnotatorLearner", " and an OnlineClassifierLearner is used to define a GenericCollinsLearner", " optional, default \"new CollinsPerceptronLearner()\"", " -out outputLabel label assigned to predictions", " optional, default _prediction", " -split splitter splitter to use, in format used by minorthird.classify.experiments.Expt.toSplitter()", " optional, default r70", " -save fileName file to save extended TextLabels in (train data + predictions)", " optional", " -show xxxx how much detail on experiment to show - xxx=all shows the most", " optional", " -mix yyyy augment feature extracture to first execute 'require yyyy,yyyy.mixup'", " optional", " -fe zzzz change default feature extractor with one of these options zzzz:", " window=K charType+ charType- charPattern+ charPattern-",}; for(int i=0;i<usageLines.length;i++) System.out.println(usageLines[i]); } }