package joshua.discriminative.syntax_reorder;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import joshua.corpus.vocab.SymbolTable;
public class HashtableBasedHieroGrammarScorer {
public static int NUM_FEATS=4;//P(e|f),P(f|e), P_lex(e|f) , P_lex(f|e)
HashMap p_gram = new HashMap();
SymbolTable symbolTable = null; //TODO
public HashtableBasedHieroGrammarScorer( int num_feats){
NUM_FEATS=num_feats;
}
public void addRawRule(Rule rl){
String sig = rl.get_signature();
Rule old_rule = (Rule) p_gram.get(sig);
if(old_rule!=null){
for(int i=0; i< rl.feat_scores.length; i++)
old_rule.feat_scores[i] += rl.feat_scores[i];
}else{
p_gram.put(sig, rl);
}
}
public void score_grammar(){
//######first get the sum weights
HashMap sum_fr = new HashMap();
HashMap sum_eng = new HashMap();
HashMap xsum = new HashMap();//for nonterminals
Iterator rules = p_gram.values().iterator();
while(rules.hasNext()) {
Rule rl = (Rule)rules.next();
float weight = rl.feat_scores[0];
String fr_sig = rl.get_fr_signature();
String eng_sig = rl.get_eng_signature();
//for french
Float old_w = (Float)sum_fr.get(fr_sig);
if(old_w==null)
sum_fr.put(fr_sig,weight);
else
sum_fr.put(fr_sig,old_w+weight);
//for english
old_w = (Float)sum_eng.get(eng_sig);
if(old_w==null)
sum_eng.put(eng_sig,weight);
else
sum_eng.put(eng_sig,old_w+weight);
//for lhs
old_w = (Float)xsum.get(rl.lhs);
if(old_w==null)
xsum.put(rl.lhs,weight);
else
xsum.put(rl.lhs,old_w+weight);
}
//######now normalize the rules
int num_infinite=0;
rules = p_gram.values().iterator();
while(rules.hasNext()) {
Rule rl = (Rule)rules.next();
float fr_sum = (Float)sum_fr.get(rl.get_fr_signature());
float eng_sum = (Float)sum_eng.get(rl.get_eng_signature());
float x_sum = (Float)xsum.get(rl.lhs);
float[] new_scores = new float[NUM_FEATS+1];
float weight = rl.feat_scores[0];
if(weight == 0.0f)
continue;
new_scores[0] = -(float)Math.log10(weight/x_sum);//P(e,f|lhs)
new_scores[1] = -(float)Math.log10(weight/eng_sum);//P(f|e)
new_scores[2] = -(float)Math.log10(weight/fr_sum);//P(e|f)
if(NUM_FEATS==4){
new_scores[3] = -(float)Math.log10(rl.feat_scores[1]/weight);//weighted avg
new_scores[4] = -(float)Math.log10(rl.feat_scores[2]/weight);//weighted avg
if( Float.isInfinite(new_scores[3]) || Float.isInfinite(new_scores[4])) num_infinite++;
}
rl.feat_scores = new_scores;
//System.out.println("f_sum " + fr_sum + " e_sum " + eng_sum + " old: "+rl.feat_scores + " new " + new_scores);
rl.print_info(symbolTable);
}
System.out.println("invalid is "+num_infinite +"; number of unique rules are " +p_gram.size());
//dump the rules??
}
public static class Rule{
//Rule formate: [Phrase] ||| french ||| english ||| feature scores
public int lhs;//tag of this rule, state to upper layer
public int[] french;
public int[] english;
public float[] feat_scores;//the feature scores for this rule
public ArrayList alignments;
public Rule(int lhs_in, int[] fr_in, int[] eng_in){
lhs=lhs_in;
french = fr_in;
english = eng_in;
}
public String get_signature(){//lhs, french, and english
StringBuffer res = new StringBuffer();
res.append(lhs);
res.append(" ");
for(int i=0; i<french.length; i++){
res.append(french[i]);
res.append(" ");
}
for(int i=0; i<english.length; i++){
res.append(english[i]);
res.append(" ");
}
return res.toString();
}
public String get_fr_signature(){//lhs, french, and english
StringBuffer res = new StringBuffer();
for(int i=0; i<french.length; i++){
res.append(french[i]);
res.append(" ");
}
return res.toString();
}
public String get_eng_signature(){//lhs, french, and english
StringBuffer res = new StringBuffer();
for(int i=0; i<english.length; i++){
res.append(english[i]);
res.append(" ");
}
return res.toString();
}
public void print_info(SymbolTable symbolTable){
//Support.write_log("Rule is: "+ lhs + " ||| " + Support.arrayToString(french, " ") + " ||| " + Support.arrayToString(english, " ") + " |||", level);
System.out.print("Rule is: "+ symbolTable.getWord(lhs) + " ||| ");
for(int i=0; i<french.length;i++)
System.out.print(symbolTable.getWord(french[i]) +" ");
System.out.print("||| ");
for(int i=0; i<english.length;i++)
System.out.print(symbolTable.getWord(english[i]) +" ");
System.out.print("||| ");
for(int i=0; i< feat_scores.length; i++)
System.out.print(" " + feat_scores[i]);
System.out.print("\n");
}
}
}