package edu.cmu.minorthird.text.learn; import java.io.Serializable; import java.util.Iterator; import javax.swing.JComponent; import javax.swing.JPanel; import javax.swing.JScrollPane; import javax.swing.border.TitledBorder; import org.apache.log4j.Logger; import edu.cmu.minorthird.classify.ClassLabel; import edu.cmu.minorthird.classify.ExampleSchema; import edu.cmu.minorthird.classify.Instance; import edu.cmu.minorthird.classify.sequential.BatchSequenceClassifierLearner; import edu.cmu.minorthird.classify.sequential.CollinsPerceptronLearner; import edu.cmu.minorthird.classify.sequential.SequenceClassifier; import edu.cmu.minorthird.classify.sequential.SequenceDataset; import edu.cmu.minorthird.text.AbstractAnnotator; import edu.cmu.minorthird.text.Annotator; import edu.cmu.minorthird.text.MonotonicTextLabels; import edu.cmu.minorthird.text.Span; import edu.cmu.minorthird.text.TextLabels; import edu.cmu.minorthird.util.ProgressCounter; import edu.cmu.minorthird.util.gui.ComponentViewer; import edu.cmu.minorthird.util.gui.SmartVanillaViewer; import edu.cmu.minorthird.util.gui.Viewer; import edu.cmu.minorthird.util.gui.ViewerFrame; import edu.cmu.minorthird.util.gui.Visible; /** * Learn an annotation model using a sequence dataset and a * BatchSequenceClassifierLearner. This class reduces extraction learning to * sequential classification of tokens. The scheme for mapping extraction * learning to token learning is determined by the Extraction2TaggingReduction. * * @author William Cohen */ public class SequenceAnnotatorLearner extends AbstractBatchAnnotatorLearner{ private static Logger log=Logger.getLogger(SequenceAnnotatorLearner.class); private static final boolean DEBUG=false; protected BatchSequenceClassifierLearner seqLearner; public SequenceAnnotatorLearner(){ super(); seqLearner=new CollinsPerceptronLearner(); } public SequenceAnnotatorLearner(BatchSequenceClassifierLearner seqLearner, SpanFeatureExtractor fe){ super(fe,new InsideOutsideReduction()); this.seqLearner=seqLearner; } public SequenceAnnotatorLearner(BatchSequenceClassifierLearner seqLearner, SpanFeatureExtractor fe,Extraction2TaggingReduction reduction){ super(fe,reduction); this.seqLearner=seqLearner; } // // getters and setters // private boolean displayDatasetBeforeLearning=false; /** * If set, try and pop up an interactive viewer of the sequential dataset * before learning. */ public boolean getDisplayDatasetBeforeLearning(){ return displayDatasetBeforeLearning; } public void setDisplayDatasetBeforeLearning( boolean newDisplayDatasetBeforeLearning){ this.displayDatasetBeforeLearning=newDisplayDatasetBeforeLearning; } public BatchSequenceClassifierLearner getSequenceClassifierLearner(){ return seqLearner; } public void setSequenceClassifierLearner( BatchSequenceClassifierLearner learner){ this.seqLearner=learner; } // Help Buttons public String getDisplayDatasetBeforeLearningHelp(){ return "Pop up an interactive viewer of the sequential dataset before learning."; } public String getSequenceClassifierLearnerHelp(){ return "The classifierLearner used to classify each token using the <br>predictions of previous tokens as features"; } /** * Return the learned annotator. */ @Override public Annotator getAnnotator(){ ExampleSchema schema=seqData.getSchema(); if(schema.getNumberOfClasses()<=1){ log.error("In the constructed dataset the number of classes is "+ schema.getNumberOfClasses()); log .error("Hint: this probably means that no spans of the specified type are present in your data"); } seqLearner.setSchema(schema); if(displayDatasetBeforeLearning) new ViewerFrame("Sequential Dataset",seqData.toGUI()); SequenceClassifier seqClassifier=seqLearner.batchTrain(seqData); if(DEBUG) log.debug("learned classifier: "+seqClassifier); return new SequenceAnnotatorLearner.SequenceAnnotator(seqClassifier,fe, reduction,annotationType); } /** * A useful subroutine - prepare sequence data the way a * SequenceAnnotatorLearner would prepare it when trained by a * TextLabelsAnnotatorTeacher. * */ static public SequenceDataset prepareSequenceData(TextLabels labels, String spanType,String spanProp,SpanFeatureExtractor fe, final int historySize,Extraction2TaggingReduction reduction){ BatchSequenceClassifierLearner dummy1=new BatchSequenceClassifierLearner(){ @Override public void setSchema(ExampleSchema schema){ } @Override public SequenceClassifier batchTrain(SequenceDataset dataset){ return null; } @Override public int getHistorySize(){ return historySize; } }; SequenceAnnotatorLearner dummy2= new SequenceAnnotatorLearner(dummy1,fe,reduction){ @Override public Annotator getAnnotator(){ return null; } }; new TextLabelsAnnotatorTeacher(labels,spanType,spanProp).train(dummy2); return dummy2.getSequenceDataset(); } // // learned annotator // public static class SequenceAnnotator extends AbstractAnnotator implements Serializable,Visible,ExtractorAnnotator{ private static final long serialVersionUID=2; private SequenceClassifier seqClassifier; private SpanFeatureExtractor fe; private Extraction2TaggingReduction reduction; private String annotationType; public SequenceAnnotator(SequenceClassifier seqClassifier, SpanFeatureExtractor fe,String annotationType){ this(seqClassifier,fe,new InsideOutsideReduction(),annotationType); } public SequenceAnnotator(SequenceClassifier seqClassifier, SpanFeatureExtractor fe,Extraction2TaggingReduction reduction, String annotationType){ this.seqClassifier=seqClassifier; this.fe=fe; this.reduction=reduction; this.annotationType=annotationType; } @Override public String getSpanType(){ return annotationType; } public SpanFeatureExtractor getSpanFeatureExtractor(){ return fe; } public Extraction2TaggingReduction getReduction(){ return reduction; } public SequenceClassifier getSequenceClassifier(){ return seqClassifier; } @Override protected void doAnnotate(MonotonicTextLabels labels){ Iterator<Span> i=labels.getTextBase().documentSpanIterator(); ProgressCounter pc= new ProgressCounter("tagging with classifier","document"); while(i.hasNext()){ Span s=i.next(); Instance[] sequence=new Instance[s.size()]; for(int j=0;j<s.size();j++){ Span tokenSpan=s.subSpan(j,1); sequence[j]=fe.extractInstance(labels,tokenSpan); } ClassLabel[] classLabels=seqClassifier.classification(sequence); for(int j=0;j<classLabels.length;j++){ labels.setProperty(s.getToken(j),reduction.getTokenProp(), classLabels[j].bestClassName()); } pc.progress(); } pc.finished(); reduction.extractFromTags(annotationType,labels); } @Override public String explainAnnotation(TextLabels labels,Span documentSpan){ return "not implemented"; } @Override public String toString(){ return "[SequenceAnnotator "+annotationType+":\n"+seqClassifier+"]"; } @Override public Viewer toGUI(){ Viewer v=new ComponentViewer(){ static final long serialVersionUID=20080306L; @Override public JComponent componentFor(Object o){ SequenceAnnotator sa=(SequenceAnnotator)o; JPanel mainPanel=new JPanel(); mainPanel.setBorder(new TitledBorder("Sequence Annotator")); Viewer subView=new SmartVanillaViewer(sa.seqClassifier); subView.setSuperView(this); mainPanel.add(subView); return new JScrollPane(mainPanel); } }; v.setContent(this); return v; } } static public void main(String[] args){ try{ SequenceAnnotator a= (SequenceAnnotator)edu.cmu.minorthird.util.IOUtil .loadSerialized(new java.io.File(args[0])); a.annotationType=args[1]; edu.cmu.minorthird.util.IOUtil .saveSerialized(a,new java.io.File(args[2])); }catch(Exception ex){ ex.printStackTrace(); System.out.println("usage: inputFile new-annotation-type outputfile"); } } }