/* Copyright 2003-2004, Carnegie Mellon, All Rights Reserved */ package edu.cmu.minorthird.classify.sequential; import java.util.Iterator; import java.util.Map; import java.util.TreeMap; import org.apache.log4j.Logger; import edu.cmu.minorthird.classify.ClassLabel; import edu.cmu.minorthird.classify.Classifier; import edu.cmu.minorthird.classify.Example; import edu.cmu.minorthird.classify.ExampleSchema; import edu.cmu.minorthird.classify.Instance; import edu.cmu.minorthird.classify.OnlineClassifierLearner; import edu.cmu.minorthird.classify.algorithms.linear.Hyperplane; import edu.cmu.minorthird.classify.algorithms.linear.MarginPerceptron; import edu.cmu.minorthird.classify.sequential.Segmentation.Segment; import edu.cmu.minorthird.util.ProgressCounter; /** * * Semi-markov version of GenericCollinsLearner. * * @author William Cohen */ public class SegmentGenericCollinsLearner implements BatchSegmenterLearner,SequenceConstants { private static Logger log = Logger.getLogger(CollinsPerceptronLearner.class); private static final boolean DEBUG = log.isDebugEnabled(); private OnlineClassifierLearner innerLearnerPrototype; private OnlineClassifierLearner[] innerLearner; private int numberOfEpochs; private int maxSegmentSize; public SegmentGenericCollinsLearner() { this(new MarginPerceptron(0.0,false,true)); } public SegmentGenericCollinsLearner(OnlineClassifierLearner innerLearner) { this(innerLearner,5); } public SegmentGenericCollinsLearner(int epochs) { this(new MarginPerceptron(0.0,false,true),epochs); } public SegmentGenericCollinsLearner(OnlineClassifierLearner innerLearner,int epochs) { this(innerLearner,4,epochs); } public SegmentGenericCollinsLearner(OnlineClassifierLearner innerLearner,int maxSegmentSize,int epochs) { this.maxSegmentSize = maxSegmentSize; this.innerLearnerPrototype = innerLearner; this.numberOfEpochs = epochs; } @Override public void setSchema(ExampleSchema schema) { ; } // // accessors // public OnlineClassifierLearner getInnerLearner() { return innerLearnerPrototype; } public void setInnerLearner(OnlineClassifierLearner newInnerLearner) { this.innerLearnerPrototype = newInnerLearner; } public int getHistorySize() { return 1; } public int getNumberOfEpochs() { return numberOfEpochs; } public void setNumberOfEpochs(int newNumberOfEpochs) { this.numberOfEpochs = newNumberOfEpochs; } @Override public Segmenter batchTrain(SegmentDataset dataset) { ExampleSchema schema = dataset.getSchema(); innerLearner = SequenceUtils.duplicatePrototypeLearner(innerLearnerPrototype,schema.getNumberOfClasses()); ProgressCounter pc = new ProgressCounter("training segments "+innerLearnerPrototype.toString(), "sequence",numberOfEpochs*dataset.getNumberOfSegmentGroups()); for (int epoch=0; epoch<numberOfEpochs; epoch++) { //dataset.shuffle(); // statistics for curious researchers int sequenceErrors = 0; int transitionErrors = 0; int transitions = 0; for (Iterator<CandidateSegmentGroup> i=dataset.candidateSegmentGroupIterator(); i.hasNext(); ) { Classifier c = new SequenceUtils.MultiClassClassifier(schema,innerLearner); if (DEBUG) log.debug("classifier is: "+c); CandidateSegmentGroup g = i.next(); Segmentation viterbi = new SegmentCollinsPerceptronLearner.ViterbiSearcher(c,schema,maxSegmentSize).bestSegments(g); if (DEBUG) log.debug("viterbi "+maxSegmentSize+"\n"+viterbi); Segmentation correct = correctSegments(g,schema,maxSegmentSize); if (DEBUG) log.debug("correct segments:\n"+correct); boolean errorOnThisSequence=false; // accumulate weights for transitions associated with each class k Hyperplane[] accumPos = new Hyperplane[schema.getNumberOfClasses()]; Hyperplane[] accumNeg = new Hyperplane[schema.getNumberOfClasses()]; for (int k=0; k<schema.getNumberOfClasses(); k++) { accumPos[k] = new Hyperplane(); accumNeg[k] = new Hyperplane(); } int fp = compareSegmentsAndIncrement(schema, viterbi, correct, accumNeg, +1, g); if (fp>0) errorOnThisSequence = true; int fn = compareSegmentsAndIncrement(schema, correct, viterbi, accumPos, +1, g); if (fn>0) errorOnThisSequence = true; if (errorOnThisSequence) sequenceErrors++; transitionErrors += fp+fn; if (errorOnThisSequence) { sequenceErrors++; String subPopId = g.getSubpopulationId(); Object source = "no source"; for (int k=0; k<schema.getNumberOfClasses(); k++) { //System.out.println("adding class="+k+" example: "+accumPos[k]); innerLearner[k].addExample( new Example( new HyperplaneInstance(accumPos[k],subPopId,source), ClassLabel.positiveLabel(+1.0) )); innerLearner[k].addExample( new Example( new HyperplaneInstance(accumNeg[k],subPopId,source), ClassLabel.negativeLabel(-1.0) )); } } transitions += correct.size(); pc.progress(); } // sequence i System.out.println("Epoch "+epoch+": sequenceErr="+sequenceErrors +" transitionErrors="+transitionErrors+"/"+transitions); if (transitionErrors==0) break; } // epoch pc.finished(); for (int k=0; k<schema.getNumberOfClasses(); k++) { innerLearner[k].completeTraining(); } Classifier c = new SequenceUtils.MultiClassClassifier(schema,innerLearner); return new SegmentCollinsPerceptronLearner.ViterbiSegmenter(c, schema, maxSegmentSize); } /** Compare the target segments to the 'otherSegments', and update * the classifier by sum_x [delta*x], for each example x * corresponding to a target segment that's not in otherSegments. */ private int compareSegmentsAndIncrement( ExampleSchema schema,Segmentation segments,Segmentation otherSegments, Hyperplane[] accum,double delta,CandidateSegmentGroup g) { int errors = 0; // first, work out the name of the previous class for each segment Map<Segment,String> map = previousClassMap(segments,schema); Map<Segment,String> otherMap = previousClassMap(otherSegments,schema); String[] history = new String[1]; for (Iterator<Segment> j=segments.iterator(); j.hasNext(); ) { Segmentation.Segment seg = j.next(); String previousClass = map.get(seg); if (seg.lo>=0 && (!otherSegments.contains(seg) || !otherMap.get(seg).equals(previousClass))) { errors++; history[0] = previousClass; Instance instance = new InstanceFromSequence( g.getSubsequenceExample(seg.lo,seg.hi), history); if (DEBUG) log.debug("class "+schema.getClassName(seg.y)+" update "+delta+" for: "+instance.getSource()); accum[seg.y].increment( instance, delta ); } } return errors; } /** Build a mapping from segment to string name of previous segment. * This should let you look up segments which are logically * equivalent, as well as ones which are pointer-equivalent (==) */ private Map<Segment,String> previousClassMap(Segmentation segments,ExampleSchema schema) { // use a treemap so that logically equivalent segments be mapped to same previousClass Map<Segment,String> map = new TreeMap<Segment,String>(); Segmentation.Segment previousSeg = null; for (Iterator<Segment> j=segments.iterator(); j.hasNext(); ) { Segmentation.Segment seg = j.next(); String previousClassName = previousSeg==null ? NULL_CLASS_NAME : schema.getClassName(previousSeg.y); map.put( seg, previousClassName); previousSeg = seg; } return map; } /** Collect the correct segments for this example. These are defined as * all segments with non-NEGATIVE labels, and all unit-length negative labels * not inside a positives label. */ private Segmentation correctSegments(CandidateSegmentGroup g,ExampleSchema schema,int maxSegmentSize) { Segmentation result = new Segmentation(schema); int pos, len; for (pos=0; pos<g.getSequenceLength(); ) { boolean addedASegmentStartingAtPos = false; for (len=1; !addedASegmentStartingAtPos && len<=maxSegmentSize; len++) { Instance inst = g.getSubsequenceInstance(pos,pos+len); ClassLabel label = g.getSubsequenceLabel(pos,pos+len); if (inst!=null && !label.isNegative()) { result.add( new Segmentation.Segment(pos,pos+len,schema.getClassIndex(label.bestClassName())) ); addedASegmentStartingAtPos = true; pos += len; } } if (!addedASegmentStartingAtPos) { // Instance inst = g.getSubsequenceInstance(pos,pos+1); // ClassLabel label = g.getSubsequenceLabel(pos,pos+1); result.add( new Segmentation.Segment(pos,pos+1,schema.getClassIndex(ExampleSchema.NEG_CLASS_NAME)) ); pos += 1; } } return result; } }