/* Copyright 2003-2004, Carnegie Mellon, All Rights Reserved */ package edu.cmu.minorthird.classify.sequential; import edu.cmu.minorthird.classify.*; import edu.cmu.minorthird.classify.algorithms.linear.*; import edu.cmu.minorthird.util.*; import java.util.*; import org.apache.log4j.*; /** * 'Generic' version of Collin's voted perceptron learner. * * <p>As of May 9, 2004, this is a different algorithm, which is much * more like Collin's original method. The 'old' implementation is in * GenericCollinsLearnerV1. * * @author William Cohen */ public class GenericCollinsLearner implements BatchSequenceClassifierLearner,SequenceConstants { private static Logger log = Logger.getLogger(CollinsPerceptronLearner.class); private static final boolean DEBUG = log.isDebugEnabled(); private OnlineClassifierLearner innerLearnerPrototype; private OnlineClassifierLearner[] innerLearner; private int historySize; private int numberOfEpochs; private String[] history; public GenericCollinsLearner() { this(new MarginPerceptron(0.0,false,true)); } public GenericCollinsLearner(OnlineClassifierLearner innerLearner) { this(innerLearner,5); } public GenericCollinsLearner(int epochs) { this(new MarginPerceptron(0.0,false,true),epochs); } public GenericCollinsLearner(OnlineClassifierLearner innerLearner,int epochs) { this(innerLearner,3,epochs); } public GenericCollinsLearner(OnlineClassifierLearner innerLearner,int historySize,int epochs) { this.historySize = historySize; this.innerLearnerPrototype = innerLearner; this.numberOfEpochs = epochs; this.history = new String[historySize]; } @Override public void setSchema(ExampleSchema schema) { ; } // // accessors // public OnlineClassifierLearner getInnerLearner() { return innerLearnerPrototype; } public void setInnerLearner(OnlineClassifierLearner newInnerLearner) { this.innerLearnerPrototype = newInnerLearner; } @Override public int getHistorySize() { return historySize; } public void setHistorySize(int newHistorySize) { this.historySize = newHistorySize; } public int getNumberOfEpochs() { return numberOfEpochs; } public void setNumberOfEpochs(int newNumberOfEpochs) { this.numberOfEpochs = newNumberOfEpochs; } @Override public SequenceClassifier batchTrain(SequenceDataset dataset) { ExampleSchema schema = dataset.getSchema(); innerLearner = SequenceUtils.duplicatePrototypeLearner(innerLearnerPrototype,schema.getNumberOfClasses()); ProgressCounter pc = new ProgressCounter("training sequential "+innerLearnerPrototype.toString(), "sequence",numberOfEpochs*dataset.numberOfSequences()); for (int epoch=0; epoch<numberOfEpochs; epoch++) { dataset.shuffle(); // statistics for curious researchers int sequenceErrors = 0; int transitionErrors = 0; int transitions = 0; for (Iterator<Example[]> i=dataset.sequenceIterator(); i.hasNext(); ) { Example[] sequence = i.next(); Classifier c = new SequenceUtils.MultiClassClassifier(schema,innerLearner); ClassLabel[] viterbi = new BeamSearcher(c,historySize,schema).bestLabelSequence(sequence); if (DEBUG) log.debug("classifier: "+c); if (DEBUG) log.debug("viterbi:\n"+StringUtil.toString(viterbi)); boolean errorOnThisSequence=false; // accumulate weights for transitions associated with each class k Hyperplane[] accumPos = new Hyperplane[schema.getNumberOfClasses()]; Hyperplane[] accumNeg = new Hyperplane[schema.getNumberOfClasses()]; for (int k=0; k<schema.getNumberOfClasses(); k++) { accumPos[k] = new Hyperplane(); accumNeg[k] = new Hyperplane(); } for (int j=0; j<sequence.length; j++) { // is the instance at sequence[j] associated with a difference in the sum // of feature values over the viterbi sequence and the actual one? boolean differenceAtJ = !viterbi[j].isCorrect( sequence[j].getLabel() ); //System.out.println("differenceAtJ for J="+j+" "+differenceAtJ+" - label"); for (int k=1; j-k>=0 && !differenceAtJ && k<=historySize; k++) { if (!viterbi[j-k].isCorrect( sequence[j-k].getLabel() )) { //System.out.println("differenceAtJ for J="+j+" true: k="+k); differenceAtJ = true; } } if (differenceAtJ) { transitionErrors++; errorOnThisSequence=true; InstanceFromSequence.fillHistory( history, sequence, j ); Instance correctXj = new InstanceFromSequence( sequence[j], history ); int correctClassIndex = schema.getClassIndex( sequence[j].getLabel().bestClassName() ); accumPos[correctClassIndex].increment( correctXj, +1.0 ); accumNeg[correctClassIndex].increment( correctXj, -1.0 ); if (DEBUG) log.debug("+ update "+sequence[j].getLabel().bestClassName()+" "+correctXj.getSource()+";"+correctXj); InstanceFromSequence.fillHistory( history, viterbi, j ); Instance wrongXj = new InstanceFromSequence( sequence[j], history ); int wrongClassIndex = schema.getClassIndex( viterbi[j].bestClassName() ); accumPos[wrongClassIndex].increment( wrongXj, -1.0 ); accumNeg[wrongClassIndex].increment( wrongXj, +1.0 ); if (DEBUG) log.debug("- update "+viterbi[j].bestClassName()+" "+wrongXj.getSource()); } } // example sequence j if (errorOnThisSequence) { sequenceErrors++; String subPopId = sequence[0].getSubpopulationId(); Object source = "no source"; for (int k=0; k<schema.getNumberOfClasses(); k++) { //System.out.println("adding class="+k+" example: "+accumPos[k]); innerLearner[k].addExample( new Example( new HyperplaneInstance(accumPos[k],subPopId,source), ClassLabel.positiveLabel(+1.0) )); innerLearner[k].addExample( new Example( new HyperplaneInstance(accumNeg[k],subPopId,source), ClassLabel.negativeLabel(-1.0) )); } } transitions += sequence.length; pc.progress(); } // sequence i System.out.println("Epoch "+epoch+": sequenceErr="+sequenceErrors +" transitionErrors="+transitionErrors+"/"+transitions); if (transitionErrors==0) break; } // epoch pc.finished(); for (int k=0; k<schema.getNumberOfClasses(); k++) { innerLearner[k].completeTraining(); } Classifier c = new SequenceUtils.MultiClassClassifier(schema,innerLearner); // we can use a CMM here, since the classifier is constructed so // that the same beam search will work return new CMM(c, historySize, schema ); } }