package joshua.discriminative.feature_related.feature_function; import java.io.IOException; import java.util.HashMap; import java.util.List; import java.util.Map; import joshua.corpus.vocab.SymbolTable; import joshua.decoder.chart_parser.SourcePath; import joshua.decoder.ff.DefaultStatefulFF; import joshua.decoder.ff.lm.NgramExtractor; import joshua.decoder.ff.state_maintenance.DPState; import joshua.decoder.ff.tm.Rule; import joshua.decoder.hypergraph.HGNode; import joshua.util.io.LineReader; import joshua.util.io.UncheckedIOException; @Deprecated public class DiscriminativeNgramModel extends DefaultStatefulFF { private HashMap<String, Double> ngramModel; private int startNgramOrder =1; private int endNgramOrder =3; private SymbolTable symbolTbl = null; private NgramExtractor ngramExtractor; private boolean useIntegerNgram = true; public DiscriminativeNgramModel(int ngramStateID, int featID, SymbolTable symbolTbl, int startNgramOrder, int endNgramOrder, String ngramModelFile, double weight, int baselineLMOrder) { super(ngramStateID, weight, featID); this.startNgramOrder = startNgramOrder; this.endNgramOrder = endNgramOrder; this.symbolTbl = symbolTbl; this.ngramExtractor = new NgramExtractor(symbolTbl, ngramStateID, useIntegerNgram, baselineLMOrder); this.ngramModel = loadModel(ngramModelFile); System.out.println("DiscriminativeNgramModel with size " + ngramModel.size()); } public double estimateLogP(Rule rule, int sentID) { return computeLogP( ngramExtractor.getRuleNgrams(rule, startNgramOrder, endNgramOrder) ); } public double estimateFutureLogP(Rule rule, DPState curDPState, int sentID) { //TODO: should we just return 0? return computeLogP( ngramExtractor.getFutureNgrams(rule, curDPState, startNgramOrder, endNgramOrder) ); } public double transitionLogP(Rule rule, List<HGNode> antNodes, int spanStart, int spanEnd, SourcePath srcPath, int sentID) { return computeLogP( ngramExtractor.getTransitionNgrams(rule, antNodes, startNgramOrder, endNgramOrder) ); } public double finalTransitionLogP(HGNode antNode, int spanStart, int spanEnd, SourcePath srcPath, int sentID) { return computeLogP( ngramExtractor.getFinalTransitionNgrams(antNode, startNgramOrder, endNgramOrder) ); } private double computeLogP(HashMap<String, Integer> ngramTbl){ double transitionLogP = 0; for(Map.Entry<String,Integer> ngram : ngramTbl.entrySet()){ transitionLogP += this.getLogProb(ngram.getKey())*ngram.getValue(); } return transitionLogP; } private double getLogProb(String ngram){ double res = ngramModel.get(ngram); return Math.log(res); } private HashMap<String, Double> loadModel(String file){ try { LineReader reader = new LineReader(file); HashMap<String, Double> res =new HashMap<String, Double>(); while(reader.hasNext()){ String line = reader.readLine(); String[] fds = line.split("\\s+\\|{3}\\s+");// feature_key ||| feature vale; the feature_key itself may contain "|||" StringBuffer featKey = new StringBuffer(); for(int i=0; i<fds.length-1; i++){ featKey.append(fds[i]); if(this.useIntegerNgram){ //TODO??????????????? } if(i<fds.length-2) featKey.append(" ||| "); } double weight = new Double(fds[fds.length-1]);//initial weight res.put(featKey.toString(), weight); } reader.close(); return res; } catch (IOException ioe) { throw new UncheckedIOException(ioe); } } }