/* Copyright 2003, Carnegie Mellon, All Rights Reserved */
package edu.cmu.minorthird.classify.sequential;
import java.awt.BorderLayout;
import java.io.Serializable;
import javax.swing.JComponent;
import javax.swing.JLabel;
import javax.swing.JPanel;
import javax.swing.JScrollPane;
import javax.swing.border.TitledBorder;
import org.apache.log4j.Logger;
import edu.cmu.minorthird.classify.ClassLabel;
import edu.cmu.minorthird.classify.Classifier;
import edu.cmu.minorthird.classify.ExampleSchema;
import edu.cmu.minorthird.classify.Explanation;
import edu.cmu.minorthird.classify.Instance;
import edu.cmu.minorthird.util.gui.ComponentViewer;
import edu.cmu.minorthird.util.gui.SmartVanillaViewer;
import edu.cmu.minorthird.util.gui.Viewer;
import edu.cmu.minorthird.util.gui.Visible;
/**
* A conditional markov model classifier.
*
* @author William Cohen
*/
public class CMM implements ConfidenceReportingSequenceClassifier,SequenceConstants,Visible,Serializable
{
private static final long serialVersionUID = 20080207L;
static Logger log = Logger.getLogger(CMM.class);
static private final boolean DEBUG = false;
private BeamSearcher searcher;
private int historySize;
// private String[] possibleClassLabels;
private Classifier classifier;
private int beamSize = 10;
public CMM(Classifier classifier,int historySize,ExampleSchema schema)
{
this.searcher = new BeamSearcher(classifier,historySize,schema);
this.classifier = classifier;
this.historySize = historySize;
// this.possibleClassLabels = schema.validClassNames();
}
public Classifier getClassifier() { return classifier; }
public int getHistorySize() { return historySize; }
@Override
public ClassLabel[] classification(Instance[] sequence)
{
return searcher.bestLabelSequence(sequence);
}
@Override
public double confidence(Instance[] sequence,ClassLabel[] predictedClasses,ClassLabel[] alternateClasses,int lo,int hi)
{
if (predictedClasses.length!=alternateClasses.length || predictedClasses.length!=sequence.length)
throw new IllegalArgumentException("predictedClasses, alternateClasses, sequence should be parallel arrays");
if (lo<0 || lo>sequence.length || hi<0 || hi>sequence.length || hi<=lo)
throw new IllegalArgumentException("lo..hi must be define a subsequence");
searcher.doSearch(sequence,alternateClasses);
ClassLabel[] constrainedPrediction = searcher.viterbi(0);
double weightOfPrediction =
ConfidenceUtils.sumPredictedWeights(predictedClasses,0,predictedClasses.length);
double weightOfConstrainedPrediction =
ConfidenceUtils.sumPredictedWeights(constrainedPrediction,0,constrainedPrediction.length);
if (DEBUG) {
for (int ii=0; ii<sequence.length; ii++) {
System.out.println("pred="+predictedClasses[ii] +"\t"+
"conp="+constrainedPrediction[ii] +"\t"+
sequence[ii].getSource());
}
System.out.println(weightOfPrediction +"\t"+ weightOfConstrainedPrediction + "\t diff="+(weightOfPrediction-weightOfConstrainedPrediction));
}
if (weightOfConstrainedPrediction>weightOfPrediction)
throw new IllegalStateException("constrained beam search should have returned a lower-scoring prediction?");
return weightOfPrediction-weightOfConstrainedPrediction ;
}
@Override
public String explain(Instance[] sequence)
{
return searcher.explain(sequence);
}
@Override
public Explanation getExplanation(Instance[] sequence) {
Explanation.Node top = new Explanation.Node("CMM Explanation");
Explanation.Node searcherEx = searcher.getExplanation(sequence).getTopNode();
if(searcherEx == null)
searcherEx = new Explanation.Node(searcher.explain(sequence));
top.add(searcherEx);
Explanation ex = new Explanation(top);
return ex;
}
@Override
public Viewer toGUI()
{
Viewer v = new ComponentViewer() {
static final long serialVersionUID=20080207L;
@Override
public JComponent componentFor(Object o) {
CMM cmm = (CMM)o;
JPanel mainPanel = new JPanel();
mainPanel.setLayout(new BorderLayout());
mainPanel.add(
new JLabel("CMM: historySize="+cmm.historySize+" beamSize="+beamSize),
BorderLayout.NORTH);
Viewer subView = new SmartVanillaViewer(cmm.classifier);
subView.setSuperView(this);
mainPanel.add(subView,BorderLayout.SOUTH);
mainPanel.setBorder(new TitledBorder("Conditional Markov Model"));
return new JScrollPane(mainPanel);
}
};
v.setContent(this);
return v;
}
}