/* This file is part of the Joshua Machine Translation System.
*
* Joshua is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2.1
* of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this library; if not, write to the Free
* Software Foundation, Inc., 59 Temple Place, Suite 330, Boston,
* MA 02111-1307 USA
*/
package joshua.decoder.hypergraph;
import joshua.decoder.hypergraph.HyperGraph;
import java.util.HashMap;
/**
* to use the functions here, one need to extend the class to
* provide a way to calculate the transitionLogP based on feature
* set
*
* @author Zhifei Li, <zhifei.work@gmail.com>
* @version $LastChangedDate: 2010-01-14 21:11:52 -0600 (Thu, 14 Jan 2010) $
*/
//TODO: currently assume log semiring, need to generalize to other semiring
//already implement both max-product and sum-product algortithms for log-semiring
//Note: this class requires the correctness of transitionLogP of each hyperedge, which itself may require the correctness of bestDerivationLogP at each item
public abstract class DefaultInsideOutside {
/**
* Two operations: add and multi
* add: different hyperedges lead to a specific item
* multi: prob of a derivation is a multi of all constituents
*/
int ADD_MODE=0; //0: sum; 1: viterbi-min, 2: viterbi-max
int LOG_SEMIRING=1;
int SEMIRING=LOG_SEMIRING; //default is in log; or real, or logic
double ZERO_IN_SEMIRING = Double.NEGATIVE_INFINITY;//log-domain
double ONE_IN_SEMIRING = 0;//log-domain
double scaling_factor ; //try to scale the original distribution: smooth or winner-take-all
private HashMap<HGNode,Double> tbl_inside_prob = new HashMap<HGNode,Double>();//remember inside prob of each item:
private HashMap<HGNode,Double> tbl_outside_prob = new HashMap<HGNode,Double>();//remember outside prob of each item
double normalizationConstant = ONE_IN_SEMIRING;
/**
* for each item, remember how many deductions pointering
* to me, this is needed for outside estimation during
* outside estimation, an item will recursive call its
* deductions to do outside-estimation only after it itself
* is done with outside estimation, this is necessary
* because the outside estimation of the items under its
* deductions require the item's outside value
*/
private HashMap<HGNode,Integer> tbl_num_parent_deductions = new HashMap<HGNode,Integer>();
private HashMap<HGNode,Integer> tbl_for_sanity_check = null;
//get feature-set specific **log probability** for each hyperedge
protected abstract double getHyperedgeLogProb(HyperEdge dt, HGNode parent_it);
protected double getHyperedgeLogProb(HyperEdge dt, HGNode parent_it, double scaling_factor){
return getHyperedgeLogProb(dt, parent_it)*scaling_factor;
}
//the results are stored in tbl_inside_prob and tbl_outside_prob
public void runInsideOutside(HyperGraph hg, int add_mode, int semiring, double scaling_factor_){//add_mode||| 0: sum; 1: viterbi-min, 2: viterbi-max
setup_semiring(semiring, add_mode);
scaling_factor = scaling_factor_;
//System.out.println("outside estimation");
inside_estimation_hg(hg);
//System.out.println("inside estimation");
outside_estimation_hg(hg);
normalizationConstant = tbl_inside_prob.get(hg.goalNode);
System.out.println("normalization constant is " + normalizationConstant);
tbl_num_parent_deductions.clear();
sanityCheckHG(hg);
}
//to save memory, external class should call this method
public void clearState(){
tbl_num_parent_deductions.clear();
tbl_inside_prob.clear();
tbl_outside_prob.clear();
}
//######### use of inside-outside probs ##########################
//this is the logZ where Z is the sum[ exp( log prob ) ]
public double getLogNormalizationConstant(){
return normalizationConstant;
}
//this is the log of expected/posterior prob (i.e., LogP, where P is the posterior probability), without normalization
public double getEdgeUnormalizedPosteriorLogProb(HyperEdge dt, HGNode parent){
//### outside of parent
double outside = (Double)tbl_outside_prob.get(parent);
//### get inside prob of all my ant-items
double inside = ONE_IN_SEMIRING;
if(dt.getAntNodes()!=null){
for(HGNode ant_it : dt.getAntNodes())
inside = multi_in_semiring(inside,(Double)tbl_inside_prob.get(ant_it));
}
//### add deduction/rule specific prob
double merit = multi_in_semiring(inside, outside);
merit = multi_in_semiring(merit, getHyperedgeLogProb(dt, parent, this.scaling_factor));
return merit;
}
//normalized probabily in [0,1]
public double getEdgePosteriorProb(HyperEdge dt, HGNode parent ){
if(SEMIRING==LOG_SEMIRING){
double res = Math.exp((getEdgeUnormalizedPosteriorLogProb(dt, parent)-getLogNormalizationConstant()));
if (res < 0.0-1e-2 || res > 1.0+1e-2) {
throw new RuntimeException("res is not within [0,1], must be wrong value: " + res);
}
return res;
} else {
throw new RuntimeException("not implemented");
}
}
// this is the log of expected/posterior prob (i.e., LogP, where P is the posterior probability), without normalization
public double getNodeUnnormalizedPosteriorLogProb(HGNode node){
//### outside of parent
double inside = (Double)tbl_inside_prob.get(node);
double outside = (Double)tbl_outside_prob.get(node);
return multi_in_semiring(inside, outside);
}
// normalized probabily in [0,1]
public double getNodePosteriorProb(HGNode node ){
if(SEMIRING==LOG_SEMIRING){
double res = Math.exp((getNodeUnnormalizedPosteriorLogProb(node)-getLogNormalizationConstant()));
if (res < 0.0-1e-2 || res > 1.0+1e-2) {
throw new RuntimeException("res is not within [0,1], must be wrong value: " + res);
}
return res;
} else {
throw new RuntimeException("not implemented");
}
}
/*Originally, to see if the sum of the posterior probabilities of all the hyperedges sum to one
* However, this won't work! The sum should be greater than 1.
* */
public void sanityCheckHG(HyperGraph hg){
tbl_for_sanity_check = new HashMap<HGNode,Integer>();
//System.out.println("num_dts: " + hg.goal_item.l_deductions.size());
sanity_check_item(hg.goalNode);
System.out.println("survied sanity check!!!!");
}
private void sanity_check_item(HGNode it){
if(tbl_for_sanity_check.containsKey(it))return;
tbl_for_sanity_check.put(it,1);
double prob_sum=0;
//### recursive call on each deduction
for(HyperEdge dt : it.hyperedges){
prob_sum += getEdgePosteriorProb(dt,it);
sanity_check_deduction(dt);//deduction-specifc operation
}
double supposed_sum = getNodePosteriorProb(it);
if (Math.abs(prob_sum-supposed_sum) > 1e-3) {
throw new RuntimeException("prob_sum=" + prob_sum + "; supposed_sum=" + supposed_sum + "; sanity check fail!!!!");
}
//### item-specific operation
}
private void sanity_check_deduction(HyperEdge dt){
//### recursive call on each ant item
if (null != dt.getAntNodes()) {
for (HGNode ant_it : dt.getAntNodes()) {
sanity_check_item(ant_it);
}
}
//### deduction-specific operation
}
//################## end use of inside-outside probs
//############ bottomn-up insdide estimation ##########################
private void inside_estimation_hg(HyperGraph hg) {
tbl_inside_prob.clear();
tbl_num_parent_deductions.clear();
inside_estimation_item(hg.goalNode);
}
private double inside_estimation_item(HGNode it) {
//### get number of deductions that point to me
Integer num_called = (Integer)tbl_num_parent_deductions.get(it);
if (null == num_called) {
tbl_num_parent_deductions.put(it, 1);
} else {
tbl_num_parent_deductions.put(it, num_called+1);
}
if (tbl_inside_prob.containsKey(it)) {
return (Double) tbl_inside_prob.get(it);
}
double inside_prob = ZERO_IN_SEMIRING;
//### recursive call on each deduction
for (HyperEdge dt : it.hyperedges) {
double v_dt = inside_estimation_deduction(dt, it);//deduction-specifc operation
inside_prob = add_in_semiring(inside_prob, v_dt);
}
//### item-specific operation, but all the prob should be factored into each deduction
tbl_inside_prob.put(it,inside_prob);
return inside_prob;
}
private double inside_estimation_deduction(HyperEdge dt, HGNode parent_item){
double inside_prob = ONE_IN_SEMIRING;
//### recursive call on each ant item
if(dt.getAntNodes()!=null)
for(HGNode ant_it : dt.getAntNodes()){
double v_item = inside_estimation_item(ant_it);
inside_prob = multi_in_semiring(inside_prob, v_item);
}
//### deduction operation
double deduct_prob = getHyperedgeLogProb(dt, parent_item, this.scaling_factor);//feature-set specific
inside_prob = multi_in_semiring(inside_prob, deduct_prob);
return inside_prob;
}
//########### end inside estimation
//############ top-downn outside estimation ##########################
private void outside_estimation_hg(HyperGraph hg){
tbl_outside_prob.clear();
tbl_outside_prob.put(hg.goalNode, ONE_IN_SEMIRING);//initialize
for(HyperEdge dt : hg.goalNode.hyperedges)
outside_estimation_deduction(dt, hg.goalNode);
}
private void outside_estimation_item(HGNode cur_it, HGNode upper_item, HyperEdge parent_dt, double parent_deduct_prob){
Integer num_called = (Integer)tbl_num_parent_deductions.get(cur_it);
if (null == num_called || 0 == num_called) {
throw new RuntimeException("un-expected call, must be wrong");
}
tbl_num_parent_deductions.put(cur_it, num_called-1);
double old_outside_prob = ZERO_IN_SEMIRING;
if (tbl_outside_prob.containsKey(cur_it)) {
old_outside_prob = (Double) tbl_outside_prob.get(cur_it);
}
double additional_outside_prob = ONE_IN_SEMIRING;
//### add parent deduction prob
additional_outside_prob = multi_in_semiring(additional_outside_prob, parent_deduct_prob);
//### sibing specifc
if(parent_dt.getAntNodes()!=null && parent_dt.getAntNodes().size()>1)
for(HGNode ant_it : parent_dt.getAntNodes()){
if(ant_it != cur_it){
double inside_prob_item =(Double)tbl_inside_prob.get(ant_it);//inside prob
additional_outside_prob = multi_in_semiring(additional_outside_prob, inside_prob_item);
}
}
//### upper item
double outside_prob_item = (Double)tbl_outside_prob.get(upper_item);//outside prob
additional_outside_prob = multi_in_semiring(additional_outside_prob, outside_prob_item);
//#### add to old prob
additional_outside_prob = add_in_semiring(additional_outside_prob, old_outside_prob);
tbl_outside_prob.put(cur_it, additional_outside_prob);
//### recursive call on each deduction
if( num_called-1<=0){//i am done
for(HyperEdge dt : cur_it.hyperedges){
//TODO: potentially, we can collect the feature expection in each hyperedge here, to avoid another pass of the hypergraph to get the counts
outside_estimation_deduction(dt, cur_it);
}
}
}
private void outside_estimation_deduction(HyperEdge dt, HGNode parent_item){
//we do not need to outside prob if no ant items
if(dt.getAntNodes()!=null){
//### deduction specific prob
double deduction_prob = getHyperedgeLogProb(dt, parent_item, this.scaling_factor);//feature-set specific
//### recursive call on each ant item
for(HGNode ant_it : dt.getAntNodes()){
outside_estimation_item(ant_it, parent_item, dt, deduction_prob);
}
}
}
//########### end outside estimation
//############ common ##########################
// BUG: replace integer pseudo-enum with a real Java enum
// BUG: use a Semiring class instead of all this?
private void setup_semiring(int semiring, int add_mode) {
ADD_MODE = add_mode;
SEMIRING = semiring;
if (SEMIRING == LOG_SEMIRING) {
if (ADD_MODE == 0) { // sum
ZERO_IN_SEMIRING = Double.NEGATIVE_INFINITY;
ONE_IN_SEMIRING = 0;
} else if (ADD_MODE == 1) { // viter-min
ZERO_IN_SEMIRING = Double.POSITIVE_INFINITY;
ONE_IN_SEMIRING = 0;
} else if (ADD_MODE == 2) { // viter-max
ZERO_IN_SEMIRING = Double.NEGATIVE_INFINITY;
ONE_IN_SEMIRING = 0;
} else {
throw new RuntimeException("invalid add mode");
}
} else {
throw new RuntimeException("un-supported semiring");
}
}
private double multi_in_semiring(double x, double y) {
if (SEMIRING == LOG_SEMIRING) {
return multi_in_log_semiring(x,y);
} else {
throw new RuntimeException("un-supported semiring");
}
}
private double add_in_semiring(double x, double y) {
if (SEMIRING == LOG_SEMIRING) {
return add_in_log_semiring(x,y);
} else {
throw new RuntimeException("un-supported semiring");
}
}
//AND
private double multi_in_log_semiring(double x, double y) { // value is Log prob
return x + y;
}
//OR: return Math.log(Math.exp(x) + Math.exp(y));
// BUG: Replace ADD_MODE pseudo-enum with a real Java enum
private double add_in_log_semiring(double x, double y) { // prevent under-flow
if (ADD_MODE == 0) { // sum
if (x == Double.NEGATIVE_INFINITY) { // if y is also n-infinity, then return n-infinity
return y;
}
if (y == Double.NEGATIVE_INFINITY) {
return x;
}
if (y <= x) {
return x + Math.log(1+Math.exp(y-x));
} else {
return y + Math.log(1+Math.exp(x-y));
}
} else if (ADD_MODE == 1) { // viter-min
return (x <= y ? x : y);
} else if (ADD_MODE == 2) { // viter-max
return (x >= y ? x : y);
} else {
throw new RuntimeException("invalid add mode");
}
}
//############ end common #####################
}