package edu.cmu.minorthird.text.learn;
import java.util.Iterator;
import java.util.Map;
import java.util.TreeMap;
import edu.cmu.minorthird.classify.BinaryClassifier;
import edu.cmu.minorthird.classify.Instance;
import edu.cmu.minorthird.text.AbstractAnnotator;
import edu.cmu.minorthird.text.Details;
import edu.cmu.minorthird.text.MixupFinder;
import edu.cmu.minorthird.text.MonotonicTextLabels;
import edu.cmu.minorthird.text.Span;
import edu.cmu.minorthird.text.SpanFinder;
import edu.cmu.minorthird.text.TextLabels;
import edu.cmu.minorthird.text.mixup.Mixup;
import edu.cmu.minorthird.util.ProgressCounter;
/**
* Annotator based on classifiers for start, labels, and length.
*
* @author William Cohen
*/
public class StartEndLengthAnnotator extends AbstractAnnotator implements
ExtractorAnnotator{
private double threshold=0.5;
// finds single-token spans
static private final SpanFinder tokenFinder;
static{
try{
tokenFinder=new MixupFinder(new Mixup("...[any]..."));
}catch(Mixup.ParseException e){
throw new IllegalStateException("illegal tokenFinder");
}
}
// scores lengths
private LengthScorer lengthScorer;
private SpanFeatureExtractor fe;
private BinaryClassifier start,end;
private String annotationType;
/**
* Create an annotator.
*/
public StartEndLengthAnnotator(BinaryClassifier start,BinaryClassifier end,
SpanFeatureExtractor fe,Map<Integer,Integer> lengthMap,int totalPosSpans,
String annotationType){
this.fe=fe;
this.start=start;
this.end=end;
lengthScorer=new LengthScorer(lengthMap,totalPosSpans);
this.annotationType=annotationType;
}
public void setThreshold(double t){
this.threshold=t;
}
public double getThreshold(){
return threshold;
}
@Override
public String getSpanType(){
return annotationType;
}
/** Return something that finds beginnings (for debugging). */
public SpanFinder getStartFinder(){
return new FilteredFinder(start,fe,tokenFinder);
}
/** Return something that finds ends (for debugging). */
public SpanFinder getEndFinder(){
return new FilteredFinder(end,fe,tokenFinder);
}
@Override
protected void doAnnotate(MonotonicTextLabels labels){
Iterator<Span> i=labels.getTextBase().documentSpanIterator();
ProgressCounter pc=new ProgressCounter("annotate","document");
while(i.hasNext()){
Span document=i.next();
// look for all start and end tokens
double[] startPred=new double[document.size()];
double[] endPred=new double[document.size()];
for(int j=0;j<document.size();j++){
Span tokenSpan=document.subSpan(j,1);
Instance instance=fe.extractInstance(labels,tokenSpan);
startPred[j]=start.score(instance);
endPred[j]=end.score(instance);
// System.out.println(document.getDocumentId()+" "+tokenSpan+" "+j+"
// start:"+startPred[j]+" end: "+endPred[j]);
}
// look for nearby start-end pairs, score them
for(int j=0;j<=document.size()-1;j++){
double startScore=startPred[j];
if(startScore<threshold)
continue;
// System.out.println("possible start "+j+" ["+startScore+"]
// "+document.subSpan(0,j+1));
for(int len=1;j+len<=document.size()&&len<=lengthScorer.maxLength();len++){
double lengthScore=lengthScorer.score(len);
if(lengthScore+startScore<threshold)
continue;
// System.out.println("possible length ["+lengthScore+"] "+len);
double endScore=endPred[j+len-1];
// System.out.println("possible end "+(j+len-1)+" ["+endScore+"]
// "+document.subSpan(0,j+len));
double finalScore=startScore+lengthScore+endScore;
// show something
String lContext=
document.subSpan(Math.max(0,j-5),Math.min(5,j-Math.max(0,j-5)))
.asString();
String rContext=
document.subSpan(j+len,Math.min(5,document.size()-j-len))
.asString();
String cContext=document.subSpan(j,len).asString();
// System.out.println("possible start ["+startScore+"]
// "+document.subSpan(0,j+1));
// System.out.println("possible end ["+endScore+"]
// "+document.subSpan(0,j+len));
// System.out.println("possible length ["+lengthScore+"] "+len);
if(finalScore>threshold){
System.out.println("output ["+finalScore+"] "+lContext+"|"+
cContext+"|"+rContext);
// put a high-scoring combination in the labels
Map<String,Double> m=new TreeMap<String,Double>();
m.put("start",new Double(startPred[j]));
m.put("end",new Double(endPred[j+len-1]));
m.put("length",new Double(lengthScore));
labels.addToType(document.subSpan(j,len),annotationType,
new Details(finalScore,m));
}
}
}
pc.progress();
}
pc.finished();
}
@Override
public String explainAnnotation(TextLabels labels,Span documentSpan){
return "not implemented";
}
@Override
public String toString(){
return "[StartEndLen: "+start+";"+end+";"+lengthScorer+"]";
}
/**
* Scores lengths using a smoothed histogram.
*/
private static class LengthScorer{
private Map<Integer,Integer> lengthFreqMap;
private int numLengths;
private double mixingFactor=0.1;
private int maxLength=0;
public LengthScorer(Map<Integer,Integer> lengthFreqMap,int totalPosSpans){
this.lengthFreqMap=lengthFreqMap;
this.numLengths=totalPosSpans;
for(Iterator<Integer> i=lengthFreqMap.keySet().iterator();i.hasNext();){
int len=i.next().intValue();
maxLength=Math.max(maxLength,len);
}
}
public int maxLength(){
return maxLength;
}
/** Return Prob(len) */
public double score(int len){
Integer freq=lengthFreqMap.get(new Integer(len));
double empiricalProb=freq==null?0:((double)freq.intValue())/numLengths;
double smoothedProb=
(mixingFactor/maxLength)+(1-mixingFactor)*empiricalProb;
double odds=Math.log(smoothedProb/(1.0-smoothedProb));
// System.out.println("odds of len="+len+": "+odds);
return odds;
}
@Override
public String toString(){
return "[LengthScorer: "+maxLength+";"+lengthFreqMap+"]";
}
}
}