package edu.stanford.nlp.sequences; import edu.stanford.nlp.math.ArrayMath; /** * @author grenager * Date: Dec 14, 2004 * @author nmramesh * Date: May 12, 2010 */ public class FactoredSequenceModel implements SequenceModel { // todo: The current version has variables for a 2 model version and arrays for an n-model version. Unify. private SequenceModel model1; private SequenceModel model2; private double model1Wt = 1.0; private double model2Wt = 1.0; private SequenceModel[] models = null; private double[] wts = null; /** {@inheritDoc} */ @Override public double[] scoresOf(int[] sequence, int pos) { if(models != null){ double[] dist = ArrayMath.multiply(models[0].scoresOf(sequence, pos),wts[0]); for(int i = 1; i < models.length; i++){ double[] dist_i = models[i].scoresOf(sequence, pos); ArrayMath.addMultInPlace(dist,dist_i,wts[i]); } return dist; } double[] dist1 = model1.scoresOf(sequence, pos); double[] dist2 = model2.scoresOf(sequence, pos); double[] dist = new double[dist1.length]; for(int i = 0; i < dist1.length; i++) dist[i] = model1Wt*dist1[i] + model2Wt*dist2[i]; return dist; } /** {@inheritDoc} */ @Override public double scoreOf(int[] sequence, int pos) { return scoresOf(sequence, pos)[sequence[pos]]; } /** {@inheritDoc} */ @Override public double scoreOf(int[] sequence) { if(models != null){ double score = 0; for(int i = 0; i < models.length; i++) score+= wts[i]*models[i].scoreOf(sequence); return score; } //return model1.scoreOf(sequence); return model1Wt*model1.scoreOf(sequence) + model2Wt*model2.scoreOf(sequence); } /** {@inheritDoc} */ @Override public int length() { if(models != null) return models[0].length(); return model1.length(); } /** {@inheritDoc} */ @Override public int leftWindow() { if(models != null) return models[0].leftWindow(); return model1.leftWindow(); } /** {@inheritDoc} */ @Override public int rightWindow() { if(models != null) return models[0].rightWindow(); return model1.rightWindow(); } /** {@inheritDoc} */ @Override public int[] getPossibleValues(int position) { if(models != null) return models[0].getPossibleValues(position); return model1.getPossibleValues(position); } /** * using this constructor results in a weighted addition of the two models' scores. * @param model1 * @param model2 * @param wt1 weight of model1 * @param wt2 weight of model2 */ public FactoredSequenceModel(SequenceModel model1, SequenceModel model2, double wt1, double wt2){ this(model1,model2); this.model1Wt = wt1; this.model2Wt = wt2; } public FactoredSequenceModel(SequenceModel model1, SequenceModel model2) { //if (model1.leftWindow() != model2.leftWindow()) throw new RuntimeException("Two models must have same window size"); if (model1.getPossibleValues(0).length != model2.getPossibleValues(0).length) throw new RuntimeException("Two models must have the same number of classes"); if (model1.length() != model2.length()) throw new RuntimeException("Two models must have the same sequence length"); this.model1 = model1; this.model2 = model2; } public FactoredSequenceModel(SequenceModel[] models, double[] weights){ this.models = models; this.wts = weights; /* for(int i = 1; i < models.length; i++){ if (models[0].getPossibleValues(0).length != models[i].getPossibleValues(0).length) throw new RuntimeException("All models must have the same number of classes"); if(models[0].length() != models[i].length()) throw new RuntimeException("All models must have the same sequence length"); } */ } }