package edu.berkeley.cs.nlp.ocular.lm; import java.util.List; import java.util.Set; import edu.berkeley.cs.nlp.ocular.util.ArrayHelper; import edu.berkeley.cs.nlp.ocular.util.Tuple2; import tberg.murphy.indexer.Indexer; /** * @author Dan Garrette (dhgarrette@gmail.com) */ public class InterpolatingSingleLanguageModel implements SingleLanguageModel { private static final long serialVersionUID = 1L; private SingleLanguageModel[] subModels; private double[] interpWeights; private int numModels; private Indexer<String> charIndexer = null; private Set<Integer> activeCharacters = null; private int maxOrder = -1; public InterpolatingSingleLanguageModel(List<Tuple2<SingleLanguageModel, Double>> subModelsAndinterpWeights) { numModels = subModelsAndinterpWeights.size(); subModels = new SingleLanguageModel[numModels]; interpWeights = new double[numModels]; double totalInterpWeight = 0.0; for (int i = 0; i < numModels; ++i) { Tuple2<SingleLanguageModel, Double> modelAndWeight = subModelsAndinterpWeights.get(i); subModels[i] = modelAndWeight._1; interpWeights[i] = modelAndWeight._2; totalInterpWeight += interpWeights[i]; if (charIndexer == null) { charIndexer = subModels[i].getCharacterIndexer(); activeCharacters = subModels[i].getActiveCharacters(); int thisMaxOrder = subModels[i].getMaxOrder(); if (thisMaxOrder > maxOrder) maxOrder = thisMaxOrder; } else if (charIndexer != subModels[i].getCharacterIndexer()) { throw new RuntimeException("Sub-models don't all use the same character indexer"); } else if (activeCharacters != subModels[i].getActiveCharacters()) { throw new RuntimeException("Sub-models don't all use the same active-character set"); } } for (int i = 0; i < numModels; ++i) { interpWeights[i] /= totalInterpWeight; } } @Override public double getCharNgramProb(int[] context, int c) { double probSum = 0.0; for (int i = 0; i < numModels; ++i) { int[] shrunkenContext = subModels[i].shrinkContext(context); // context may be different for different submodels probSum += subModels[i].getCharNgramProb(shrunkenContext, c) * interpWeights[i]; } return probSum; } @Override public Indexer<String> getCharacterIndexer() { return charIndexer; } @Override public Set<Integer> getActiveCharacters() { return activeCharacters; } @Override public int getMaxOrder() { return maxOrder; } @Override public int[] shrinkContext(int[] originalContext) { int[] newContext = originalContext; while (!containsContext(newContext) && newContext.length > 0) { newContext = ArrayHelper.takeRight(newContext, newContext.length - 1); } return newContext; } @Override public boolean containsContext(int[] context) { for (SingleLanguageModel slm : subModels) { if (slm.containsContext(context)) { return true; } } return false; } public SingleLanguageModel getSubModel(int i) { return subModels[i]; } }