/** SegmentCRF.java * Created on Nov 21, 2004 * * @author Sunita Sarawagi * @since 1.2 * @version 1.3 * * This is a version of the CRF model that applies the semi-markov * model on data where the candidate segments are provided by the dataset. */ package iitb.CRF; import gnu.trove.list.array.TIntArrayList; import gnu.trove.map.hash.TIntDoubleHashMap; import gnu.trove.map.hash.TIntFloatHashMap; public class SegmentCRF extends CRF { /** * */ private static final long serialVersionUID = 4846441387460151325L; protected FeatureGeneratorNested featureGenNested; transient SegmentViterbi segmentViterbi; transient SegmentAStar segmentAStar; public SegmentCRF(int numLabels, FeatureGeneratorNested fgen, String arg) { super(numLabels,fgen,arg); featureGenNested = fgen; segmentViterbi = new SegmentViterbi(this,1); segmentAStar = new SegmentAStar(this, 1); } public SegmentCRF(int numLabels, FeatureGeneratorNested fgen, java.util.Properties configOptions) { super(numLabels,fgen,configOptions); featureGenNested = fgen; segmentViterbi = new SegmentViterbi(this,1); segmentAStar = new SegmentAStar(this, 1); } public interface ModelGraph { public int numStates(); public void stateMappingGivenLength(int label, int len, TIntArrayList stateIds) throws Exception; }; protected Trainer getTrainer() { Trainer thisTrainer = dynamicallyLoadedTrainer(); if (thisTrainer != null) return thisTrainer; if (params.trainerType.startsWith("SegmentCollins")) return new NestedCollinsTrainer(params); return new SegmentTrainer(params); } public Viterbi getViterbi(int beamsize) { return new SegmentViterbi(this,beamsize); } public void apply(CandSegDataSequence dataSeq, int rank) { System.out.println("Not implemented yet"); } public double apply(DataSequence dataSeq) { if(params.inferenceType.equalsIgnoreCase("AStar")){ if(segmentAStar == null) segmentAStar = new SegmentAStar(this, 1); return segmentAStar.bestLabelSequence((CandSegDataSequence)dataSeq, lambda); }else{//default return super.apply(dataSeq); } } public void singleSegmentClassScores(CandSegDataSequence dataSeq, TIntFloatHashMap scores) { if (segmentViterbi==null) segmentViterbi = (SegmentViterbi)getViterbi(1); segmentViterbi.singleSegmentClassScores(dataSeq,lambda,scores); } public Segmentation[] segmentSequences(CandSegDataSequence dataSeq, int numLabelSeqs) { return segmentSequences(dataSeq,numLabelSeqs,null); } public Segmentation[] segmentSequences(CandSegDataSequence dataSeq, int numLabelSeqs, double scores[]) { if ((segmentViterbi==null) || (segmentViterbi.beamsize < numLabelSeqs)) segmentViterbi = (SegmentViterbi)getViterbi(numLabelSeqs); return segmentViterbi.segmentSequences(dataSeq,lambda,numLabelSeqs,scores); } public double segmentMarginalProbabilities(DataSequence dataSequence, TIntDoubleHashMap segmentMarginals[][], TIntDoubleHashMap edgeMarginals[][][]) { if (trainer==null) { trainer = getTrainer(); trainer.init(this,null,lambda); } return -1*((SegmentTrainer)trainer).sumProductInner(dataSequence,featureGenerator,lambda,null,false, -1, null,segmentMarginals,edgeMarginals); } public double[] marginalProbsOfSegmentation(DataSequence dataSequence, Segmentation segmentation) { TIntDoubleHashMap segMargs[][] = new TIntDoubleHashMap[numY][dataSequence.length()]; for (int i = segmentation.numSegments()-1; i >= 0; i--) { segMargs[segmentation.segmentLabel(i)][segmentation.segmentStart(i)] = new TIntDoubleHashMap(); segMargs[segmentation.segmentLabel(i)][segmentation.segmentStart(i)].put(segmentation.segmentEnd(i), 0); } segmentMarginalProbabilities(dataSequence, segMargs, null); double margPr[]=new double[segmentation.numSegments()]; for (int i = segmentation.numSegments()-1; i >= 0; i--) { margPr[i] = segMargs[segmentation.segmentLabel(i)][segmentation.segmentStart(i)].get(segmentation.segmentEnd(i)); } return margPr; } @Override public double getLogZx(DataSequence dataSequence) { double logZ = super.getLogZx(dataSequence); if (segmentViterbi==null) segmentViterbi = (SegmentViterbi)getViterbi(20); double violatingScore = segmentViterbi.sumScoreTopKViolators(dataSequence,lambda); assert(violatingScore < logZ+1e-4); // if (Math.exp(violatingScore-logZ) > 0.5) { // System.out.println("Lot of mass in violating labelings "+Math.exp(violatingScore-logZ)); // } return RobustMath.logMinusExp(logZ, violatingScore); } /* public void apply(DataSequence dataSeq) { apply((CandSegDataSequence)dataSeq); } public void apply(CandSegDataSequence dataSeq) { if (params.debugLvl > 2) Util.printDbg("SegmentCRF: Applying on " + dataSeq); if(params.inferenceType.equalsIgnoreCase("AStar")){ if(segmentAStar == null) segmentAStar = new SegmentAStar(this, params.beamSize); segmentAStar.bestLabelSequence(dataSeq, lambda); }else{ if (segmentViterbi==null) segmentViterbi = new SegmentViterbi(this,params.beamSize); segmentViterbi.bestLabelSequence(dataSeq,lambda); } } */ /* public double score(DataSequence dataSeq) { if (segmentViterbi==null) segmentViterbi = new SegmentViterbi(this,1); return segmentViterbi.viterbiSearch(dataSeq,lambda,true); } */ }