package edu.cmu.minorthird.classify.sequential;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Hashtable;
import java.util.Iterator;
import javax.swing.JComponent;
import javax.swing.JPanel;
import javax.swing.JScrollPane;
import javax.swing.border.TitledBorder;
import edu.cmu.minorthird.classify.ClassLabel;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.ExampleSchema;
import edu.cmu.minorthird.classify.Explanation;
import edu.cmu.minorthird.classify.Instance;
import edu.cmu.minorthird.util.gui.ComponentViewer;
import edu.cmu.minorthird.util.gui.Viewer;
import edu.cmu.minorthird.util.gui.Visible;
//// so here when you call the MultiClassHMMClassifier, it's like return MultiClassHMMClassifier( dataset)
public class MultiClassHMMClassifier implements SequenceClassifier,SequenceConstants,Visible,Serializable
{
static final long serialVersionUID=20080207L;
private ExampleSchema schema;
public HMM hmmModel;
private int numStates;
private int numEmissions;
String[] state;
double[][] aprob;
double[][] eprob;
ArrayList<String[]> training_seq;
private Hashtable<String,String> dict_tok;
// private Hashtable<String,String> dict_tok2idx;
// private Hashtable<String,String> dict_idx2tok;
/* HMM needs the dataset, to build the Hashtables and init all the matrix*/
public MultiClassHMMClassifier(SequenceDataset dataset)
{
this.schema = dataset.getSchema();
/* schema.validClassNameSet has private access so we can't change it here,
but we'll add the state 'start' to it conceptually*/
// adding 1 state corresponding to 'start of sentence' will be done in hmm
this.numStates = schema.getNumberOfClasses();
this.state = new String[numStates];
for(int i=0;i<schema.getNumberOfClasses();i++){
state[i] = schema.getClassName(i);
}
this.dict_tok = new Hashtable<String,String>();
training_seq = new ArrayList<String[]>();
for (Iterator<Example[]> i=dataset.sequenceIterator(); i.hasNext(); ) {
Example[] sequence = i.next();
String[] tok = new String[sequence.length];
// int labels[] = new int[sequence.length];
int size;
String token;
for (int j=0; j<sequence.length; j++) {
// ClassLabel label = sequence[j].getLabel();
size = sequence[j].numericFeatureIterator().next().size();
token = sequence[j].numericFeatureIterator().next().getPart(size-1);
tok[j] = token;
if ( dict_tok.containsKey(token) ){
int cnt = Integer.parseInt(dict_tok.get(token));
cnt++;
dict_tok.put(token,String.valueOf(cnt));
}else{
dict_tok.put(token,"1");
}
// labels[j] = schema.getClassIndex(sequence[j].getLabel().bestClassName());
}
training_seq.add( tok);
}
dict_tok.put("UNSEEN","1");
this.numEmissions=dict_tok.size();
/* initHMM, could be estimate AB based on the dataset*/
aprob = new double[numStates][numStates];
eprob = new double[numStates][numEmissions];
hmmModel = new HMM(state, aprob, dict_tok, eprob);
}
/*baum welch for hmm*/
public void baumwelch( final double threshold) {
ArrayList<String[]> training_data = new ArrayList<String[]>( this.training_seq.size());
for( int i=0; i<training_seq.size();i++){
training_data.add( hmmModel.convert_Ob_seq( training_seq.get(i) ) );
}
hmmModel = HMM.baumwelch( training_data, this.state, this.dict_tok, threshold);
return;
}
// I think here you need the viterbi to get the ClassLabel[] for the instance[], something to
//take the place of h[i].score(instance)
@Override
public ClassLabel[] classification(Instance[] sequence)
{
ClassLabel[] label = new ClassLabel[sequence.length];
String[] ob_seq = new String[sequence.length];
for (int i=0; i<sequence.length; i++) {
int size = sequence[i].numericFeatureIterator().next().size();
ob_seq[i] = sequence[i].numericFeatureIterator().next().getPart(size-1);
System.out.println("ob_seq["+i+"] is "+ob_seq[i]);
}
// System.out.println("End of one call ");
String[] seq;
seq = hmmModel.convert_Ob_seq( ob_seq );
Viterbi vit = new Viterbi(hmmModel, seq);
//vit.print(new SystemOut());
String[] tag_seq = vit.getPath();
for (int i=0; i<tag_seq.length; i++) {
label[i]= new ClassLabel(tag_seq[i]);
System.out.println("tag_seq["+i+"] is "+tag_seq[i]);
}
// label[i]= new ClassLabel("NEG");
// label.add( schema.getClassName(i), h[i].score(instance) );
// System.out.println("Name of Classes is "+schema.getClassName(i));
return label;
}
/** Return some string that 'explains' the classification,
this function is also required to be re-written*/
@Override
public String explain(Instance[] instance)
{
StringBuffer buf = new StringBuffer("");
for (int i=0; i<numStates; i++) {
buf.append("Hyperplane for class "+schema.getClassName(i)+":\n");
// buf.append( h[i].explain(instance) );
buf.append("\n");
}
return buf.toString();
}
@Override
public Explanation getExplanation(Instance[] instance) {
Explanation.Node top = new Explanation.Node("MultiClassHMM Explanation");
for (int i=0; i<numStates; i++) {
Explanation.Node classEx = new Explanation.Node("Hyperplane for class "+schema.getClassName(i)+":\n");
top.add(classEx);
}
Explanation ex = new Explanation(top);
return ex;
}
/*This one is for visualization*/
@Override
public Viewer toGUI()
{
Viewer gui = new ComponentViewer() {
static final long serialVersionUID=20080207L;
@Override
public JComponent componentFor(Object o) {
MultiClassHMMClassifier c = (MultiClassHMMClassifier)o;
JPanel main = new JPanel();
for (int i=0; i<numStates; i++) {
JPanel classPanel = new JPanel();
classPanel.setBorder(new TitledBorder("Class "+c.schema.getClassName(i)));
// Viewer subviewer = voteMode ? s_t[i].toGUI() : w_t[i].toGUI();
// subviewer.setSuperView( this );
// classPanel.add( subviewer );
main.add(classPanel);
}
return new JScrollPane(main);
}
};
gui.setContent(this);
return gui;
}
/*This one could be used to output the model, say all the matrix*/
@Override
public String toString()
{
return "[MultiClassHMMClassifier:";
}
}