package joshua.discriminative.semiring_parsing; import java.util.HashMap; import java.util.Map; public class VarianceSemiringHashMap implements CompositeSemiring { //TODO: assuming the following are always in *log* semiring private double logProb; private SignedValue factor1; private HashMap<Integer, SignedValue> factor2; private HashMap<Integer, SignedValue> combinedfactor; public VarianceSemiringHashMap(){ factor1 = new SignedValue(); factor2 = new HashMap<Integer, SignedValue>(); combinedfactor = new HashMap<Integer, SignedValue>(); } public VarianceSemiringHashMap(double logProb_, SignedValue factor1_, HashMap<Integer, SignedValue> factor2_, HashMap<Integer, SignedValue> combinedfactor_){ logProb = logProb_; factor1 = factor1_; factor2 = factor2_; combinedfactor = combinedfactor_; } public void setToZero(AtomicSemiring atomic){ logProb = atomic.ATOMIC_ZERO_IN_SEMIRING; factor1.setZero(); for (SignedValue val : factor2.values()) { val.setZero(); } for (SignedValue val : combinedfactor.values()) { val.setZero(); } } public void setToOne(AtomicSemiring atomic){ logProb = atomic.ATOMIC_ONE_IN_SEMIRING; /**Note that factor should always set as zero. * For example, when the factor is expected length, it should start from zero * */ factor1.setZero(); for (SignedValue val : factor2.values()) { val.setZero(); } for (SignedValue val : combinedfactor.values()) { val.setZero(); } } public void add(CompositeSemiring b, AtomicSemiring atomic){ VarianceSemiringHashMap b2 = (VarianceSemiringHashMap)b; this.logProb = atomic.add_in_atomic_semiring(this.logProb, b2.logProb); //this.factor1 = this.factor1 + b2.factor1;//real semiring this.factor1 = SignedValue.add(this.factor1, b2.factor1); for(Map.Entry<Integer, SignedValue> entry : b2.factor2.entrySet()){ Integer key = entry.getKey(); SignedValue valB = entry.getValue(); SignedValue valA = this.factor2.get(key); if(valA!=null){ this.factor2.put(key, SignedValue.add(valA, valB)); }else{ //this.factor2.put(key, valB);//TODO should duplicate valB?? this.factor2.put(key, SignedValue.clone(valB)); } } for(Map.Entry<Integer, SignedValue> entry : b2.combinedfactor.entrySet()){ Integer key = entry.getKey(); SignedValue valB = entry.getValue(); SignedValue valA = this.combinedfactor.get(key); if(valA!=null){ this.combinedfactor.put(key, SignedValue.add(valA, valB)); }else{ //this.combinedfactor.put(key, valB);//TODO should duplicate valB?? this.combinedfactor.put(key, SignedValue.clone(valB)); } } } public void multi(CompositeSemiring b, AtomicSemiring atomic){ VarianceSemiringHashMap b2 = (VarianceSemiringHashMap)b; //first update combinedFactor, then factor2, then factor, and then logProb for(Map.Entry<Integer, SignedValue> entry : b2.combinedfactor.entrySet()){ Integer key = entry.getKey(); SignedValue combinedB = entry.getValue(); SignedValue combinedA = this.combinedfactor.get(key); SignedValue factor2B = b2.factor2.get(key); if(combinedA!=null){ SignedValue factor2A = this.factor2.get(key); SignedValue part1 = SignedValue.add( SignedValue.multi(this.logProb, combinedB), SignedValue.multi(b2.logProb, combinedA) ); SignedValue part2 = SignedValue.add( SignedValue.multi( this.factor1, factor2B), SignedValue.multi(factor2A, b2.factor1) ); this.combinedfactor.put(key, SignedValue.add(part1, part2)); this.factor2.put(key, SignedValue.add( SignedValue.multi(this.logProb, factor2B), SignedValue.multi(b2.logProb, factor2A) )); }else{ SignedValue part1 = SignedValue.add( SignedValue.multi(this.logProb, combinedB), SignedValue.multi( this.factor1, factor2B) ); this.combinedfactor.put(key, part1); this.factor2.put(key, SignedValue.multi(this.logProb, factor2B)); } } //now update entries that are in myself, but not in b for(Map.Entry<Integer, SignedValue> entry : this.combinedfactor.entrySet()){ Integer key = entry.getKey(); SignedValue combinedA = entry.getValue(); SignedValue combinedB = b2.combinedfactor.get(key); SignedValue factor2A = this.factor2.get(key); //we already dealed with the case combinedB!=null above if(combinedB==null){ SignedValue part1 = SignedValue.add( SignedValue.multi(b2.logProb, combinedA), SignedValue.multi(factor2A, b2.factor1) ); this.combinedfactor.put(key, part1); this.factor2.put(key, SignedValue.multi(b2.logProb, factor2A)); } } // this.factor1 = Math.exp(oldLogProb)* b2.factor1 + Math.exp(b2.logProb) * oldFactor1; this.factor1 = SignedValue.add( SignedValue.multi(this.logProb, b2.factor1), SignedValue.multi(b2.logProb, this.factor1) ); this.logProb = atomic.multi_in_atomic_semiring(this.logProb, b2.logProb); } public void normalizeFactors(){ /**we should not normalize the probability at each intermediate node * because our model is a global model??? */ /**originallly, the factor value is \sum_x p(x).v(x), where p(x) is not normalized, meaning \sum_x p(x)!=1; * we need to normalize p(x) by divide out Math.exp(prob) * */ //this.factor1 = factor1/Math.exp(logProb); this.factor1 = SignedValue.multi(-logProb, this.factor1); /*this is wrong!!!! for(FactorAtomicSemiring val : this.factor2.values()){ val = FactorAtomicSemiring.multi(-logProb, val);//TODO put(key,val) }*/ for(Map.Entry<Integer, SignedValue> entry : this.factor2.entrySet()){ entry.setValue(SignedValue.multi(-logProb, entry.getValue())); } for(Map.Entry<Integer, SignedValue> entry : this.combinedfactor.entrySet()){ entry.setValue(SignedValue.multi(-logProb, entry.getValue())); } } public void printInfor(){ System.out.println("prob: " + logProb); System.out.println("factor1: " + factor1.convertRealValue()); System.out.print("factor2:"); for(Map.Entry<Integer, SignedValue> entry : this.factor2.entrySet()){ System.out.print(" " + entry.getKey() + "=" + entry.getValue().convertRealValue()); } System.out.print("\nfactor1*factor2:"); for(Map.Entry<Integer, SignedValue> entry : this.factor2.entrySet()){ System.out.print(" " + entry.getKey() + "=" + entry.getValue().convertRealValue()); System.out.print(" " + (factor1.convertRealValue()*entry.getValue().convertRealValue())); } System.out.print("\ncombinedfactor: "); for(Map.Entry<Integer, SignedValue> entry : this.combinedfactor.entrySet()){ System.out.print(" " + entry.getKey() + "=" + entry.getValue().convertRealValue()); } System.out.print("\n"); } public void printInfor2(){ //do nothing } public double getLogProb(){ return logProb; } public SignedValue getFactor1(){ return factor1; } public HashMap<Integer, SignedValue> getFactor2(){ return factor2; } public HashMap<Integer, SignedValue> getCombinedfactor(){ return combinedfactor; } }