package joshua.discriminative.training.risk_annealer.hypergraph; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.logging.Logger; import joshua.corpus.vocab.SymbolTable; import joshua.decoder.BLEU; import joshua.decoder.ff.lm.NgramExtractor; import joshua.decoder.hypergraph.HGNode; import joshua.decoder.hypergraph.HyperEdge; import joshua.decoder.hypergraph.HyperGraph; import joshua.discriminative.feature_related.feature_template.FeatureTemplate; /**The way we extract features (stored in featureTbl) for each edge is as following: * (1) Feature template will return feature-name (i.e. string) and feature value. * (2) This class will convert the feature-name to feature-id (i.e. integer) by using featureStringToIntegerMap * (3) The featureStringToIntegerMap may be also used for feature filtering * */ public class RiskAndFeatureAnnotationOnLMHG { //private SymbolTable symbolTbl; private HashSet<HGNode> processedHGNodesTbl = new HashSet<HGNode>(); // == variables related to BLEU risk private boolean doRiskAnnotation = true;//TODO private double[] linearCorpusGainThetas; //weights in the Goolge linear corpus gain function private NgramExtractor ngramExtractor; private int baselineLMOrder; private static int startNgramOrder=1; private static int endNgramOrder=4; // == variables related to feature annotation private boolean doFeatureAnnotation = true; private HashMap<String, Integer> featureStringToIntegerMap; //this can also be used as feature filter private HashSet<String> restrictedFeatureSet; private List<FeatureTemplate> featTemplates; double scale = 1.0; //TODO Logger logger = Logger.getLogger(RiskAndFeatureAnnotationOnLMHG.class.getSimpleName()); public RiskAndFeatureAnnotationOnLMHG(int baselineLMOrder, int ngramStateID, double[] linearCorpusGainThetas, SymbolTable symbolTbl, HashMap<String, Integer> featureStringToIntegerMap, List<FeatureTemplate> featTemplates, boolean doRiskAnnotation){ this.baselineLMOrder = baselineLMOrder; this.doRiskAnnotation = doRiskAnnotation; if(this.baselineLMOrder<endNgramOrder){ System.out.println("Baseline LM n-gram order is too small"); System.exit(1); } //this.symbolTbl = symbolTbl; this.linearCorpusGainThetas = linearCorpusGainThetas; this.ngramExtractor = new NgramExtractor(symbolTbl, ngramStateID, false, baselineLMOrder); //=== feature related this.featureStringToIntegerMap = featureStringToIntegerMap; this.restrictedFeatureSet = new HashSet<String>( featureStringToIntegerMap.keySet() ); this.featTemplates = featTemplates; //System.out.println("use riskAnnotatorNoEquiv===="); } /**Input a hypergraph, return * a FeatureForest * Note that the input hypergraph has been changed*/ public FeatureForest riskAnnotationOnHG(HyperGraph hg, String[] referenceSentences){ processedHGNodesTbl.clear(); if(doRiskAnnotation){ HashMap<String, Integer> refereceNgramTable = BLEU.constructMaxRefCountTable(referenceSentences, endNgramOrder); annotateNode(hg.goalNode, refereceNgramTable); }else{ annotateNode(hg.goalNode, null); } releaseNodesStateMemroy(); return new FeatureForest(hg); } private void annotateNode(HGNode node, HashMap<String, Integer> refereceNgramTable){ if(processedHGNodesTbl.contains(node)) return; processedHGNodesTbl.add(node); //=== recursive call on each edge for(int i=0; i<node.hyperedges.size(); i++){ HyperEdge oldEdge = node.hyperedges.get(i); HyperEdge newEdge = annotateHyperEdge(oldEdge, refereceNgramTable); node.hyperedges.set(i, newEdge); } //===@todo: release the memory consumed by the state of node, but we have to make sure all parent hyperedges have been processed } private HyperEdge annotateHyperEdge(HyperEdge oldEdge, HashMap<String, Integer> refereceNgramTable){ //=== recursive call on each ant item if(oldEdge.getAntNodes()!=null) for(HGNode antNode : oldEdge.getAntNodes()) annotateNode(antNode, refereceNgramTable); //=== HyperEdge-specific operation return createNewHyperEdge(oldEdge, refereceNgramTable); } protected HyperEdge createNewHyperEdge(HyperEdge oldEdge, HashMap<String, Integer> refereceNgramTable) { //======== risk annotation double transitionRisk = 0; if(doRiskAnnotation) transitionRisk = getTransitionRisk(oldEdge, refereceNgramTable); //System.out.println("tran2=" + riskTransitionCost); //======== feature annotation HashMap<Integer, Double> featureTbl= null; if(doFeatureAnnotation) featureTbl = featureExtraction(oldEdge, null, scale); /**compared wit the original edge, two changes: * (1) add risk at edge (but does not change the orignal model score) * (2) add feature tbl * */ return new FeatureHyperEdge(oldEdge, featureTbl, transitionRisk); } private void releaseNodesStateMemroy(){ for(HGNode node : processedHGNodesTbl){ //System.out.println("releaseNodesStateMemroy"); node.releaseDPStatesMemory(); } } private double getTransitionRisk(HyperEdge dt, HashMap<String, Integer> refereceNgramTable){ double transitionRisk = 0; if(dt.getRule() != null){//note: hyperedges under goal item does not contribute BLEU int hypLength = dt.getRule().getEnglish().length-dt.getRule().getArity(); HashMap<String, Integer> hyperedgeNgramTable = ngramExtractor.getTransitionNgrams(dt, startNgramOrder, endNgramOrder); transitionRisk = - BLEU.computeLinearCorpusGain(linearCorpusGainThetas, hypLength, hyperedgeNgramTable, refereceNgramTable); /* System.out.println("hyp tbl: " + hyperedgeNgramTable); System.out.println("ref tbl: " + refereceNgramTable); System.out.println("hypLength: " + hypLength); System.out.println("risk is " + transitionRisk); System.exit(1);*/ } return transitionRisk; } //TODO: copied from RiskAndFeatureAnnotation, consider merge //============================================================================================ //==================================== feature extraction function ====================================== //============================================================================================ /**The way we extract features (stored in featureTbl) for each edge is as following: * (1) Feature template will return feature-name (i.e. string) and feature value. * (2) This class will convert the feature-name to feature-id (i.e. integer) by using featureStringToIntegerMap * (3) The featureStringToIntegerMap (derive restrictedFeatureSet) will be also used for feature filtering * */ private final HashMap<Integer, Double> featureExtraction(HyperEdge dt, HGNode parentItem, double scale){ //=== extract feature counts HashMap<String, Double> activeFeaturesHelper = new HashMap<String, Double>(); for(FeatureTemplate template : featTemplates){ template.getFeatureCounts(dt, activeFeaturesHelper, restrictedFeatureSet, scale); } //=== convert the featureString to featureInteger HashMap<Integer, Double> res = new HashMap<Integer, Double>(); for(Map.Entry<String, Double> feature : activeFeaturesHelper.entrySet()){ Integer featureID = featureStringToIntegerMap.get(feature.getKey()); if(featureID==null){ logger.severe("Null feature ID, featureID="+feature.getKey()); System.exit(1); } res.put(featureID, feature.getValue()); } //System.out.println("Feature extraction res: " + res); return res; } }