package edu.cmu.minorthird.classify.sequential; import edu.cmu.minorthird.classify.*; import edu.cmu.minorthird.classify.algorithms.linear.Hyperplane; import edu.cmu.minorthird.util.*; import edu.cmu.minorthird.util.gui.*; import javax.swing.*; import javax.swing.border.TitledBorder; import java.io.Serializable; import java.util.Iterator; import org.apache.log4j.*; /** * Sequential learner based on the perceptron algorithm, as described * in Discriminative Training Methods for Hidden Markov Models: Theory * and Experiments with Perceptron Algorithms, Mike Collins, EMNLP * 2002. * * @author William Cohen */ public class CollinsPerceptronLearner implements BatchSequenceClassifierLearner,SequenceConstants { protected static Logger log = Logger.getLogger(CollinsPerceptronLearner.class); protected static final boolean DEBUG = log.isDebugEnabled(); protected int historySize; protected int numberOfEpochs; protected String[] history; public CollinsPerceptronLearner() { this(3,5); } public CollinsPerceptronLearner(int numberOfEpochs) { this(3,numberOfEpochs); } public CollinsPerceptronLearner(int historySize,int numberOfEpochs) { this.historySize = historySize; this.numberOfEpochs = numberOfEpochs; this.history = new String[historySize]; } public int getNumberOfEpochs() { return numberOfEpochs; } public void setNumberOfEpochs(int newNumberOfEpochs) { this.numberOfEpochs = newNumberOfEpochs; } @Override public int getHistorySize() { return historySize; } public void setHistorySize(int newHistorySize) { this.historySize = newHistorySize; } // Help Button public String getHistorySizeHelp() { return "Number of tokens to look back on. <br>The predicted labels for the history are used as features to help classify the current token." ;} @Override public void setSchema(ExampleSchema schema) { ; } @Override public SequenceClassifier batchTrain(SequenceDataset dataset) { ExampleSchema schema = dataset.getSchema(); MultiClassVPClassifier c = new MultiClassVPClassifier(schema); ProgressCounter pc = new ProgressCounter("training sequence perceptron","sequence",numberOfEpochs*dataset.numberOfSequences()); for (int epoch=0; epoch<numberOfEpochs; epoch++) { dataset.shuffle(); // statistics for curious researchers int sequenceErrors = 0; int transitionErrors = 0; int transitions = 0; for (Iterator<Example[]> i=dataset.sequenceIterator(); i.hasNext(); ) { Example[] sequence = i.next(); ClassLabel[] viterbi = new BeamSearcher(c,historySize,schema).bestLabelSequence(sequence); if (DEBUG) log.debug("classifier: "+c); if (DEBUG) log.debug("viterbi:\n"+StringUtil.toString(viterbi)); // At this point, Collin's paper says to add Phi(sequence) - // Phi(viterbi) to the current weight vector W. We're doing // this, with two twists: (a) the features in our instance // vectors phi(sequence,i) vectors are not paired with class // labels. Instead, we compute class-label independent features // and then attach the class label when we 'update' c. // (b) rather than computing Phi(sequence), Phi(viterbi), and // subtracting, we compute the difference directly. boolean errorOnThisSequence=false; for (int j=0; j<sequence.length; j++) { // is the instance at sequence[j] associated with a difference in the sum // of feature values over the viterbi sequence and the actual one? boolean differenceAtJ = !viterbi[j].isCorrect( sequence[j].getLabel() ); //System.out.println("differenceAtJ for J="+j+" "+differenceAtJ+" - label"); for (int k=1; j-k>=0 && !differenceAtJ && k<=historySize; k++) { if (!viterbi[j-k].isCorrect( sequence[j-k].getLabel() )) { //System.out.println("differenceAtJ for J="+j+" true: k="+k); differenceAtJ = true; } } if (differenceAtJ) { // i.e., if phi(sequence,j) != phi(viterbi,j) transitionErrors++; errorOnThisSequence=true; InstanceFromSequence.fillHistory( history, sequence, j ); Instance correctXj = new InstanceFromSequence( sequence[j], history ); c.update( sequence[j].getLabel().bestClassName(), correctXj, 1.0 ); if (DEBUG) log.debug("+ update "+sequence[j].getLabel().bestClassName()+" "+correctXj.getSource()); InstanceFromSequence.fillHistory( history, viterbi, j ); Instance wrongXj = new InstanceFromSequence( sequence[j], history ); c.update( viterbi[j].bestClassName(), wrongXj, -1.0 ); if (DEBUG) log.debug("- update "+viterbi[j].bestClassName()+" "+wrongXj.getSource()); } } // example sequence j // our computation of Phi(sequence)-Phi(viterbi) is complete - the voting scheme // for voted perceptron needs this... c.completeUpdate(); if (errorOnThisSequence) sequenceErrors++; transitions += sequence.length; pc.progress(); } // sequence i System.out.println("Epoch "+epoch+": sequenceErr="+sequenceErrors +" transitionErrors="+transitionErrors+"/"+transitions); if (transitionErrors==0) break; } // epoch pc.finished(); c.setVoteMode(true); // we can use a CMM here, since the classifier is constructed to the same // beam search will work return new CMM(c, historySize, schema ); } public static class MultiClassVPClassifier implements Classifier,Visible,Serializable { static private final long serialVersionUID = 1; private ExampleSchema schema; private Hyperplane[] s_t, w_t; private int numClasses; private boolean voteMode = false; public MultiClassVPClassifier(ExampleSchema schema) { this.schema = schema; this.numClasses = schema.getNumberOfClasses(); reset(); } public void setVoteMode(boolean flag) { voteMode=flag; } public Hyperplane[] getHyperplanes() { return voteMode? s_t : w_t ; } public ExampleSchema getSchema() { return schema; } public void update(String className, Instance instance, double delta) { int index = schema.getClassIndex(className); w_t[index].increment( instance, delta ); } public void completeUpdate() { for (int i=0; i<numClasses; i++) { s_t[i].increment( w_t[i], 1.0 ); } } @Override public ClassLabel classification(Instance instance) { Hyperplane[] h = voteMode ? s_t : w_t ; ClassLabel label = new ClassLabel(); for (int i=0; i<numClasses; i++) { label.add( schema.getClassName(i), h[i].score(instance) ); } return label; } @Override public String explain(Instance instance) { Hyperplane[] h = voteMode ? s_t : w_t ; StringBuffer buf = new StringBuffer(""); for (int i=0; i<numClasses; i++) { buf.append("Hyperplane for class "+schema.getClassName(i)+":\n"); buf.append( h[i].explain(instance) ); buf.append("\n"); } return buf.toString(); } @Override public Explanation getExplanation(Instance instance) { Hyperplane[] h = voteMode ? s_t : w_t ; Explanation.Node top = new Explanation.Node("CollinsPerceptron Explanation"); for (int i=0; i<numClasses; i++) { Explanation.Node hyp = new Explanation.Node("Hyperplane for class "+schema.getClassName(i)+":\n"); Explanation.Node explanation = h[i].getExplanation(instance).getTopNode(); hyp.add(explanation); top.add(hyp); } Explanation ex = new Explanation(top); return ex; } @Override public Viewer toGUI() { Viewer gui = new ComponentViewer() { static final long serialVersionUID=20080207L; @Override public JComponent componentFor(Object o) { MultiClassVPClassifier c = (MultiClassVPClassifier)o; JPanel main = new JPanel(); for (int i=0; i<numClasses; i++) { JPanel classPanel = new JPanel(); classPanel.setBorder(new TitledBorder("Class "+c.schema.getClassName(i))); Viewer subviewer = voteMode ? s_t[i].toGUI() : w_t[i].toGUI(); subviewer.setSuperView( this ); classPanel.add( subviewer ); main.add(classPanel); } return new JScrollPane(main); } }; gui.setContent(this); return gui; } public void reset() { s_t = new Hyperplane[numClasses]; w_t = new Hyperplane[numClasses]; for (int i=0; i<numClasses; i++) { s_t[i] = new Hyperplane(); w_t[i] = new Hyperplane(); } } @Override public String toString() { return "[MultiClassVPClassifier:"+StringUtil.toString(w_t,"\n","\n]","\n - "); } } }