package edu.umd.hooka.alignment.hmm; import java.io.IOException; import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.mapred.OutputCollector; import org.apache.hadoop.mapred.Reporter; import edu.umd.hooka.Alignment; import edu.umd.hooka.AlignmentPosteriorGrid; import edu.umd.hooka.Array2D; import edu.umd.hooka.PhrasePair; import edu.umd.hooka.alignment.PartialCountContainer; import edu.umd.hooka.alignment.PerplexityReporter; import edu.umd.hooka.alignment.CrossEntropyCounters; import edu.umd.hooka.alignment.ZeroProbabilityException; import edu.umd.hooka.alignment.model1.Model1; import edu.umd.hooka.ttables.TTable; /** * Represents an HMM that applies to a single sentence pair, which is * derived from the parameters stored in a TTable and an ATable object. * * @author redpony * */ public class HMM extends Model1 { public static final IntWritable ACOUNT_VOC_ID = new IntWritable(999999); static final int MAX_LENGTH = 500; static final float THRESH =0.5f; /** * (s,j) = p(f_j|e(s)) */ Array2D emission = new Array2D(MAX_LENGTH * MAX_LENGTH); /** * (s,j) = i s.t. e(s) = e_i or -1 if n.a. */ IntArray2D e_coords = new IntArray2D(MAX_LENGTH * MAX_LENGTH); /** * (s,j) = the english word corresponding to state s */ IntArray2D e_words = new IntArray2D(MAX_LENGTH * MAX_LENGTH); /** * (i',i) = p(i-i') */ Array2D transition = new Array2D(MAX_LENGTH * MAX_LENGTH); IntArray2D transition_coords = new IntArray2D(MAX_LENGTH * MAX_LENGTH); Array2D alphas = new Array2D(MAX_LENGTH * MAX_LENGTH); Array2D betas = new Array2D(MAX_LENGTH * MAX_LENGTH); Array2D viterbi = new Array2D(MAX_LENGTH * MAX_LENGTH); IntArray2D backtrace = new IntArray2D(MAX_LENGTH * MAX_LENGTH); ATable amodel; ATable acounts; int l = -1; int m = -1; AlignmentPosteriorGrid m1_post = null; public void setModel1Posteriors(AlignmentPosteriorGrid m1pg) { m1_post = m1pg; } protected HMM(TTable ttable, ATable atable, boolean useNull) { super(ttable, useNull); amodel = atable; acounts = (ATable)amodel.clone(); acounts.clear(); } public HMM(TTable ttable, ATable atable) { super(ttable, false); amodel = atable; acounts = (ATable)amodel.clone(); acounts.clear(); } public void writePartialCounts(OutputCollector<IntWritable,PartialCountContainer> output) throws IOException { super.writePartialCounts(output); PartialCountContainer pcc = new PartialCountContainer(); pcc.setContent(acounts); output.collect(ACOUNT_VOC_ID, pcc); acounts.clear(); } public void buildHMMTables(PhrasePair pp) { int[] es = pp.getE().getWords(); int[] fs = pp.getF().getWords(); l = es.length; m = fs.length; emission.resize(m + 1, l + 1); e_coords.resize(m + 1, l + 1); e_words.resize(m + 1, l + 1); e_words.fill(-1); e_coords.fill(-1); for (int i = 1; i <= l; i++) { int ei = es[i-1]; for (int j = 1; j <= m; j++) { int fj = fs[j-1]; e_coords.set(j, i, i); emission.set(j, i, tmodel.get(ei, fj)); e_words.set(j, i, i - 1); } } //System.out.println("b:\n"+emission); transition.resize(l+1, l+1); transition_coords.resize(l+1, l+1); transition_coords.fill(-1); for (int i_prev = 0; i_prev <= l; i_prev++) { for (int i = 1; i <= l; i++) { transition_coords.set(i_prev, i, amodel.getCoord(i - i_prev, (char)l)); transition.set(i_prev, i, amodel.get(i - i_prev, (char)l)); } } //System.out.println("a:\n"+transition); } public final int getNumStates() { return transition.getSize2(); } public final float getTransitionProb(int s_prev, int s) { return transition.get(s_prev, s); } public final float getEmissionProb(int j, int s) { return emission.get(j, s); } public final void addPartialJumpCountsToATable(ATable ac) { ac.plusEquals(acounts); } @Override public void processTrainingInstance(PhrasePair pp, Reporter r) { if (pp.getE().size() >= amodel.getMaxDist()-1) return; if (pp.getF().size() >= amodel.getMaxDist()-1) return; if (pp.getE().size() == 0) return; if (pp.getF().size() == 0) return; this.buildHMMTables(pp); float totalLogProb = this.baumWelch(pp, null); if (r != null) { r.incrCounter(CrossEntropyCounters.LOGPROB, (long)(-totalLogProb)); r.incrCounter(CrossEntropyCounters.WORDCOUNT, pp.getF().size()); } } /** * @return negative log probability of sentence */ public final float baumWelch(PhrasePair pp, AlignmentPosteriorGrid pg) { initializeCountTableForSentencePair(pp); int[] obs = pp.getF().getWords(); int J = obs.length + 1; int numStates = getNumStates(); int l = pp.getE().getWords().length; float[] anorms = new float[J]; alphas.resize(J + 1, getNumStates()); betas.resize(J + 1, getNumStates()); alphas.set(0, 0, 1.0f); anorms[0]=1.0f; Alignment m1a = null; if (m1_post != null) m1a = m1_post.alignPosteriorThreshold(THRESH); for (int j = 1; j < J; j++) { //System.out.println("J="+j); for (int s = 0; s < numStates; s++) { float alpha = 0.0f; float m1boost = 1.0f; float m1penalty = 0.0f; boolean use_m1 = false; if (m1a != null && m1a.isFAligned(j-1)) { float m1post = 0.0f; use_m1 = true; for (int i=0; i<l; i++) if (m1a.aligned(j-1, i)) m1post = m1_post.getAlignmentPointPosterior(j-1, i+1); //System.out.println(m1post); m1boost = (float)(Math.sqrt(m1post)); m1penalty = 1.0f - m1boost; } for (int s_prev = 0; s_prev < numStates; s_prev++) { float trans = getTransitionProb(s_prev, s); if (use_m1) { if (s <= l && s > 0 && m1a.aligned(j-1, s-1)) trans = m1boost; else trans *= m1penalty; } alpha += alphas.get(j - 1, s_prev) * trans; } alpha *= getEmissionProb(j, s); //System.out.println(" ep:" + hmm.getEmissionProb(s, j)); alphas.set(j, s, alpha); } //anorms[j] = 1.0f; try { anorms[j] = alphas.normalizeColumn(j); } catch (ZeroProbabilityException ex) { this.notifyUnalignablePair(pp, ex.getMessage()); return 0.0f; } } for (int s=1; s<numStates; s++) betas.set(J-1, s, 1.0f); for (int j=J-2; j>=1; j--) { //System.out.println("J="+j); for (int s = 0; s < numStates; s++) { float beta = 0.0f; float m1boost = 1.0f; float m1penalty = 0.0f; boolean use_m1 = false; if (m1a != null && m1a.isFAligned(j-1)) { float m1post = 0.0f; use_m1 = true; for (int i=0; i<l; i++) if (m1a.aligned(j-1, i)) m1post = m1_post.getAlignmentPointPosterior(j-1, i+1); m1boost = (float)(Math.sqrt(m1post)); m1penalty = 1.0f - m1boost; } for (int s_next = 0; s_next < numStates; s_next++) { //System.out.println(" s_next="+s_next + " b(j+1,s_next)="+ betas.get(j+1, s_next) + " * " + // hmm.getTransitionProb(s, s_next) + " * " + hmm.getEmissionProb(s_next, j)); float trans = getTransitionProb(s, s_next); if (use_m1) { if (s <= l && s > 0 && m1a.aligned(j-1, s-1)) trans = m1boost; else trans *= m1penalty; } beta += betas.get(j+1, s_next) * trans * getEmissionProb(j+1, s_next); } beta /= anorms[j]; //System.out.println(" s="+s+ " b:"+beta); betas.set(j, s, beta); } } // PARTIAL COUNTS FOR EMMISSIONS (WORD TRANSLATION) float totalProb[] = new float[J]; for (int j=1; j<J; j++) { float tp = 0.0f; for (int s = 0; s < numStates; s++) { tp += betas.get(j, s) * alphas.get(j, s); } // System.out.println("total prob(" + j + ")=" + tp); totalProb[j] = tp; for (int s = 0; s < numStates; s++) { // j=1 s=14 int iplus1 = e_coords.get(j, s); if (iplus1 == -1) continue; float pc = betas.get(j, s) * alphas.get(j, s) / tp; if (pg != null) { int e = 0; if (s <= l) e = s; if (s != 0) { float p = pg.getAlignmentPointPosterior(j-1, e) + pc; pg.setAlignmentPointPosterior(j-1, e, p); } } else { try { addTranslationCount(iplus1, j-1, pc); } catch (Exception e) { throw new RuntimeException("J=" + J + ", numStates=" + numStates +": Failed to add (" +iplus1+","+(j-1)+") += " + pc + " s=" + s + " pp=" + pp + "\n E:\n"+ e_coords); } } //System.out.println("ec="+ec+" pc="+pc); } } // PARTIAL COUNTS FOR TRANSITIONS if (pg == null) { for (int j=1; j<J-1; j++) { for (int s_prev=0; s_prev < numStates; s_prev++) { for (int s=0; s < numStates; s++) { int tc = transition_coords.get(s_prev, s); if (tc == -1) continue; float m1boost = 1.0f; float m1penalty = 0.0f; boolean use_m1 = false; if (m1a != null && m1a.isFAligned(j-1)) { float m1post = 0.0f; use_m1 = true; for (int i=0; i<l; i++) if (m1a.aligned(j-1, i)) m1post = m1_post.getAlignmentPointPosterior(j-1, i+1); m1boost = (float)(Math.sqrt(m1post)); m1penalty = 1.0f - m1boost; } float trans = getTransitionProb(s_prev, s); if (use_m1) { if (s <= l && s > 0 && m1a.aligned(j-1, s-1)) trans = m1boost; else trans *= m1penalty; } // SKIPPING: REMOVE!!! if (use_m1) continue; float pc = alphas.get(j, s_prev) * trans * emission.get(j+1, s) / anorms[j+1] * betas.get(j+1, s) / totalProb[j+1]; acounts.add(tc, (char)l, pc); //System.out.println("tc="+tc+" pc="+pc); } } } } float tlp = 0.0f; for (float n : anorms) tlp += Math.log(n); return tlp; //System.out.println(acounts); // System.out.println(alphas + "\n" + betas); } @Override public AlignmentPosteriorGrid computeAlignmentPosteriors(PhrasePair pp) { AlignmentPosteriorGrid res = new AlignmentPosteriorGrid(pp); buildHMMTables(pp); baumWelch(pp, res); return res; } @Override public Alignment viterbiAlign(PhrasePair sentence, PerplexityReporter reporter) { this.buildHMMTables(sentence); Alignment res = new Alignment(sentence.getF().size(), sentence.getE().size()); int J = sentence.getF().size() + 1; int numStates = getNumStates(); viterbi.resize(J, getNumStates()); backtrace.resize(J, getNumStates()); viterbi.fill(Float.NEGATIVE_INFINITY); viterbi.set(0, 0, 0.0f); int lene = sentence.getE().getWords().length; Alignment m1a = null; if (m1_post != null) m1a = m1_post.alignPosteriorThreshold(THRESH); //System.out.println(emission); for (int j = 1; j < J; j++) { //System.out.println("J="+j); boolean valid = false; for (int s = 1; s < numStates; s++) { float best = Float.NEGATIVE_INFINITY; int best_s = -1; double emitLogProb = Math.log(emission.get(j, s)); if (emitLogProb == Float.NEGATIVE_INFINITY) { //System.out.println("BAD STATE: " + j + " " + s); continue; } //System.out.println("j="+j + " s="+s+ " ep"+emitLogProb); for (int s_prev = 0; s_prev < numStates; s_prev++) { float m1boost = 1.0f; float m1penalty = 0.0f; boolean use_m1 = false; if (m1a != null && m1a.isFAligned(j-1)) { float m1post = 0.0f; use_m1 = true; for (int i=0; i<lene; i++) { if (m1a.aligned(j-1, i)) m1post = m1_post.getAlignmentPointPosterior(j-1, i+1); } m1boost = (float)Math.sqrt(m1post); m1penalty = 1.0f - m1boost; } float trans = getTransitionProb(s_prev, s); if (use_m1) { if (s <= l && s > 0 && m1a.aligned(j-1, s-1)) trans = m1boost; else trans *= m1penalty; } float cur = (float)(viterbi.get(j - 1, s_prev) + Math.log(trans) + emitLogProb); //System.out.println(" s'="+s_prev + " cur="+cur); if (cur > best) { best = cur; best_s = s_prev; //System.out.println("new best: " + s + " " + best_s); } } //System.out.println(" s_best="+best_s + " cur="+best); viterbi.set(j, s, best); if (best != Float.NEGATIVE_INFINITY) valid = true; backtrace.set(j, s, best_s); } // if we don't know how to generate some column // create a uniform distribution over the states // and assume the previous state was the best if (!valid) { float best = Float.NEGATIVE_INFINITY; int bests = -1; for (int s = 1; s < numStates; s++) { if (viterbi.get(j-1, s) > best) { best = viterbi.get(j-1, s); bests = s; } } for (int s = 1; s < numStates; s++) { viterbi.set(j, s, 0.0f); backtrace.set(j, s, bests); } } } //System.out.println(viterbi); float best = Float.NEGATIVE_INFINITY; int best_s = -1; for (int s = 1; s < numStates; s++) { if (viterbi.get(J-1, s) > best) { best = viterbi.get(J-1,s); best_s = s; } } //System.out.println("vit: " + best + "j-1="+(J-1)); reporter.addFactor(best, J - 1); //System.out.println(viterbi); int e = best_s; for (int f=J-1; f>0; f--) { if (e <= 0) { throw new ZeroProbabilityException(" Error f=" +f+" e="+e+ " sentence + \n" + viterbi + "\n" + emission + "\n" + transition + "\n" + backtrace); } else { if (viterbi.get(f, e) < 0.0) { // hack to avoid errors try { int af = f-1; int ae = e_words.get(f, e); if (ae >= 0) res.align(af, ae); //else // System.err.println("ALIGN NULL TO " + af); } catch (RuntimeException ex) { throw new RuntimeException("Caught " + ex + "\nvit(f,e)="+viterbi.get(f,e)+" size(f,e)=" + sentence.getF().size() +","+ sentence.getE().size() + " Error f=" +f+" e="+e+ " sentence + \n" + viterbi + "\n" + emission + "\n" + transition + "\n" + backtrace + "\n" + e_words); } } e = backtrace.get(f, e); } } return res; } }