/** * */ package joshua.discriminative.training.risk_annealer.hypergraph.parallel; import java.util.HashMap; import java.util.Map; import java.util.concurrent.BlockingQueue; import joshua.discriminative.semiring_parsing.MinRiskDAFuncValSemiringParser; import joshua.discriminative.semiring_parsing.MinRiskDAGradientSemiringParser; import joshua.discriminative.semiring_parsingv2.applications.min_risk_da.MinRiskDADenseFeaturesSemiringParser; import joshua.discriminative.training.parallel.Consumer; import joshua.discriminative.training.risk_annealer.hypergraph.FeatureForest; import joshua.discriminative.training.risk_annealer.hypergraph.HGAndReferences; import joshua.discriminative.training.risk_annealer.hypergraph.HGRiskGradientComputer; import joshua.discriminative.training.risk_annealer.hypergraph.RiskAndFeatureAnnotationOnLMHG; public class GradientConsumer extends Consumer<HGAndReferences> { private final HGRiskGradientComputer gradientComputer; MinRiskDADenseFeaturesSemiringParser gradientSemiringParserV2; MinRiskDAGradientSemiringParser gradientSemiringParserV1; MinRiskDAFuncValSemiringParser funcValSemiringParserV1; RiskAndFeatureAnnotationOnLMHG riskAnnotator; double[] weightsForTheta; double scalingFactor; boolean shouldComputeGradientForScalingFactor; double temperature; boolean useSemiringV2=true; //static private Logger logger = Logger.getLogger(GradientConsumer.class.getSimpleName()); public GradientConsumer(boolean useSemiringV2, HGRiskGradientComputer gradientComputer, BlockingQueue<HGAndReferences> q, double[] weightsForTheta, RiskAndFeatureAnnotationOnLMHG riskAnnotator, double temperature, double scalingFactor, boolean shouldComputeGradientForScalingFactor) { super(q); this.useSemiringV2 = useSemiringV2; this.gradientComputer = gradientComputer; this.weightsForTheta = weightsForTheta; this.riskAnnotator = riskAnnotator; this.temperature = temperature; this.scalingFactor = scalingFactor; this.shouldComputeGradientForScalingFactor = shouldComputeGradientForScalingFactor; if(useSemiringV2){ //System.out.println("----------------useSemiringV2"); this.gradientSemiringParserV2 = new MinRiskDADenseFeaturesSemiringParser(this.temperature); }else{ //System.out.println("----------------useSemiringV1"); this.gradientSemiringParserV1 = new MinRiskDAGradientSemiringParser(1, 0, scalingFactor, temperature); this.funcValSemiringParserV1 =new MinRiskDAFuncValSemiringParser(1, 0, scalingFactor, temperature); } } @Override public void consume(HGAndReferences hgAndRefs) { FeatureForest fForest = riskAnnotator.riskAnnotationOnHG(hgAndRefs.hg, hgAndRefs.referenceSentences); fForest.setFeatureWeights(weightsForTheta); fForest.setScale(scalingFactor); /** Based on a model and a test hypergraph * (which provides the topology and feature/risk annotation), * compute the gradient and function value. **/ if(this.useSemiringV2){ consumeHelperV2(fForest); }else{ consumeHelperV1(fForest); } } private void consumeHelperV1(FeatureForest fForest){ gradientSemiringParserV1.setHyperGraph(fForest); HashMap<Integer, Double> gradients = gradientSemiringParserV1.computeGradientForTheta(); double gradientForScalingFactor = 0; if(shouldComputeGradientForScalingFactor) gradientForScalingFactor -= computeGradientForScalingFactor(gradients, weightsForTheta, scalingFactor);//we are maximizing, instead of minizing //== compute function value funcValSemiringParserV1.setHyperGraph(fForest); double funcVal = funcValSemiringParserV1.computeFunctionVal();//risk-T*entroy double risk = funcValSemiringParserV1.getRisk(); double entropy = funcValSemiringParserV1.getEntropy(); //== accumulate gradient and function value //risk-T*entroy this.gradientComputer.accumulateGradient(gradients, gradientForScalingFactor, funcVal, risk, entropy); //logger.info("=====consumed one sentence "); } private void consumeHelperV2(FeatureForest fForest){ // @todo: we should check if hg_test is a feature forest or not gradientSemiringParserV2.setHyperGraph(fForest); //== compute gradient and function value HashMap<Integer, Double> gradients = gradientSemiringParserV2.computeGradientForTheta(); double gradientForScalingFactor = 0; if(this.shouldComputeGradientForScalingFactor) gradientForScalingFactor = computeGradientForScalingFactor(gradients, weightsForTheta, scalingFactor); double funcVal = gradientSemiringParserV2.getFuncVal();//risk-T*entroy double risk = gradientSemiringParserV2.getRisk(); double entropy = gradientSemiringParserV2.getEntropy(); //== accumulate gradient and function value: //risk-T*entroy this.gradientComputer.accumulateGradient(gradients, gradientForScalingFactor, funcVal, risk, entropy); //logger.info("=====consumed one sentence "); } @Override public boolean isPoisonObject(HGAndReferences x) { return (x.hg==null); } private double computeGradientForScalingFactor(HashMap<Integer, Double> gradientForTheta, double[] weightsForTheta, double scale){ double gradientForScale = 0; for(Map.Entry<Integer, Double> feature : gradientForTheta.entrySet()){ gradientForScale += weightsForTheta[feature.getKey()] * feature.getValue(); //System.out.println("**featureWeights[i]: " + featureWeights[i] + "; gradientForTheta[i]: " + gradientForTheta[i] + "; gradientForScale" + gradientForScale); } gradientForScale /= scale; //System.out.println("****gradientForScale" + gradientForScale + "; scale: " + scale ); if(Double.isNaN(gradientForScale)){ System.out.println("gradient value for scaling is NaN"); System.exit(1); } //System.out.println("Gradient for scale is : " + gradientForScale); return gradientForScale; } }