package edu.cmu.minorthird.classify.sequential; import java.awt.BorderLayout; import java.io.Serializable; import java.util.Iterator; import java.util.StringTokenizer; import javax.swing.JComponent; import javax.swing.JLabel; import javax.swing.JPanel; import javax.swing.JScrollPane; import javax.swing.border.TitledBorder; import edu.cmu.minorthird.classify.ClassLabel; import edu.cmu.minorthird.classify.Example; import edu.cmu.minorthird.classify.ExampleSchema; import edu.cmu.minorthird.classify.Explanation; import edu.cmu.minorthird.classify.Feature; import edu.cmu.minorthird.classify.Instance; import edu.cmu.minorthird.classify.algorithms.linear.Hyperplane; import edu.cmu.minorthird.util.ProgressCounter; import edu.cmu.minorthird.util.gui.ComponentViewer; import edu.cmu.minorthird.util.gui.SmartVanillaViewer; import edu.cmu.minorthird.util.gui.Viewer; import edu.cmu.minorthird.util.gui.Visible; /** * Sequential learner based on the CRF algorithm. Source for the iitb.CRF * package available from http://crf.sourceforge.net. * * @author Sunita Sarawagi */ public class CRFLearner implements BatchSequenceClassifierLearner,SequenceConstants,SequenceClassifier,Visible,Serializable{ static private final long serialVersionUID = 1; int histsize = 1; ExampleSchema schema; iitb.CRF.CRF crfModel; java.util.Properties defaults; java.util.Properties options; private static final boolean CONVERT_TO_MINORTHIRD_HYPERPLANE=true; public CRFLearner() { defaults = new java.util.Properties(); defaults.setProperty("modelGraph", "naive"); defaults.setProperty("debugLvl", "1"); //defaults.setProperty("trainer", "ll"); options = defaults; } public CRFLearner(String args) { this(args,1); } public CRFLearner(String args, int histsize) { this(); this.histsize = histsize; StringTokenizer argTok = new StringTokenizer(args, " "); options = new java.util.Properties(defaults); while (argTok.hasMoreTokens()) { options.setProperty(argTok.nextToken(),argTok.nextToken()); } } public CRFLearner(String args[]) { this(); options = new java.util.Properties(defaults); for (int i = 0; i < args.length-1; i+=2) { options.setProperty(args[i], args[i+1]); } } public void setLogSpaceOption() { options.setProperty("trainer", "ll"); //option for german multi data (very large dataset!) } public void removeLogSpaceOption() { options.remove("trainer"); } @Override public void setSchema(ExampleSchema schema){ this.schema=schema; } public ExampleSchema getSchema(){ return schema; } @Override public int getHistorySize() {return histsize;} public void setMaxIters (int newMaxIters) { defaults.setProperty("maxIters", Integer.toString(newMaxIters)); } public int getMaxIters () { String maxIters = defaults.getProperty("maxIters"); if (maxIters != null) return Integer.parseInt(maxIters); return 100; } public String maxItersHelp = new String("Number of training iterations over the training set; default set to 100"); public String getMaxItersHelp() { return maxItersHelp; } public boolean getUseHighPrecisionArithmetic() { String value = defaults.getProperty("trainer"); if ((value != null) && (value.equals("ll"))) return true; return false; } public void setUseHighPrecisionArithmetic (boolean newUseHighPrecisionArithmetic) { if (newUseHighPrecisionArithmetic == true) this.setLogSpaceOption(); else this.removeLogSpaceOption(); } public String useHighPrecisionArithmeticHelp = new String("Make the learner use high precision arithmetic."); public String getUseHighPrecisionArithmeticHelp() { return useHighPrecisionArithmeticHelp; } class DataSequenceC implements iitb.CRF.DataSequence { Instance[] sequence; int labels[]; void init(Instance[] tokens) { sequence = tokens; if (tokens != null) { if ((labels == null) || (tokens.length > labels.length)) { labels = new int[tokens.length]; } } } @Override public int length() { return sequence.length; } @Override public int y(int i) { return labels[i]; } @Override public Object x(int i) { return sequence[i]; } @Override public void set_y(int i, int label) { labels[i] = label; } }; class TrainDataSequenceC extends DataSequenceC { void init(Example[] tokens) { super.init(tokens); if (tokens != null) { for (int i = 0; i < sequence.length; i++) { labels[i] = schema.getClassIndex(tokens[i].getLabel().bestClassName()); } } } }; class TestDataSequenceC extends DataSequenceC { TestDataSequenceC(Instance[] tokens) { init(tokens); } ClassLabel[] getLabels() { ClassLabel[] clabels = new ClassLabel[sequence.length]; for (int i = 0; i < sequence.length; i++) { clabels[i] = new ClassLabel(schema.getClassName(labels[i])); } return clabels; } }; class CRFDataIter implements iitb.CRF.DataIter { Iterator<Example[]> iter; SequenceDataset dataset; TrainDataSequenceC sequence; int dataSize; CRFDataIter(SequenceDataset ds) { dataset = ds; dataSize = ds.size(); sequence = new TrainDataSequenceC(); } @Override public void startScan() { iter=dataset.sequenceIterator(); } @Override public boolean hasNext() { return iter.hasNext(); } @Override public iitb.CRF.DataSequence next() { sequence.init(iter.next()); return sequence; } }; class MTFeatureTypes extends iitb.Model.FeatureTypes { static final long serialVersionUID=20080207L; Iterator<Feature> featureLooper; Feature feature; int numStates; Instance example; int stateId; MTFeatureTypes(iitb.Model.FeatureGenImpl gen) { super(gen); numStates = model.numStates(); } void advance() { stateId++; if (stateId < numStates) return; if (featureLooper.hasNext()) { feature = featureLooper.next(); stateId = 0; } else { feature = null; featureLooper=null; } } boolean startScan() { stateId = -1; if (featureLooper.hasNext()) { feature = featureLooper.next(); advance(); } else { feature = null; return false; } return true; } @Override public boolean startScanFeaturesAt(iitb.CRF.DataSequence data, int prevPos, int pos) { example = (Instance)data.x(pos); featureLooper = example.featureIterator(); return startScan(); } @Override public boolean hasNext() { return ((stateId < numStates) && (feature != null)); } @Override public void next(iitb.Model.FeatureImpl f) { f.yend = stateId; f.ystart = -1; f.val = (float)example.getWeight(feature); setFeatureIdentifier(feature.getID()*numStates+stateId, stateId, feature,f); advance(); } }; public class MTFeatureGenImpl extends iitb.Model.FeatureGenImpl { static final long serialVersionUID=20080207L; public MTFeatureGenImpl(String modelSpecs, int numLabels, String[] labelNames) throws Exception { super(modelSpecs,numLabels,false); Feature features[] = new Feature[labelNames.length]; for (int i = 0; i < labelNames.length; i++) { features[i] = new Feature(new String[]{ HISTORY_FEATURE, "1", labelNames[i]}); } addFeature(new iitb.Model.EdgeFeatures(this, features)); addFeature(new iitb.Model.StartFeatures(this, new Feature(new String[]{ HISTORY_FEATURE, "1", NULL_CLASS_NAME}))); //wwc: I don't think this feature should be used for minorthird.... //addFeature(new iitb.Model.EndFeatures(model, new Feature("E"))); if (histsize > 1) { //uncomment this for all n-gram history features, //addFeature(new iitb.Model.EdgeHistFeatures(model, HISTORY_FEATURE,labelNames,histsize)); // this is for minorthird style linear history features... Feature histFeatures[][] = new Feature[histsize][labelNames.length]; for (int k = 1; k < histsize; k++) { for (int i = 0; i < labelNames.length; i++) histFeatures[k][i] = new Feature(new String[]{ HISTORY_FEATURE, Integer.toString((k+1)), labelNames[i]}); } addFeature(new iitb.Model.EdgeLinearHistFeatures(this, histFeatures, histsize)); } addFeature(new MTFeatureTypes(this)); } }; iitb.Model.FeatureGenImpl featureGen; SequenceClassifier cmmClassifier = null; double[] crfWs; iitb.CRF.DataIter allocModel(SequenceDataset dataset) throws Exception { featureGen = new MTFeatureGenImpl(options.getProperty("modelGraph"),schema.getNumberOfClasses(),schema.validClassNames()); //options.setProperty("trainer", "ll"); //option for german multi data (very large dataset!) System.out.println("Property: " + options.getProperty("trainer")); crfModel = new iitb.CRF.CRF(featureGen.numStates(),histsize,featureGen,options); return new CRFDataIter(dataset); } @Override public SequenceClassifier batchTrain(SequenceDataset dataset) { try { schema = dataset.getSchema(); return doTrain(allocModel(dataset)); } catch (Exception e) { e.printStackTrace(); throw new IllegalStateException("error in CRF: "+e); } } SequenceClassifier doTrain(iitb.CRF.DataIter trainData) throws Exception { featureGen.train(trainData); ProgressCounter pc = new ProgressCounter("training CRF","iteration"); crfWs = crfModel.train(trainData); pc.finished(); if (CONVERT_TO_MINORTHIRD_HYPERPLANE) return toMinorthirdClassifier(); else return this; } private SequenceClassifier toMinorthirdClassifier() { Hyperplane[] w_t; int numClasses = schema.getNumberOfClasses(); w_t = new Hyperplane[numClasses]; for (int i=0; i<numClasses; i++) { w_t[i] = new Hyperplane(); w_t[i].setBias(0); } for (int fIndex = 0; fIndex < crfWs.length; fIndex++) { Feature feature = (Feature)featureGen.featureIdentifier(fIndex).name; int classIndex = featureGen.featureIdentifier(fIndex).stateId; w_t[classIndex].increment(feature,crfWs[fIndex]); } return new CMM(new SequenceUtils.MultiClassClassifier(schema,w_t), histsize, schema ); } /** Return a predicted type for each element of the sequence. */ @Override public ClassLabel[] classification(Instance[] sequence) { TestDataSequenceC seq = new TestDataSequenceC(sequence); crfModel.apply(seq); featureGen.mapStatesToLabels(seq); return seq.getLabels(); } /** Return some string that 'explains' the classification */ @Override public String explain(Instance[] sequence) { if (cmmClassifier==null) cmmClassifier = toMinorthirdClassifier(); return cmmClassifier.explain(sequence); } @Override public Explanation getExplanation(Instance[] sequence) { if (cmmClassifier==null) cmmClassifier = toMinorthirdClassifier(); Explanation.Node top = new Explanation.Node("CRF Explanation"); Explanation.Node cmmEx = cmmClassifier.getExplanation(sequence).getTopNode(); if(cmmEx == null) cmmEx = new Explanation.Node(cmmClassifier.explain(sequence)); top.add(cmmEx); Explanation ex = new Explanation(top); return ex; } @Override public Viewer toGUI() { Viewer v = new ComponentViewer() { static final long serialVersionUID=20080207L; @Override public JComponent componentFor(Object o) { // CRFLearner cmm = (CRFLearner)o; JPanel mainPanel = new JPanel(); mainPanel.setLayout(new BorderLayout()); mainPanel.add( new JLabel("CRFLearner: historySize=1"), BorderLayout.NORTH); Viewer subView = new SmartVanillaViewer(toMinorthirdClassifier()); subView.setSuperView(this); mainPanel.add(subView,BorderLayout.SOUTH); mainPanel.setBorder(new TitledBorder("CRFLearner")); return new JScrollPane(mainPanel); } }; v.setContent(this); return v; } }