package edu.cmu.minorthird.classify.sequential; import java.util.Iterator; import java.util.Vector; import edu.cmu.minorthird.classify.ClassLabel; import edu.cmu.minorthird.classify.Classifier; import edu.cmu.minorthird.classify.Example; import edu.cmu.minorthird.classify.ExampleSchema; import edu.cmu.minorthird.classify.Instance; import edu.cmu.minorthird.util.ProgressCounter; /** * Sequential learner based on the perceptron algorithm that takes the * top-k viterbi paths and subtracts those within a margin of beta of * the correct. * * @author Sunita Sarawagi */ public class MarginPerceptronLearner extends CollinsPerceptronLearner { float beta = (float)0.05; int topK = 10; public MarginPerceptronLearner() { this(3,5,(float)0.05); } public MarginPerceptronLearner(int numberOfEpochs) { this(3,numberOfEpochs,(float)0.05); } public MarginPerceptronLearner(int historySize,int numberOfEpochs, float beta) { this(historySize, numberOfEpochs,beta,10); } public MarginPerceptronLearner(int historySize,int numberOfEpochs, float beta, int topK) { super(historySize, numberOfEpochs); this.beta = beta; this.topK = topK; } @Override public SequenceClassifier batchTrain(SequenceDataset dataset) { ExampleSchema schema = dataset.getSchema(); MultiClassVPClassifier c = new MultiClassVPClassifier(schema); ProgressCounter pc = new ProgressCounter("training sequence perceptron","sequence",getNumberOfEpochs()*dataset.numberOfSequences()); Vector<ClassLabel[]> viterbiS = new Vector<ClassLabel[]>(); for (int epoch=0; epoch<getNumberOfEpochs(); epoch++) { // statistics for curious researchers int sequenceErrors = 0; int transitionErrors = 0; int transitions = 0; for (Iterator<Example[]> i=dataset.sequenceIterator(); i.hasNext(); ) { Example[] sequence = i.next(); BeamSearcher beam = new BeamSearcher(c,getHistorySize(),schema); beam.doSearch(sequence); float corrScore = getScore(sequence, c); if (DEBUG) log.debug("corrScore: " + corrScore); viterbiS.clear(); int maxNum = Math.min(beam.getNumberOfSolutionsFound(),topK); for (int k = 0; k < maxNum; k++) { ClassLabel[] viterbi = beam.viterbi(k); float thisScore = beam.score(k); if (DEBUG) log.debug("viterbi: "+k + " score " + thisScore); if (DEBUG) log.debug(sequenceToString(viterbi)); if (thisScore < corrScore*(1-beta)) break; if (!isCorrect(viterbi,sequence)) { viterbiS.add(viterbi); } } if (DEBUG) log.debug("added: " + viterbiS.size()); boolean errorOnThisSequence=false; if (viterbiS.size() > 0) { for (int j=0; j<sequence.length; j++) { boolean differenceAtJ = false; for (int s = 0; s < viterbiS.size(); s++) { ClassLabel[] viterbi = viterbiS.elementAt(s); differenceAtJ = !viterbi[j].isCorrect( sequence[j].getLabel() ); for (int k=1; j-k>=0 && !differenceAtJ && k<=getHistorySize(); k++) { if (!viterbi[j-k].isCorrect( sequence[j-k].getLabel() )) { differenceAtJ = true; } } if (differenceAtJ) break; } if (differenceAtJ) { // i.e., if phi(sequence,j) != phi(viterbi,j) transitionErrors++; errorOnThisSequence=true; InstanceFromSequence.fillHistory( history, sequence, j ); Instance correctXj = new InstanceFromSequence( sequence[j], history ); c.update( sequence[j].getLabel().bestClassName(), correctXj, 1.0 ); for (int s = 0; s < viterbiS.size(); s++) { ClassLabel[] viterbi = viterbiS.elementAt(s); InstanceFromSequence.fillHistory( history, viterbi, j ); Instance wrongXj = new InstanceFromSequence( sequence[j], history ); c.update( viterbi[j].bestClassName(), wrongXj, -1.0/viterbiS.size()); } } } // example sequence j } // for voted perceptron needs this... c.completeUpdate(); if (errorOnThisSequence) sequenceErrors++; transitions += sequence.length; pc.progress(); } // sequence i System.out.println("Epoch "+epoch+": sequenceErr="+sequenceErrors +" transitionErrors="+transitionErrors+"/"+transitions); if (transitionErrors==0) break; } // epoch pc.finished(); c.setVoteMode(true); // we can use a CMM here, since the classifier is constructed to the same // beam search will work return new CMM(c, getHistorySize(), schema ); } float getScore(Example[] sequence, Classifier classifier) { float score = 0; for (int j=0; j<sequence.length; j++) { InstanceFromSequence.fillHistory( history, sequence, j ); Instance correctXj = new InstanceFromSequence( sequence[j], history ); score += classifier.classification(correctXj).getWeight(sequence[j].getLabel().bestClassName()); } return score; } boolean isCorrect(ClassLabel[] viterbi, Example[] sequence) { for (int j=0; j<sequence.length; j++) { if (!viterbi[j].isCorrect( sequence[j].getLabel())) return false; } return true; } String sequenceToString(ClassLabel[] viterbi) { String path=""; for (int j=0; j<viterbi.length; j++) { path += (viterbi[j].bestClassName() + " "); } return path; } }