/* Copyright 2003-2004, Carnegie Mellon, All Rights Reserved */
package edu.cmu.minorthird.classify.sequential;
import java.awt.BorderLayout;
import java.io.Serializable;
import java.util.Iterator;
import java.util.Map;
import java.util.TreeMap;
import javax.swing.JComponent;
import javax.swing.JLabel;
import javax.swing.JPanel;
import javax.swing.JScrollPane;
import javax.swing.border.TitledBorder;
import org.apache.log4j.Logger;
import edu.cmu.minorthird.classify.ClassLabel;
import edu.cmu.minorthird.classify.Classifier;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.ExampleSchema;
import edu.cmu.minorthird.classify.Instance;
import edu.cmu.minorthird.classify.MutableInstance;
import edu.cmu.minorthird.classify.sequential.Segmentation.Segment;
import edu.cmu.minorthird.util.ProgressCounter;
import edu.cmu.minorthird.util.gui.ComponentViewer;
import edu.cmu.minorthird.util.gui.SmartVanillaViewer;
import edu.cmu.minorthird.util.gui.Viewer;
import edu.cmu.minorthird.util.gui.Visible;
/**
*
* @author William Cohen
*/
public class SegmentCollinsPerceptronLearner implements BatchSegmenterLearner,SequenceConstants
{
private static Logger log = Logger.getLogger(SegmentCollinsPerceptronLearner.class);
private static final boolean DEBUG = log.isDebugEnabled();
private int numberOfEpochs;
private boolean updatedViterbi = false;
public SegmentCollinsPerceptronLearner(int epochs) { this.numberOfEpochs = epochs; }
public SegmentCollinsPerceptronLearner(int epochs, boolean updatedViterbi) { this(epochs); this.updatedViterbi=updatedViterbi;}
public SegmentCollinsPerceptronLearner() { this.numberOfEpochs = 5; }
@Override
public void setSchema(ExampleSchema schema) { ; }
//
// accessors
//
public int getNumberOfEpochs() { return numberOfEpochs; }
public void setNumberOfEpochs(int newNumberOfEpochs) { this.numberOfEpochs = newNumberOfEpochs; }
public int getHistorySize() { return 1; }
//
// training scheme
//
@Override
public Segmenter batchTrain(SegmentDataset dataset)
{
int maxSegmentSize = dataset.getMaxWindowSize();
ExampleSchema schema = dataset.getSchema();
if (DEBUG) log.debug("schema: "+schema);
CollinsPerceptronLearner.MultiClassVPClassifier c = new CollinsPerceptronLearner.MultiClassVPClassifier(schema);
//if (DEBUG) log.debug("dataset:\n"+dataset);
ProgressCounter pc =
new ProgressCounter("training semi-markov voted-perceptron",
"sequence",numberOfEpochs*dataset.getNumberOfSegmentGroups());
if (updatedViterbi)
c.setVoteMode(true);
for (int epoch=0; epoch<numberOfEpochs; epoch++)
{
// shuffling seems to lower performance by a lot - why?
//dataset.shuffle();
// statistics for curious researchers
int sequenceErrors = 0;
int transitionErrors = 0;
int transitions = 0;
for (Iterator<CandidateSegmentGroup> i=dataset.candidateSegmentGroupIterator(); i.hasNext(); )
{
CandidateSegmentGroup g = i.next();
if (DEBUG) log.debug("classifier is: "+c);
Segmentation viterbi = new ViterbiSearcher(c,schema,maxSegmentSize).bestSegments(g);
if (DEBUG) log.debug("viterbi:\n"+viterbi);
Segmentation correct = correctSegments(g,schema,maxSegmentSize);
if (DEBUG) log.debug("correct segments:\n"+correct);
boolean errorOnThisSequence = false;
// Segmentation.Segment previousViterbiSeg = null;
int fp = compareSegmentsAndRevise(c, schema, viterbi, correct, -1.0, g);
if (fp>0) errorOnThisSequence = true;
int fn = compareSegmentsAndRevise(c, schema, correct, viterbi, +1.0, g);
if (fn>0) errorOnThisSequence = true;
if (errorOnThisSequence) sequenceErrors++;
transitionErrors += fp + fn;
transitions += correct.size();
c.completeUpdate();
pc.progress();
} // sequence i
System.out.println("Epoch "+epoch+": sequenceErr="+sequenceErrors
+" transitionErrors="+transitionErrors+"/"+transitions);
if (transitionErrors==0) break;
} // epoch
pc.finished();
c.setVoteMode(true);
// construct the classifier
return new ViterbiSegmenter(c,schema,maxSegmentSize);
}
/** Compare the target segments to the 'otherSegments', and update
* the classifier by sum_x [delta*x], for each example x
* corresponding to a target segment that's not in otherSegments.
*/
private int compareSegmentsAndRevise(
CollinsPerceptronLearner.MultiClassVPClassifier classifier,ExampleSchema schema,
Segmentation segments,Segmentation otherSegments,double delta,CandidateSegmentGroup g)
{
int errors = 0;
// first, work out the name of the previous class for each segment
Map<Segment,String> map = previousClassMap(segments,schema);
Map<Segment,String> otherMap = previousClassMap(otherSegments,schema);
String[] history = new String[1];
for (Iterator<Segment> j=segments.iterator(); j.hasNext(); ) {
Segmentation.Segment seg = j.next();
String previousClass = map.get(seg);
if (seg.lo>=0 && (!otherSegments.contains(seg) || !otherMap.get(seg).equals(previousClass))) {
errors++;
history[0] = previousClass;
Instance instance = new InstanceFromSequence( g.getSubsequenceExample(seg.lo,seg.hi), history);
if (DEBUG) log.debug("update "+delta+" for: "+instance.getSource());
classifier.update( schema.getClassName( seg.y ), instance, delta );
}
}
return errors;
}
/** Build a mapping from segment to string name of previous segment.
* This should let you look up segments which are logically
* equivalent, as well as ones which are pointer-equivalent (==)
*/
private Map<Segment,String> previousClassMap(Segmentation segments,ExampleSchema schema)
{
// use a treemap so that logically equivalent segments be mapped to same previousClass
Map<Segment,String> map = new TreeMap<Segment,String>();
Segmentation.Segment previousSeg = null;
for (Iterator<Segment> j=segments.iterator(); j.hasNext(); ) {
Segmentation.Segment seg = j.next();
String previousClassName = previousSeg==null ? NULL_CLASS_NAME : schema.getClassName(previousSeg.y);
map.put( seg, previousClassName);
previousSeg = seg;
}
return map;
}
/** Collect the correct segments for this example. These are defined as
* all segments with non-NEGATIVE labels, and all unit-length negative labels
* not inside a positives label.
*/
private Segmentation correctSegments(CandidateSegmentGroup g,ExampleSchema schema,int maxSegmentSize)
{
Segmentation result = new Segmentation(schema);
int pos, len;
for (pos=0; pos<g.getSequenceLength(); ) {
boolean addedASegmentStartingAtPos = false;
for (len=1; !addedASegmentStartingAtPos && len<=maxSegmentSize; len++) {
Instance inst = g.getSubsequenceInstance(pos,pos+len);
ClassLabel label = g.getSubsequenceLabel(pos,pos+len);
if (inst!=null && !label.isNegative()) {
result.add( new Segmentation.Segment(pos,pos+len,schema.getClassIndex(label.bestClassName())) );
addedASegmentStartingAtPos = true;
pos += len;
}
}
if (!addedASegmentStartingAtPos) {
// Instance inst = g.getSubsequenceInstance(pos,pos+1);
// ClassLabel label = g.getSubsequenceLabel(pos,pos+1);
result.add( new Segmentation.Segment(pos,pos+1,schema.getClassIndex(ExampleSchema.NEG_CLASS_NAME)) );
pos += 1;
}
}
return result;
}
static public class ViterbiSearcher
{
private Classifier classifier;
private ExampleSchema schema;
private int maxSegmentSize;
public ViterbiSearcher(Classifier classifier,ExampleSchema schema,int maxSegmentSize)
{
this.classifier = classifier;
this.schema = schema;
this.maxSegmentSize = maxSegmentSize;
}
public Segmentation bestSegments(CandidateSegmentGroup g)
{
// for t=0..size, y=0 or 1, fty[t][y] is the highest score that
// can be obtained with a segmentation of the tokens from 0..t
// that ends with class y (where y=1 means "from dictionary", y=0
// means "from null model")
// initialize
String[] history = new String[1];
int seqLen = g.getSequenceLength();
int ny = schema.getNumberOfClasses();
int backgroundClass = schema.getClassIndex( ExampleSchema.NEG_CLASS_NAME );
double[][] fty = new double[seqLen+1][ny];
BackPointer[][] trace = new BackPointer[seqLen+1][ny];
for (int t=0; t<seqLen+1; t++) {
for (int y=0; y<ny; y++) {
fty[t][y] = -99999; //could be -Double.MAX_VALUE;
trace[t][y] = null;
}
}
for (int y=0; y<ny; y++) fty[0][y] = 0;
// fill the matrix fty[t][y] = score of maximal segmentation
// from 0..t that ends in y
for (int t=0; t<seqLen+1; t++) {
for (int y=0; y<ny; y++) {
for (int lastY=0; lastY<ny; lastY++) {
int maxSegSizeForY = y==backgroundClass ? 1 : maxSegmentSize;
for (int lastT=Math.max(0, t-maxSegSizeForY); lastT<t; lastT++) {
// find the classifier's score for the subsequence from lastT to t
// with label y and previous label lastY
Instance segmentInstance = g.getSubsequenceInstance(lastT, t);
if (segmentInstance!=null) {
history[0] = schema.getClassName( lastY );
InstanceFromSequence seqSegmentInstance = new InstanceFromSequence(segmentInstance,history);
double segmentScore = classifier.classification(seqSegmentInstance).getWeight( schema.getClassName(y) );
// store the max score (over all lastT,lastY) in fty
if (segmentScore + fty[lastT][lastY] > fty[t][y]) {
fty[t][y] = segmentScore + fty[lastT][lastY];
trace[t][y] = new BackPointer(lastT,t,lastY);
}
}
}
}
}
}
// use the back pointers to find the best segmentation that ends at t==documentSize
int bestEndY = -1;
double bestEndYScore = -Double.MAX_VALUE;
for (int y=0; y<ny; y++) {
if (fty[seqLen][y] > bestEndYScore) {
bestEndYScore = fty[seqLen][y];
bestEndY = y;
}
}
Segmentation result = new Segmentation(schema);
int y = bestEndY;
for (BackPointer bp = trace[seqLen][y]; bp!=null; bp=trace[bp.lastT][bp.lastY]) {
bp.onBestPath = true;
result.add( new Segmentation.Segment(bp.lastT,bp.t,y) );
y = bp.lastY;
}
if (DEBUG) dumpStuff(g,fty,trace);
return result;
}
}
private static class BackPointer {
public int lastT, t,lastY;
public boolean onBestPath;
public BackPointer(int lastT, int t,int lastY) {
this.lastT=lastT;
this.t=t;
this.lastY=lastY;
this.onBestPath=false;
}
}
private static void dumpStuff(CandidateSegmentGroup g, double[][] fty, BackPointer[][] trace)
{
Example nullExample = new Example(new MutableInstance("*NULL*"),new ClassLabel("*NULL*"));
java.text.DecimalFormat format = new java.text.DecimalFormat("####.###");
System.out.println("t.y\tf(t,y)\tt'.y'\tspan");
for (int t=0; t<fty.length; t++) {
for (int y=0; y<fty[t].length; y++) {
BackPointer bp = trace[t][y];
Example ex = bp==null ? nullExample : g.getSubsequenceExample(bp.lastT,bp.t);
if (bp==null) bp = new BackPointer(-1,-1,-1);
String marker = bp.onBestPath? "<==" : "";
System.out.println(t+"."+y+"\t"+format.format(fty[t][y])+"\t"+
bp.lastT+"."+bp.lastY+"\t'"+ex.getSource()+"' "+marker);
}
}
}
public static class ViterbiSegmenter implements Segmenter,Visible,Serializable
{
static private final long serialVersionUID = 20080207L;
private Classifier c;
private ExampleSchema schema;
private int maxSegSize;
public ViterbiSegmenter(Classifier c,ExampleSchema schema,int maxSegSize)
{
this.c = c;
this.schema = schema;
this.maxSegSize = maxSegSize;
}
@Override
public Segmentation segmentation(CandidateSegmentGroup g)
{
return new ViterbiSearcher(c,schema,maxSegSize).bestSegments(g);
}
@Override
public String explain(CandidateSegmentGroup g)
{
return "not implemented yet";
}
@Override
public Viewer toGUI()
{
Viewer v = new ComponentViewer() {
static final long serialVersionUID=20080207L;
@Override
public JComponent componentFor(Object o) {
ViterbiSegmenter vs = (ViterbiSegmenter)o;
JPanel mainPanel = new JPanel();
mainPanel.setLayout(new BorderLayout());
mainPanel.add(new JLabel("ViterbiSegmenter: maxSegSize="+vs.maxSegSize),BorderLayout.NORTH);
Viewer subView = new SmartVanillaViewer(vs.c);
subView.setSuperView(this);
mainPanel.add(subView,BorderLayout.SOUTH);
mainPanel.setBorder(new TitledBorder("ViterbiSegmenter"));
return new JScrollPane(mainPanel);
}
};
v.setContent(this);
return v;
}
}
}