package edu.cmu.minorthird.classify.sequential; import java.io.Serializable; import java.util.Iterator; import javax.swing.JComponent; import javax.swing.JPanel; import javax.swing.JScrollPane; import javax.swing.border.TitledBorder; import edu.cmu.minorthird.classify.BinaryClassifier; import edu.cmu.minorthird.classify.ClassLabel; import edu.cmu.minorthird.classify.Classifier; import edu.cmu.minorthird.classify.Example; import edu.cmu.minorthird.classify.ExampleSchema; import edu.cmu.minorthird.classify.Explanation; import edu.cmu.minorthird.classify.Instance; import edu.cmu.minorthird.classify.OnlineBinaryClassifierLearner; import edu.cmu.minorthird.classify.algorithms.linear.VotedPerceptron; import edu.cmu.minorthird.util.ProgressCounter; import edu.cmu.minorthird.util.gui.ComponentViewer; import edu.cmu.minorthird.util.gui.SmartVanillaViewer; import edu.cmu.minorthird.util.gui.Viewer; import edu.cmu.minorthird.util.gui.Visible; /** * Generic version of Collin's voted perceptron learner. * * @author William Cohen */ public class GenericCollinsLearnerV1 implements BatchSequenceClassifierLearner,SequenceConstants { private OnlineBinaryClassifierLearner innerLearnerPrototype; private OnlineBinaryClassifierLearner[] innerLearner; private int historySize; private int numberOfEpochs; private String[] history; public GenericCollinsLearnerV1() { this(3,5); } public GenericCollinsLearnerV1(OnlineBinaryClassifierLearner innerLearner,int historySize) { this(innerLearner,historySize,5); } public GenericCollinsLearnerV1(int historySize,int epochs) { this(new VotedPerceptron(),historySize,epochs); } public GenericCollinsLearnerV1(OnlineBinaryClassifierLearner innerLearner,int historySize,int epochs) { this.historySize = historySize; this.innerLearnerPrototype = innerLearner; this.numberOfEpochs = epochs; this.history = new String[historySize]; } @Override public void setSchema(ExampleSchema schema) { ; } // // accessors // public OnlineBinaryClassifierLearner getInnerLearner() { return innerLearnerPrototype; } public void setInnerLearner(OnlineBinaryClassifierLearner newInnerLearner) { this.innerLearnerPrototype = newInnerLearner; } @Override public int getHistorySize() { return historySize; } public void setHistorySize(int newHistorySize) { this.historySize = newHistorySize; } public int getNumberOfEpochs() { return numberOfEpochs; } public void setNumberOfEpochs(int newNumberOfEpochs) { this.numberOfEpochs = newNumberOfEpochs; } @Override public SequenceClassifier batchTrain(SequenceDataset dataset) { ExampleSchema schema = dataset.getSchema(); try { innerLearner = new OnlineBinaryClassifierLearner[ schema.getNumberOfClasses() ]; for (int i=0; i<schema.getNumberOfClasses(); i++) { innerLearner[i] = (OnlineBinaryClassifierLearner)innerLearnerPrototype.copy(); innerLearner[i].reset(); } } catch (Exception ex) { throw new IllegalArgumentException("innerLearner must be cloneable"); } ProgressCounter pc = new ProgressCounter("training sequential "+innerLearnerPrototype.toString(), "sequence",numberOfEpochs*dataset.numberOfSequences()); for (int epoch=0; epoch<numberOfEpochs; epoch++) { // statistics for curious researchers int sequenceErrors = 0; int transitionErrors = 0; int transitions = 0; for (Iterator<Example[]> i=dataset.sequenceIterator(); i.hasNext(); ) { Example[] sequence = i.next(); Classifier c = new MultiClassClassifier(schema,innerLearner); ClassLabel[] viterbi = new BeamSearcher(c,historySize,schema).bestLabelSequence(sequence); boolean errorOnThisSequence=false; for (int j=0; j<sequence.length; j++) { // is the instance at sequence[j] associated with a difference in the sum // of feature values over the viterbi sequence and the actual one? boolean differenceAtJ = !viterbi[j].isCorrect( sequence[j].getLabel() ); for (int k=1; j-k>=0 && !differenceAtJ && k<=historySize; k++) { if (!viterbi[j-k].isCorrect( sequence[j-k].getLabel() )) { differenceAtJ = true; } } if (differenceAtJ) { transitionErrors++; errorOnThisSequence=true; InstanceFromSequence.fillHistory( history, sequence, j ); Instance correctXj = new InstanceFromSequence( sequence[j], history ); int correctClassIndex = schema.getClassIndex( sequence[j].getLabel().bestClassName() ); innerLearner[correctClassIndex].addExample( new Example( correctXj, ClassLabel.binaryLabel(1.0) ) ); InstanceFromSequence.fillHistory( history, viterbi, j ); Instance wrongXj = new InstanceFromSequence( sequence[j], history ); int wrongClassIndex = schema.getClassIndex( viterbi[j].bestClassName() ); innerLearner[wrongClassIndex].addExample( new Example( wrongXj, ClassLabel.binaryLabel(-1.0)) ); } } // example sequence j if (errorOnThisSequence) sequenceErrors++; transitions += sequence.length; pc.progress(); } // sequence i System.out.println("Epoch "+epoch+": sequenceErr="+sequenceErrors +" transitionErrors="+transitionErrors+"/"+transitions); if (transitionErrors==0) break; } // epoch pc.finished(); // we can use a CMM here, since the classifier is constructed to the same // beam search will work Classifier c = new MultiClassClassifier(schema,innerLearner); return new CMM(c, historySize, schema ); } public static class MultiClassClassifier implements Classifier,Visible,Serializable { static private final long serialVersionUID = 1; private ExampleSchema schema; private BinaryClassifier[] innerClassifier; private int numClasses; public MultiClassClassifier(ExampleSchema schema,BinaryClassifier[] learners) { this.schema = schema; this.numClasses = schema.getNumberOfClasses(); innerClassifier = learners; } public MultiClassClassifier(ExampleSchema schema,OnlineBinaryClassifierLearner[] innerLearner) { this.schema = schema; this.numClasses = schema.getNumberOfClasses(); innerClassifier = new BinaryClassifier[ numClasses ]; for (int i=0; i<numClasses; i++) { innerClassifier[i] = innerLearner[i].getBinaryClassifier(); } } @Override public ClassLabel classification(Instance instance) { ClassLabel label = new ClassLabel(); for (int i=0; i<numClasses; i++) { label.add( schema.getClassName(i), innerClassifier[i].score(instance) ); } return label; } @Override public String explain(Instance instance) { StringBuffer buf = new StringBuffer(""); for (int i=0; i<numClasses; i++) { buf.append("Classifier for class "+schema.getClassName(i)+":\n"); buf.append( innerClassifier[i].explain(instance) ); buf.append("\n"); } return buf.toString(); } @Override public Explanation getExplanation(Instance instance) { Explanation.Node top = new Explanation.Node("GenericCollins Explanation"); for (int i=0; i<numClasses; i++) { Explanation.Node classifier = new Explanation.Node("Classifier for class "+schema.getClassName(i)); Explanation.Node classEx = innerClassifier[i].getExplanation(instance).getTopNode(); classifier.add(classEx); top.add(classifier); } Explanation ex = new Explanation(top); return ex; } @Override public Viewer toGUI() { Viewer gui = new ComponentViewer() { static final long serialVersionUID=20080207L; @Override public JComponent componentFor(Object o) { MultiClassClassifier c = (MultiClassClassifier)o; JPanel main = new JPanel(); for (int i=0; i<numClasses; i++) { JPanel classPanel = new JPanel(); classPanel.setBorder(new TitledBorder("Class "+c.schema.getClassName(i))); Viewer subviewer = new SmartVanillaViewer( c.innerClassifier[i] ); subviewer.setSuperView( this ); classPanel.add( subviewer ); main.add(classPanel); } return new JScrollPane(main); } }; gui.setContent(this); return gui; } } }