package iitb.CRF; import java.io.*; import java.lang.reflect.Constructor; /** * * CRF (conditional random fields) This class provides support for * training and applying a conditional random field for sequence * labeling problems. * * @author Sunita Sarawagi * */ public class CRF implements Serializable { /** * Comment for <code>serialVersionUID</code> */ private static final long serialVersionUID = 14L; double lambda[]; protected int numY; transient Trainer trainer; FeatureGenerator featureGenerator; EdgeGenerator edgeGen; HistoryManager histMgr; public CrfParams params; transient Viterbi viterbi; /** * @param numLabels is the number of distinct class labels or y-labels * @param fgen is the class that is responsible for providing * the features for a particular position on the sequence. * @param arg is a string that can be used to control various * parameters of the CRF, these are space separated name-value pairs * described in * @see iitb.CRF.CrfParams */ public CRF(int numLabels, FeatureGenerator fgen, String arg) { this(numLabels, fgen, CrfParams.stringToOptions(arg)); } public CRF(int numLabels, FeatureGenerator fgen, java.util.Properties configOptions) { this(numLabels,1,fgen,configOptions); } public CRF(int numLabels, int histsize, FeatureGenerator fgen, java.util.Properties configOptions) { histMgr = new HistoryManager(histsize,numLabels); featureGenerator = histMgr.getFeatureGen(fgen); numY = histMgr.numY; params = new CrfParams(configOptions); edgeGen = histMgr.getEdgeGenerator(); viterbi = getViterbi(1); } public FeatureGenerator getFeatureGenerator() {return featureGenerator;} /* * useful for resetting Viterbi options after loading a saved model. */ public void reinitOptions(java.util.Properties configOptions) { params = new CrfParams(configOptions); viterbi = null; } /** * write the trained parameters of the CRF to the file */ public void write(String fileName) throws IOException { PrintWriter out=new PrintWriter(new FileOutputStream(fileName)); out.println(lambda.length); for (int i = 0; i < lambda.length; i++) out.println(lambda[i]); out.close(); } /** * read the parameters of the CRF from a file */ public void read(String fileName) throws IOException { BufferedReader in=new BufferedReader(new FileReader(fileName)); int numF = Integer.parseInt(in.readLine()); lambda = new double[numF]; int pos = 0; String line; while((line=in.readLine())!=null) { lambda[pos++] = Double.parseDouble(line); } } protected Trainer dynamicallyLoadedTrainer() { if (params.trainerType.startsWith("load=")) { try { Class c = Class.forName(params.trainerType.substring(5)); Constructor constr = c.getConstructor( new Class[] { CrfParams.class } ); return (Trainer)constr.newInstance( new Object[] { params } ); } catch (Exception e) { // TODO Auto-generated catch block e.printStackTrace(); return null; } } return null; } protected Trainer getTrainer() { Trainer thisTrainer = dynamicallyLoadedTrainer(); if (thisTrainer != null) return thisTrainer; if (params.trainerType.startsWith("Collins")) return new CollinsTrainer(params); if (params.trainerType.startsWith("Piecewise")) return new PiecewiseTrainer(params); return new Trainer(params); } public Viterbi getViterbi(int beamsize) { return params.miscOptions.getProperty("segmentViterbi", "false").equals("true")? new SegmentViterbi(this,beamsize):new Viterbi(this, beamsize); } /** * Trains the model given the data * @return the learnt parameter value as an array */ public double[] train(DataIter trainData) { return train(trainData,null,null); } public void setInitTrainWeights(double initLambda[]) { lambda = new double[featureGenerator.numFeatures()]; params.miscOptions.setProperty("initValuesUseExisting", "true"); for (int i = 0; i < initLambda.length; lambda[i] = initLambda[i], i++); } /** * Trains the model given the data * @return the learnt parameter value as an array */ public double[] train(DataIter trainData, Evaluator evaluator) { return train(trainData,evaluator,null); } /** * Trains the model given the data with weighted instances. * @return the learnt parameter value as an array */ public double[] train(DataIter trainData, Evaluator evaluator, float instanceWts[]) { if (lambda == null) lambda = new double[featureGenerator.numFeatures()]; trainer = getTrainer(); trainer.train(this, histMgr.mapTrainData(trainData), lambda, evaluator, instanceWts); return lambda; } /** * Same as above except that there is a misclassification cost associated with label pairs. */ public double[] train(DataIter trainData, Evaluator evaluator, float instanceWts[], float misClassCosts[][]) { if (lambda == null) lambda = new double[featureGenerator.numFeatures()]; trainer = getTrainer(); trainer.train(this, histMgr.mapTrainData(trainData), lambda, evaluator, instanceWts, misClassCosts); return lambda; } public double[] learntWeights() { return lambda; } public double apply(DataSequence dataSeq) { if (viterbi==null) viterbi = getViterbi(1); if (params.debugLvl > 1) Util.printDbg("CRF: Applying on " + dataSeq); double score = viterbi.bestLabelSequence(dataSeq,lambda); if (histMgr != null) { for(int i = dataSeq.length()-1; i >= 0; i--) { histMgr.set_y(dataSeq, i, dataSeq.y(i)); } } return score; } public double applyAndScore(DataSequence dataSeq) { double score = apply(dataSeq); double lZx = getLogZx(dataSeq); return Math.exp(score-lZx); } public LabelSequence[] topKLabelSequences(DataSequence dataSeq, int numLabelSeqs, boolean getScores) { if ((viterbi==null) || (viterbi.beamsize < numLabelSeqs)) viterbi = getViterbi(numLabelSeqs); return viterbi.topKLabelSequences(dataSeq,lambda,numLabelSeqs,getScores); } public double score(DataSequence dataSeq) { if (viterbi==null) viterbi = getViterbi(1); return viterbi.viterbiSearch(dataSeq,lambda,true); } public void expectedFeatureValues(DataIter data, double expFVals[], FeatureGenerator fgen) { if (trainer==null) { trainer = getTrainer(); trainer.init(this,data,lambda); } trainer.computeFeatureExpectedValue(data,fgen,lambda,expFVals); } public double getLogZx(DataSequence dataSequence) { if (trainer==null) { trainer = getTrainer(); trainer.init(this,null,lambda); } else { trainer.reInit(); } return -1*trainer.sumProduct(dataSequence,featureGenerator,lambda,null,null,false, -1, null); } public void score(DataSequence dataSeq, double[] featureVec) { if (trainer==null) { trainer = getTrainer(); trainer.init(this,null,lambda); } trainer.addFeatureVector(dataSeq,featureVec); } public int getNumY() {return numY;} };