package iitb.MaxentClassifier; import iitb.CRF.CRF; import iitb.Utils.Options; import java.io.IOException; import java.util.Iterator; import java.util.Vector; /** * * This class shows how to use the CRF package iitb.CRF for basic maxent * classification where the features are provided as attributes of the * instances to be classified. The number of classes can be more than two. * * @author Sunita Sarawagi * */ public class MaxentClassifier { protected FeatureGenRecord featureGen; protected CRF crfModel; protected DataDesc dataDesc; protected Options opts; public MaxentClassifier(Options opts) throws Exception { dataDesc = new DataDesc(opts); this.opts = opts; // read all parameters featureGen = new FeatureGenRecord(dataDesc.numColumns, dataDesc.numLabels); featureGen.addBias=1; if (opts.getProperty("class-prior")!=null) featureGen.addBias = opts.getInt("class-prior"); } protected void train(String trainFile) throws IOException { train(FileData.read(trainFile,dataDesc)); } public void train(Vector trainRecs) { crfModel = new CRF(dataDesc.numLabels,featureGen,opts); // read training data from the given file. double params[] = crfModel.train(new DataSet(trainRecs)); System.out.println("Trained model"); for (int i = 0; i < params.length; i++) System.out.println(featureGen.featureName(i) + " " + params[i]); } void test(String testFile) throws IOException { FileData fData = new FileData(); fData.openForRead(testFile,dataDesc); test(fData.iterator(),false); } public void test(Iterator<DataRecord> dataIter, boolean testOnly) throws IOException { int confMat[][] = new int[dataDesc.numLabels][dataDesc.numLabels]; while (dataIter.hasNext()) { DataRecord dataRecord = (DataRecord) dataIter.next(); int trueLabel = dataRecord.y(); crfModel.apply(dataRecord); // System.out.println(trueLabel + " true:pred " + dataRecord.y()); confMat[trueLabel][dataRecord.y()]++; if (testOnly) dataRecord.set_y(0, trueLabel); } // output confusion matrix etc directly. System.out.println("Confusion matrix "); for(int i=0 ; i<dataDesc.numLabels ; i++) { System.out.print(i); for(int j=0 ; j<dataDesc.numLabels ; j++) { System.out.print("\t"+confMat[i][j]); } System.out.println(); } } public static void main(String args[]) { try { Options opts = new Options(args); MaxentClassifier maxent = new MaxentClassifier(opts); maxent.train(opts.getMandatoryProperty("trainFile")); System.out.println("Finished training...Starting test"); maxent.test(opts.getMandatoryProperty("testFile")); } catch (Exception e) { e.printStackTrace(); } } };