/** * */ package com.maalaang.omtwitter.ml; import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.IOException; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.io.PrintWriter; import java.util.ArrayList; import cc.mallet.fst.CRF; import cc.mallet.fst.CRFTrainerByLabelLikelihood; import cc.mallet.fst.CRFTrainerByThreadedLabelLikelihood; import cc.mallet.pipe.Pipe; import cc.mallet.pipe.SerialPipes; import cc.mallet.pipe.TokenSequence2FeatureVectorSequence; import cc.mallet.pipe.tsf.OffsetConjunctions; import cc.mallet.pipe.tsf.SequencePrintingPipe; import cc.mallet.types.Instance; import cc.mallet.types.InstanceList; /** * @author Sangwon Park * */ public class CrfClassifier { private CRF crf = null; private Pipe pipe = null; public void loadModel(String modelFile) throws IOException, ClassNotFoundException { FileInputStream fis = new FileInputStream(modelFile); ObjectInputStream ois = new ObjectInputStream(fis); crf = (CRF)ois.readObject(); pipe = crf.getInputPipe(); pipe.setTargetProcessing(false); ois.close(); fis.close(); } public Instance classify(String[][] data) { Instance instance = new Instance(data, null, null, null); return crf.label(instance); } public Instance classify(Instance instance) { return crf.label(instance); } public void train(String[] trainingFiles, String fieldDelim, int[] fields, String modelFile, String featureDumpFile, boolean writeFeatureFile, int threadNum) throws IOException, ClassNotFoundException { ArrayList<Pipe> pipes = new ArrayList<Pipe>(); int[][] conjunctions = new int[4][]; conjunctions[0] = new int[] { -2 }; conjunctions[1] = new int[] { -1 }; conjunctions[2] = new int[] { 1 }; conjunctions[3] = new int[] { 2 }; pipes.add(new TokenWithPosSequence(true)); pipes.add(new TweetFeatures()); pipes.add(new OffsetConjunctions(conjunctions)); pipes.add(new TokenSequence2FeatureVectorSequence()); /* for debugging feature generation */ if (writeFeatureFile) { PrintWriter out = null; out = new PrintWriter(featureDumpFile); pipes.add(new SequencePrintingPipe(out)); } Pipe pipe = new SerialPipes(pipes); InstanceList trainingInstances = new InstanceList(pipe); for (String file : trainingFiles) { trainingInstances.addThruPipe(new TweetEntityCorpusLineIterator(file, fieldDelim, fields)); } CRF crf = new CRF(pipe, null); // crf.addStatesForLabelsConnectedAsIn(trainingInstances); crf.addStatesForThreeQuarterLabelsConnectedAsIn(trainingInstances); crf.addStartState(); // for (int i = 0; i < crf.numStates(); i++) { // crf.getState(i).setInitialWeight (Transducer.IMPOSSIBLE_WEIGHT); // } // crf.getState(startName).setInitialWeight(0.0); // crf.addStatesForThreeQuarterLabelsConnectedAsIn(trainingInstances); // crf.addStatesForHalfLabelsConnectedAsIn(trainingInstances); // CRFTrainerByStochasticGradient trainer = new CRFTrainerByStochasticGradient(crf, 1.0); // CRFTrainerByL1LabelLikelihood trainer = new CRFTrainerByL1LabelLikelihood(crf, 0.75); // trainer.addEvaluator(new PerClassAccuracyEvaluator(testingInstances, "testing")); // trainer.addEvaluator(new TokenAccuracyEvaluator(testingInstances, "testing")); if (threadNum > 1) { CRFTrainerByThreadedLabelLikelihood crft = new CRFTrainerByThreadedLabelLikelihood(crf, threadNum); crft.setGaussianPriorVariance(10.0); // boolean converged; // int iterations = 500; // for (int i = 1; i <= iterations; i++) { // converged = crft.train(trainingInstances, 1); // if (converged) // break; // } crft.train(trainingInstances); crft.shutdown(); } else { CRFTrainerByLabelLikelihood crft = new CRFTrainerByLabelLikelihood(crf); crft.setGaussianPriorVariance(10.0); crft.train(trainingInstances); } if (!writeFeatureFile) { ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(modelFile)); oos.writeObject(crf); oos.close(); } } }