package joshua.discriminative.training.risk_annealer.hypergraph; import java.util.HashMap; import java.util.Map; import joshua.decoder.hypergraph.HGNode; import joshua.decoder.hypergraph.HyperEdge; import joshua.decoder.hypergraph.HyperGraph; import joshua.discriminative.semiring_parsingv2.SignedValue; import joshua.discriminative.semiring_parsingv2.pmodule.SparseMap; /** * * ideally, we should first process the hypergraph to store risk and feature information, * and then using a feature filter to filter unwanted features. * Each feature should have a unique feature ID. **/ public class FeatureForest extends HyperGraph { //latest model double[] featureWeights; double scale; public FeatureForest(HyperGraph hg) { super(hg.goalNode, hg.numNodes, hg.numEdges, hg.sentID, hg.sentLen); } public void setFeatureWeights(double[] featureWeights){ this.featureWeights = featureWeights; } public void setScale(double scale){ this.scale = scale; } public final HashMap<Integer, Double> featureExtraction(HyperEdge dt, HGNode parentItem){ return ((FeatureHyperEdge)dt).featureTbl;//TODO } public final double getEdgeRisk( HyperEdge dt){ if(dt.getRule() == null){//hyperedges under goal item does not contribute BLEU return 0; }else{ return ((FeatureHyperEdge)dt).transitionRisk;//TODO } } //edge transition log-probability public final double getEdgeLogTransitionProb(HyperEdge edge, HGNode parentItem){ double transitionLogP =0; /**assume all feature are fired **/ HashMap<Integer, Double> features = featureExtraction(edge, parentItem); for(Map.Entry<Integer, Double> feature : features.entrySet()){ int featID = feature.getKey(); transitionLogP += this.featureWeights[featID] * feature.getValue(); } return scale * transitionLogP; } public final SparseMap getGradientSparseMap(HGNode parentItem, HyperEdge dt, double logTransitionProb){ HashMap<Integer, Double> features = featureExtraction(dt, parentItem); HashMap<Integer, SignedValue> gradientsMap = new HashMap<Integer, SignedValue>(); for(Map.Entry<Integer, Double> feature : features.entrySet()){ int featID = feature.getKey(); //P_e * \gamma * \Phi(e) SignedValue logGradient = SignedValue.createSignedValueFromRealNumber( scale*feature.getValue() ); logGradient.multiLogNumber(logTransitionProb); gradientsMap.put(featID, logGradient); } return new SparseMap(gradientsMap); } }