/* Copyright 2003, Carnegie Mellon, All Rights Reserved */
package edu.cmu.minorthird.text.learn;
import java.awt.BorderLayout;
import java.io.File;
import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.TreeSet;
import javax.swing.JComponent;
import javax.swing.JLabel;
import javax.swing.JPanel;
import javax.swing.JScrollPane;
import javax.swing.border.TitledBorder;
import org.apache.log4j.Logger;
import com.wcohen.ss.BasicStringWrapper;
import com.wcohen.ss.DistanceLearnerFactory;
import com.wcohen.ss.api.StringDistance;
import com.wcohen.ss.api.StringDistanceLearner;
import com.wcohen.ss.api.StringWrapper;
import com.wcohen.ss.lookup.SoftDictionary;
import edu.cmu.minorthird.classify.BinaryClassifier;
import edu.cmu.minorthird.classify.ClassLabel;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.ExampleSchema;
import edu.cmu.minorthird.classify.Feature;
import edu.cmu.minorthird.classify.Instance;
import edu.cmu.minorthird.classify.OnlineBinaryClassifierLearner;
import edu.cmu.minorthird.classify.algorithms.linear.VotedPerceptron;
import edu.cmu.minorthird.classify.sequential.InstanceFromSequence;
import edu.cmu.minorthird.text.AbstractAnnotator;
import edu.cmu.minorthird.text.Annotator;
import edu.cmu.minorthird.text.EmptyLabels;
import edu.cmu.minorthird.text.MonotonicTextLabels;
import edu.cmu.minorthird.text.Span;
import edu.cmu.minorthird.text.TextLabels;
import edu.cmu.minorthird.util.ProgressCounter;
import edu.cmu.minorthird.util.gui.ComponentViewer;
import edu.cmu.minorthird.util.gui.SmartVanillaViewer;
import edu.cmu.minorthird.util.gui.Viewer;
import edu.cmu.minorthird.util.gui.ViewerFrame;
import edu.cmu.minorthird.util.gui.Visible;
/**
* Learn to annotate based on a conditional semi-markov model, learned from
* examples.
*
* @author William Cohen
*/
/*
* status/limitations: this only learns one label type, with a single binary
* classifier.
*/
public class ConditionalSemiMarkovModel{
private static Logger log=Logger.getLogger(ConditionalSemiMarkovModel.class);
private static final boolean DEBUG=log.isDebugEnabled();
/**
* A learner for ConditionalSemiMarkovModel's.
*/
static public class CSMMLearner extends AnnotatorLearner{
private SpanFeatureExtractor fe;
private OnlineBinaryClassifierLearner classifierLearner;
private int epochs;
private int maxSegmentSize=5;
// temporary storage
private Iterator<Span> documentLooper;
private List<AnnotationExample> exampleList;
// type of annotation to produce
private String annotationType;
//
// constructors
//
public CSMMLearner(){
this(new CSMMSpanFE(),new VotedPerceptron(),5,5,"");
}
public CSMMLearner(int epochs){
this(new CSMMSpanFE(),new VotedPerceptron(),epochs,5,"");
}
public CSMMLearner(int epochs,int maxSegmentSize){
this(new CSMMSpanFE(),new VotedPerceptron(),epochs,maxSegmentSize,"");
}
public CSMMLearner(String annotation){
this(new CSMMSpanFE(),new VotedPerceptron(),5,5,annotation);
}
public CSMMLearner(String dictionaryFile,String distanceNames,
int maxSegmentSize){
this(dictionaryFile,distanceNames,5,maxSegmentSize);
}
public CSMMLearner(String dictionaryFile,String distanceNames,int epoch,
int maxSegmentSize){
this(dictionaryFile,distanceNames,epoch,maxSegmentSize,"");
}
public CSMMLearner(String dictionaryFile,String distanceNames,int epoch,
int maxSegmentSize,String mixFile){
this(dictionaryFile,distanceNames,epoch,maxSegmentSize,false,mixFile);
}
public CSMMLearner(String dictionaryFile,String distanceNames,
int epochSize,int maxSegmentSize,boolean addTraining,
boolean doCrossVal,String mixFile){
this(new CSMMWithDictionarySpanFE(dictionaryFile,distanceNames,
addTraining,doCrossVal),new VotedPerceptron(),epochSize,
maxSegmentSize,mixFile);
}
public CSMMLearner(String dictionaryFile,String distanceNames,int epoch,
int maxSegmentSize,boolean addTraining,String mixFile){
this(dictionaryFile,distanceNames,epoch,maxSegmentSize,addTraining,true,
mixFile);
}
public CSMMLearner(SpanFeatureExtractor fe,
OnlineBinaryClassifierLearner classifierLearner,int epochs,
int maxSegSz,String annotation){
this.fe=fe;
if(annotation.length()>0){
System.out.println("Reading annotations");
((CSMMSpanFE)fe).setRequiredAnnotation(annotation,annotation+".mixup");
((CSMMSpanFE)fe).setTokenPropertyFeatures("*"); // use all defined
// properties
}
this.classifierLearner=classifierLearner;
this.epochs=epochs;
this.maxSegmentSize=maxSegSz;
reset();
}
//
// getters/setters for gui
//
public OnlineBinaryClassifierLearner getLearner(){
return classifierLearner;
}
public void setLearner(OnlineBinaryClassifierLearner newLearner){
this.classifierLearner=newLearner;
}
public int getEpochs(){
return epochs;
}
public void setEpochs(int newEpochs){
this.epochs=newEpochs;
}
public int getMaxSegmentSize(){
return maxSegmentSize;
}
public void setMaxSegmentSize(int newMaxSize){
this.maxSegmentSize=newMaxSize;
}
//
@Override
public SpanFeatureExtractor getSpanFeatureExtractor(){
return fe;
}
@Override
public void setSpanFeatureExtractor(SpanFeatureExtractor fe){
this.fe=fe;
}
//
// AnnotatorLearner implementation: query all documents, and accumulate
// examples in exampleList
//
@Override
public void reset(){
exampleList=new ArrayList<AnnotationExample>();
}
@Override
public void setDocumentPool(Iterator<Span> documentLooper){
this.documentLooper=documentLooper;
}
@Override
public boolean hasNextQuery(){
return documentLooper.hasNext();
}
@Override
public Span nextQuery(){
return documentLooper.next();
}
@Override
public void setAnswer(AnnotationExample answeredQuery){
exampleList.add(answeredQuery);
}
@Override
public void setAnnotationType(String s){
this.annotationType=s;
}
@Override
public String getAnnotationType(){
return annotationType;
}
/**
* Learning takes place here.
*/
@Override
public Annotator getAnnotator(){
classifierLearner.reset();
log.debug("processing "+exampleList.size()+" examples for "+epochs+
" epochs");
ProgressCounter pc=
new ProgressCounter("training CSMM","document",epochs*
exampleList.size());
if(fe.getClass().getName().endsWith("CSMMWithDictionarySpanFE"))
((CSMMWithDictionarySpanFE)fe).train(exampleList.iterator());
for(int i=0;i<epochs;i++){
for(Iterator<AnnotationExample> j=exampleList.iterator();j.hasNext();){
AnnotationExample example=j.next();
Span doc=example.getDocumentSpan();
if(DEBUG)
log.debug("updating from "+doc);
// get best segmentation, given current classifier
Segments viterbi=
bestSegments(doc,example.getLabels(),fe,classifierLearner
.getBinaryClassifier(),maxSegmentSize);
if(DEBUG)
log.debug("viterbi solution:\n"+viterbi);
// train classifier on any false positives
Segments correct=correctSegments(example);
if(DEBUG)
log.debug("correct spans:\n"+correct);
Span previousSpan=null;
for(Iterator<Span> k=viterbi.iterator();k.hasNext();){
Span span=k.next();
if(!correct.contains(span)){
if(DEBUG)
log.debug("false pos: "+span);
classifierLearner.addExample(exampleFor(example,span,
previousSpan,-1));
}
previousSpan=span;
}
// train classifier on any false negatives
// ignoring context (for now)
previousSpan=null;
for(Iterator<Span> k=correct.iterator();k.hasNext();){
Span span=k.next();
if(!viterbi.contains(span)){
if(DEBUG)
log.debug("false neg: "+span);
classifierLearner.addExample(exampleFor(example,span,
previousSpan,+1));
}
previousSpan=span;
}
pc.progress();
}// epoch
if(DEBUG){
new ViewerFrame("classifier after epoch "+i,new SmartVanillaViewer(
classifierLearner.getBinaryClassifier()));
}
pc.finished();
}// all epochs
return new CSMMAnnotator(fe,classifierLearner.getBinaryClassifier(),
annotationType,maxSegmentSize);
}
// build an example from a span and its context
private Example exampleFor(AnnotationExample example,Span span,
Span prevSpan,double numberLabel){
Instance instance=fe.extractInstance(example.getLabels(),span);
String prevLabel;
if(prevSpan!=null&&
prevSpan.getRightBoundary().equals(span.getLeftBoundary())){
prevLabel=ExampleSchema.POS_CLASS_NAME;
}else{
prevLabel=ExampleSchema.NEG_CLASS_NAME;
}
Instance instanceFromSeq=
new InstanceFromSequence(instance,new String[]{prevLabel});
if(DEBUG)
log.debug("example for "+span+": "+instanceFromSeq);
return new Example(instanceFromSeq,ClassLabel.binaryLabel(numberLabel));
}
// the correct segments, as defined by the example
private Segments correctSegments(AnnotationExample example){
Set<Span> set=new TreeSet<Span>();
String id=example.getDocumentSpan().getDocumentId();
String type=example.getInputType();
for(Iterator<Span> i=example.getLabels().instanceIterator(type,id);i
.hasNext();){
set.add(i.next());
}
return new Segments(set);
}
} // class CSMMLearner
// annotate a document using a learned model
static public class CSMMAnnotator extends AbstractAnnotator implements
Visible,ExtractorAnnotator,Serializable{
static private final long serialVersionUID=20080306L;
private SpanFeatureExtractor fe;
private BinaryClassifier classifier;
private String annotationType;
private int maxSegSize;
@Override
public Viewer toGUI(){
Viewer v=new ComponentViewer(){
static final long serialVersionUID=20080306L;
@Override
public JComponent componentFor(Object o){
CSMMAnnotator ann=(CSMMAnnotator)o;
JPanel mainPanel=new JPanel();
mainPanel.setLayout(new BorderLayout());
mainPanel.add(new JLabel("CSMM: segsize "+maxSegSize),
BorderLayout.NORTH);
Viewer subView=new SmartVanillaViewer(ann.classifier);
subView.setSuperView(this);
mainPanel.add(subView,BorderLayout.SOUTH);
mainPanel
.setBorder(new TitledBorder("Conditional Semi-Markov-Model"));
return new JScrollPane(mainPanel);
}
};
v.setContent(this);
return v;
}
public CSMMAnnotator(SpanFeatureExtractor fe,BinaryClassifier classifier,
String annotationType,int maxSegSize){
this.fe=fe;
this.classifier=classifier;
this.annotationType=annotationType;
this.maxSegSize=maxSegSize;
}
@Override
public String getSpanType(){
return annotationType;
}
@Override
public void doAnnotate(MonotonicTextLabels labels){
ProgressCounter pc=
new ProgressCounter("annotating","document",labels.getTextBase()
.size());
for(Iterator<Span> i=labels.getTextBase().documentSpanIterator();i
.hasNext();){
Span doc=i.next();
Segments viterbi=bestSegments(doc,labels,fe,classifier,maxSegSize);
for(Iterator<Span> j=viterbi.iterator();j.hasNext();){
Span span=j.next();
labels.addToType(span,annotationType);
}
pc.progress();
}
pc.finished();
}
@Override
public String explainAnnotation(TextLabels labels,Span documentSpan){
return "not implemented";
}
}
//
// viterbi algorithm
//
static public Segments bestSegments(Span documentSpan,TextLabels labels,
SpanFeatureExtractor fe,BinaryClassifier classifier,int maxSegSize){
// for t=0..size, y=0 or 1, fty[t][y] is the highest score that
// can be obtained with a segmentation of the tokens from 0..t
// that ends with class y (where y=1 means "from dictionary", y=0
// means "from null model")
// initialize
double[][] fty=new double[documentSpan.size()+1][2];
BackPointer[][] trace=new BackPointer[documentSpan.size()+1][2];
for(int t=0;t<documentSpan.size()+1;t++){
for(int y=0;y<2;y++){
fty[t][y]=-99999; // could be -Double.MAX_VALUE;
trace[t][y]=null;
}
}
fty[0][0]=fty[0][1]=0;
// fill the fty matrix
for(int t=0;t<documentSpan.size()+1;t++){
for(int y=0;y<2;y++){
for(int lastY=0;lastY<2;lastY++){
int maxSegSizeForY=y==0?1:maxSegSize;
for(int lastT=Math.max(0,t-maxSegSizeForY);lastT<t;lastT++){
Span segment=documentSpan.subSpan(lastT,t-lastT);
double segmentScore=
score(labels,lastY,y,lastT,t,segment,fe,classifier);
if(segmentScore+fty[lastT][lastY]>fty[t][y]){
fty[t][y]=segmentScore+fty[lastT][lastY];
trace[t][y]=new BackPointer(segment,lastT,lastY);
}
}
}
}
}
// use the back pointers to find the best segmentation that ends at
// t==documentSize
int y=(fty[documentSpan.size()][1]>fty[documentSpan.size()][0])?1:0;
Set<Span> result=new TreeSet<Span>();
for(BackPointer bp=trace[documentSpan.size()][y];bp!=null;bp=
trace[bp.lastT][bp.lastY]){
bp.onBestPath=true;
if(y==1)
result.add(bp.span);
y=bp.lastY;
}
if(DEBUG)
dumpStuff(fty,trace);
return new Segments(result);
}
private static void dumpStuff(double[][] fty,BackPointer[][] trace){
java.text.DecimalFormat format=new java.text.DecimalFormat("####.###");
System.out.println("t.y\tf(t,y)\tt'.y'\tspan");
for(int t=0;t<fty.length;t++){
for(int y=0;y<2;y++){
BackPointer bp=trace[t][y];
String spanText=bp==null?"*NULL*":bp.span.asString();
if(bp==null)
bp=new BackPointer((Span)null,-1,-1);
String marker=bp.onBestPath?"<==":"";
System.out.println(t+"."+y+"\t"+format.format(fty[t][y])+"\t"+bp.lastT+
"."+bp.lastY+" '"+spanText+"' "+marker);
}
}
}
// used by viterbi
static private double score(TextLabels labels,int lastY,int y,int lastT,
int t,Span segment,SpanFeatureExtractor fe,BinaryClassifier cls){
if(y==0)
return 0;
String prevLabel=
lastY==1?ExampleSchema.POS_CLASS_NAME:ExampleSchema.NEG_CLASS_NAME;
// System.out.println("score with labels "+labels.getClass());
Instance segmentInstance=
new InstanceFromSequence(fe.extractInstance(labels,segment),
new String[]{prevLabel});
if(DEBUG)
log.debug("score: "+cls.score(segmentInstance)+"\t"+segment);
// System.out.println("score" + cls.score(segmentInstance)+"\t"+segment + "
// instance " + segmentInstance);
return cls.score(segmentInstance);
}
private static class BackPointer{
public Span span;
public int lastT,lastY;
public boolean onBestPath;
public BackPointer(Span span,int lastT,int lastY){
this.span=span;
this.lastT=lastT;
this.lastY=lastY;
this.onBestPath=false;
}
}
//
// convert a span to an instance
//
static public class CSMMSpanFE extends SampleFE.ExtractionFE{
static final long serialVersionUID=20080306L;
public CSMMSpanFE(){
super();
}
public CSMMSpanFE(int windowSize){
super(windowSize);
}
public CSMMSpanFE(String mixupFile){
super();
setRequiredAnnotation(mixupFile,mixupFile+".mixup");
setTokenPropertyFeatures("*");
}
@Override
public void extractFeatures(Span span){
extractFeatures(new EmptyLabels(),span);
}
@Override
public void extractFeatures(TextLabels labels,Span span){
super.extractFeatures(labels,span);
// text of span & its charTypePattern
from(span).eq().lc().emit();
if(useCharType)
from(span).eq().charTypes().emit();
if(useCompressedCharType)
from(span).eq().charTypePattern().emit();
// length properties of span
from(span).size().emit();
from(span).exactSize().emit();
// first and last tokens
from(span).token(0).eq().lc().emit();
from(span).token(-1).eq().lc().emit();
if(useCharType){
from(span).token(0).eq().charTypes().lc().emit();
from(span).token(-1).eq().charTypes().lc().emit();
}
if(useCompressedCharType){
from(span).token(0).eq().charTypePattern().lc().emit();
from(span).token(-1).eq().charTypePattern().lc().emit();
}
// use marked properties of tokens for first & last tokens in span
for(int i=0;i<tokenPropertyFeatures.length;i++){
String p=tokenPropertyFeatures[i];
// first & last tokens
from(span).token(0).prop(p).emit();
from(span).token(-1).prop(p).emit();
from(span).subSpan(1,span.size()-2).tokens().prop(p).emit();
}
}
};
/**
* Feature extractor for providing distance-based features on terms.
* Dictionary can be specified either as an external file or by using the
* training spans. - Sunita Sarawagi
*/
static public class CSMMWithDictionarySpanFE extends CSMMSpanFE{
static final long serialVersionUID=20080306L;
boolean addTrainingSegsToDictionary;
boolean useCrossVal;
SoftDictionary dictionary;
StringDistance distances[];
Feature features[];
// distanceNames has to be "/" separated list of distance functions that
// we want to apply
public CSMMWithDictionarySpanFE(String dictionaryFile,String distanceNames){
this(dictionaryFile,distanceNames,false,false);
}
public CSMMWithDictionarySpanFE(String dictionaryFile,String distanceNames,
boolean addTraining,boolean useCrossValArg){
super();
try{
addTrainingSegsToDictionary=addTraining;
useCrossVal=useCrossValArg;
dictionary=new SoftDictionary();
distances=DistanceLearnerFactory.buildArray(distanceNames);
if(dictionaryFile.length()>0){
dictionary.load(new File(dictionaryFile));
trainDistances();
}
// now create features corresponding to each distance function.
features=new Feature[distances.length];
for(int d=0;d<distances.length;d++){
// save the feature name
features[d]=new Feature(distances[d].toString());
}
}catch(IOException e){
e.printStackTrace();
}
}
public void trainDistances(){
for(int d=0;d<distances.length;d++){
// train anything that's also a distance learner
if(distances[d] instanceof StringDistanceLearner){
distances[d]=
dictionary.getTeacher()
.train((StringDistanceLearner)distances[d]);
}
}
}
public void train(Iterator<AnnotationExample> iter){
if(!addTrainingSegsToDictionary)
return;
int numAdded=0;
//float total=0;
for(;iter.hasNext();){
AnnotationExample example=iter.next();
String id=example.getDocumentSpan().getDocumentId();
String type=example.getInputType();
for(Iterator<Span> i=example.getLabels().instanceIterator(type,id);i
.hasNext();){
String thisSeg=i.next().asString();
/**
* Uncomment for reporting distances..
*
* float dist = (float)dictionary.lookupDistance(thisSeg);
* System.out.println("Match " + thisSeg + " => " +
* dictionary.lookup(thisSeg)+ " at " +
* dictionary.lookupDistance(thisSeg)); if (dist > 0) total += dist;
*/
numAdded++;
dictionary.put(id,thisSeg,null);
}
}
trainDistances();
// System.out.println("Average distance " + total/numAdded + " over " +
// numAdded);
}
@Override
public void extractFeatures(TextLabels labels,Span span){
super.extractFeatures(labels,span);
StringWrapper spanString=new BasicStringWrapper(span.asString());
String id=
((addTrainingSegsToDictionary&&useCrossVal)?span.getDocumentId():null);
Object closestMatch=dictionary.lookup(id,spanString);
if(closestMatch!=null){
// create various types of similarity measures.
for(int d=0;d<distances.length;d++){
double score=
distances[d].score(spanString,(StringWrapper)closestMatch);
if(score!=0){
// instance has been created by the parent.
instance.addNumeric(features[d],score);
}
}
}
}
};
// a proposed segmentation of a document
static public class Segments{
private Set<Span> spanSet;
public Segments(Set<Span> spanSet){
this.spanSet=spanSet;
}
public Iterator<Span> iterator(){
return spanSet.iterator();
}
public boolean contains(Span span){
return spanSet.contains(span);
}
@Override
public String toString(){
return "[Segments: "+spanSet.toString()+"]";
}
}
};