package joshua.discriminative.feature_related.feature_function; import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.logging.Logger; import joshua.decoder.chart_parser.SourcePath; import joshua.decoder.ff.DefaultStatelessFF; import joshua.decoder.ff.tm.Rule; import joshua.decoder.hypergraph.HGNode; import joshua.decoder.hypergraph.HyperEdge; import joshua.discriminative.DiscriminativeSupport; import joshua.discriminative.feature_related.feature_template.FeatureTemplate; /**Given a list of featureTemplates and a model, * this model extracts feature and compute the model score * */ //TODO: it is possible some featureTemplate may be stateful, e.g., the ngram baseline feature public class FeatureTemplateBasedFF extends DefaultStatelessFF { //corrective model: this should not have the baseline feature private HashMap<String, Double> model = null; private List<FeatureTemplate> featTemplates=null; private HashSet<String> restrictedFeatSet =null; //feature set private double scale = 1.0; private static Logger logger = Logger.getLogger(FeatureTemplateBasedFF.class.getName()); public FeatureTemplateBasedFF(int featID, double weight, FeatureTemplate featTemplate){ super(weight, -1, featID); this.model = null; this.featTemplates = new ArrayList<FeatureTemplate>(); this.featTemplates.add(featTemplate); this.restrictedFeatSet = null; logger.info("weight="+weight); } public FeatureTemplateBasedFF(int featID, double weight, HashMap<String, Double> correctiveModel, List<FeatureTemplate> featTemplates, HashSet<String> restrictedFeatSet){ super(weight, -1, featID); this.model = correctiveModel; this.featTemplates = featTemplates; this.restrictedFeatSet = restrictedFeatSet; logger.info("weight="+weight); } @Override public double transitionLogP(Rule rule, List<HGNode> antNodes, int spanStart, int spanEnd, SourcePath srcPath, int sentID){ return getTransitionLogP(rule, antNodes); } @Override public double transitionLogP(HyperEdge edge, int spanStart, int spanEnd, int sentID){ return getTransitionLogP(edge); } @Override public double finalTransitionLogP(HGNode antNode, int spanStart, int spanEnd, SourcePath srcPath, int sentID){ List<HGNode> antNodes = new ArrayList<HGNode>(); antNodes.add(antNode); return getTransitionLogP(null, antNodes); } @Override public double finalTransitionLogP(HyperEdge edge, int spanStart, int spanEnd, int sentID){ return getTransitionLogP(edge); } public double estimateLogP(Rule rule, int sentID) { return getEstimateLogP(rule); //return 0; } public void setModel(HashMap<String, Double> correctiveModel){ this.model = correctiveModel; } public HashMap<String, Double> getModel(){ return this.model; } private double getEstimateLogP(Rule rule){ //=== extract features HashMap<String, Double> featTbl = new HashMap<String, Double>(); for(FeatureTemplate template : featTemplates){ template.estimateFeatureCounts(rule, featTbl, restrictedFeatSet, scale); } //=== compute logP double res =0; res = DiscriminativeSupport.computeLinearCombinationLogP(featTbl, model); /*if(res!=0){ System.out.println("getEstimateLogP: " + res); System.out.println(featTbl); }*/ return res; } private double getTransitionLogP(Rule rule, List<HGNode> antNodes){ //=== extract features HashMap<String, Double> featTbl = new HashMap<String, Double>(); for(FeatureTemplate template : featTemplates){ template.getFeatureCounts(rule, antNodes, featTbl, restrictedFeatSet, scale); } //=== compute logP double res =0; res = DiscriminativeSupport.computeLinearCombinationLogP(featTbl, model); //System.out.println("TransitionLogP: " + res); return res; } private double getTransitionLogP(HyperEdge edge){ //=== extract features HashMap<String, Double> featTbl = new HashMap<String, Double>(); for(FeatureTemplate template : featTemplates){ template.getFeatureCounts(edge, featTbl, restrictedFeatSet, scale); } //=== compute logP double res =0; res = DiscriminativeSupport.computeLinearCombinationLogP(featTbl, model); //System.out.println("TransitionLogP: " + res); return res; } }