// Some algorithms for Hidden Markov Models (Chapter 3): Viterbi, // Forward, Backward, Baum-Welch. We compute with log probabilities. // Notational conventions: // i = 1,...,L indexes x, the observed string, x_0 not a symbol // k,ell = 0,...,hmm.nstate-1 indexes hmm.state(k) a_0 is the start state //Zhenzhen Kou // Notational conventions: // i = 1,...,L indexes x, the observed string, x_0 not a symbol // k,ell = 0,...,hmm.nstate-1 indexes hmm.state(k) a_0 is the start state package edu.cmu.minorthird.classify.sequential; import java.text.DecimalFormat; import java.util.ArrayList; import java.util.Enumeration; import java.util.Hashtable; // Some algorithms for Hidden Markov Models (Chapter 3): Viterbi, // Forward, Backward, Baum-Welch. We compute with log probabilities. public class HMM { // State names and state-to-state transition probabilities int nstate; // number of states (incl initial state) String[] state; // names of the states double[][] amat; // transition matrix double[][] loga; // loga[k][ell] = log(P(k -> ell)) // Emission names and emission probabilities int nesym; // number of emission symbols Hashtable<String,String> esym = new Hashtable<String,String>(); // the emission symbols e1,...,eL (characters) Hashtable<String,String> esym_tok2idx; Hashtable<String,String> esym_idx2tok; double[][] emat; // emision matrix double[][] loge; // loge[k][ei] = log(P(emit ei in state k)) // Input: // state = array of state names (except initial state) // amat = matrix of transition probabilities (except initial state) // esym = string of emission names // emat = matrix of emission probabilities public HMM(String[] state, double[][] amat, Hashtable<String,String> esym, double[][] emat) { if (state.length != amat.length) throw new IllegalArgumentException("HMM: state and amat disagree"); if (amat.length != emat.length) throw new IllegalArgumentException("HMM: amat and emat disagree"); for (int i=0; i<amat.length; i++) { if (state.length != amat[i].length) throw new IllegalArgumentException("HMM: amat non-square"); if (esym.size() != emat[i].length) throw new IllegalArgumentException("HMM: esym and emat disagree"); } // Set up the transition matrix nstate = state.length + 1; this.state = new String[nstate]; loga = new double[nstate][nstate]; this.state[0] = "S"; // initial state // P(start -> start) = 0 loga[0][0] = Double.NEGATIVE_INFINITY; // = log(0) // P(start -> other) = 1.0/state.length double fromstart = Math.log(1.0/state.length); for (int j=1; j<nstate; j++) loga[0][j] = fromstart; for (int i=1; i<nstate; i++) { // Reverse state names for efficient backwards concatenation this.state[i] = new StringBuffer(state[i-1]).reverse().toString(); // System.out.println("state["+i+"] is "+this.state[i]); // P(other -> start) = 0 loga[i][0] = Double.NEGATIVE_INFINITY; // = log(0) for (int j=1; j<nstate; j++) loga[i][j] = Math.log(amat[i-1][j-1]); } this.esym = esym; esym_tok2idx = new Hashtable<String,String>(); esym_idx2tok = new Hashtable<String,String>(); int idx=0; for ( Enumeration<String> e_keys = esym.keys(); e_keys.hasMoreElements();){ String key = e_keys.nextElement(); esym_tok2idx.put(key, String.valueOf(idx) ); esym_idx2tok.put(String.valueOf(idx),key ); idx ++; } for ( Enumeration<String> e_keys = esym_tok2idx.keys(); e_keys.hasMoreElements();){ String key = e_keys.nextElement(); String val = esym_tok2idx.get(key ); System.out.println("in esym_tok2idx: "+key+"<--->"+val); } // Set up the emission matrix nesym = esym.size(); loge = new double[nstate][nesym]; for (int b=0; b<nesym; b++) { loge[0][b] = Double.NEGATIVE_INFINITY; // = log(0) for (int k=0; k<emat.length; k++) loge[k+1][b] = Math.log(emat[k][b]); } } /*public void print(Output out) { printa(out); printe(out); } public void printa(Output out) { out.println("Transition probabilities:"); for (int i=1; i<nstate; i++) { for (int j=1; j<nstate; j++) out.print(fmtlog(loga[i][j])); out.println(); } }*/ public String[] convert_Ob_seq( String[] x) { String[] y = new String[x.length]; for( int i=0; i<x.length;i++){ if( esym_tok2idx.containsKey( x[i] ) ){ y[i]= esym_tok2idx.get( x[i] ); }else{ y[i]= esym_tok2idx.get( "UNSEEN" ); } System.out.println("string "+x[i]+" corresponds to state idx "+y[i]); } return(y); } /*public void printe(Output out) { out.println("Emission probabilities:"); for (int b=0; b<esym_idx2tok.size(); b++) out.print((String)esym_idx2tok.get(String.valueOf(b)) + hdrpad); out.println(); for (int i=1; i<loge.length; i++) { for (int b=0; b<nesym; b++) out.print(fmtlog(loge[i][b])); out.println(); } }*/ private static DecimalFormat fmt = new DecimalFormat("0.000000 "); // private static String hdrpad = " "; public static String fmtlog(double x) { if (x == Double.NEGATIVE_INFINITY) return fmt.format(0); else return fmt.format(Math.exp(x)); } // The Baum-Welch algorithm for estimating HMM parameters for a // given model topology and a family of observed sequences. // Often gets stuck at a non-global minimum; depends on initial guess. // xs is the set of training sequences, here one training sequence is the sequence of index reprensenting tokens // state is the set of HMM state names // esym is the set of emissible symbols public static HMM baumwelch(ArrayList<String[]> xs, String[] state, Hashtable<String,String> esym, final double threshold) { int nstate = state.length; int nseqs = xs.size(); int nesym = esym.size(); Forward[] fwds = new Forward[nseqs]; Backward[] bwds = new Backward[nseqs]; double[] logP = new double[nseqs]; double[][] amat = new double[nstate][]; double[][] emat = new double[nstate][]; // Initially use random transition and emission matrices for (int k=0; k<nstate; k++) { amat[k] = randomdiscrete(nstate); emat[k] = randomdiscrete(nesym); } HMM hmm = new HMM(state, amat, esym, emat); double oldloglikelihood; // Compute Forward and Backward tables for the sequences double loglikelihood = fwdbwd(hmm, xs, fwds, bwds, logP); System.out.println("log likelihood = " + loglikelihood); // hmm.print(new SystemOut()); do { oldloglikelihood = loglikelihood; // Compute estimates for A and E double[][] A = new double[nstate][nstate]; double[][] E = new double[nstate][nesym]; for (int s=0; s<nseqs; s++) { String[] x = xs.get(s); Forward fwd = fwds[s]; Backward bwd = bwds[s]; int L = x.length; double P = logP[s]; // NOT exp. Fixed 2001-08-20 for (int i=0; i<L; i++) { for (int k=0; k<nstate; k++) E[k][Integer.parseInt(x[i])] += exp(fwd.f[i+1][k+1] + bwd.b[i+1][k+1] - P); } for (int i=0; i<L-1; i++) for (int k=0; k<nstate; k++) for (int ell=0; ell<nstate; ell++) A[k][ell] += exp(fwd.f[i+1][k+1] + hmm.loga[k+1][ell+1] + hmm.loge[ell+1][Integer.parseInt(x[i+1])] + bwd.b[i+2][ell+1] - P); } // Estimate new model parameters, i.e. normalize for (int k=0; k<nstate; k++) { double Aksum = 0; for (int ell=0; ell<nstate; ell++) Aksum += A[k][ell]; for (int ell=0; ell<nstate; ell++) amat[k][ell] = A[k][ell] / Aksum; double Eksum = 0; for (int b=0; b<nesym; b++) Eksum += E[k][b]; for (int b=0; b<nesym; b++) emat[k][b] = E[k][b] / Eksum; } // Create new model hmm = new HMM(state, amat, esym, emat); loglikelihood = fwdbwd(hmm, xs, fwds, bwds, logP); System.out.println("log likelihood = " + loglikelihood); // hmm.print(new SystemOut()); } while (Math.abs(oldloglikelihood - loglikelihood) > threshold); return hmm; } private static double fwdbwd(HMM hmm, ArrayList<String[]> xs, Forward[] fwds, Backward[] bwds, double[] logP) { double loglikelihood = 0; for (int s=0; s<xs.size(); s++) { fwds[s] = new Forward(hmm, xs.get(s)); bwds[s] = new Backward(hmm, xs.get(s)); logP[s] = fwds[s].logprob(); loglikelihood += logP[s]; } return loglikelihood; } public static double exp(double x) { if (x == Double.NEGATIVE_INFINITY) return 0; else return Math.exp(x); } // private static double[] uniformdiscrete(int n) { // double[] ps = new double[n]; // for (int i=0; i<n; i++) // ps[i] = 1.0/n; // return ps; // } private static double[] randomdiscrete(int n) { double[] ps = new double[n]; double sum = 0; // Generate random numbers for (int i=0; i<n; i++) { ps[i] = Math.random(); sum += ps[i]; } // Scale to obtain a discrete probability distribution for (int i=0; i<n; i++) ps[i] /= sum; return ps; } }