package joshua.discriminative.training.learning_algorithm; /** * @author Zhifei Li, <zhifei.work@gmail.com> * @version $LastChangedDate: 2008-10-20 00:12:30 -0400 $ */ public abstract class DefaultEM { //==== stop criterion private int maxNumIter = 100;//run at most 100 iterations (i.e., number of funcion and gradient evaluation) for this particular run private double relativeLikelihoodThreshold = 1e-5;//if the relative change of the function value is smaller than this value, then we terminate private int maxConvergeNum = 1000;//if the number of times that the likelihood does not change, then stop private double lastLikelihood; public abstract void runOneEMStep(int iterNum); public abstract boolean isEMConverged(); public abstract double getLastLikelihood(); public abstract void printStatistics(int iter_num); public DefaultEM(int maxNumIter_, double relativeLikelihoodThreshold_, int maxConvergeNum_){ maxNumIter = maxNumIter_; relativeLikelihoodThreshold = relativeLikelihoodThreshold_; maxConvergeNum = maxConvergeNum_; } public void runEM(){ System.out.println("================ beging to run EM ======================="); int numCalls=0; lastLikelihood=Double.NEGATIVE_INFINITY; int checkConverge=0; while (numCalls==0 || ( isEMConverged() == false ) && ( numCalls <= maxNumIter ) ){ numCalls++; System.out.println("================ run iteration " + numCalls + "======================="); double tLikelihood = getLastLikelihood(); if(tLikelihood<lastLikelihood) { System.out.println("EM returns a bad optimal value; best: " + tLikelihood + "; last: " + lastLikelihood); System.exit(1); } //=== another way to terminate the em if( numCalls!=0 && Math.abs(lastLikelihood-tLikelihood)/lastLikelihood<relativeLikelihoodThreshold){ checkConverge++; if(checkConverge>=maxConvergeNum){//does not change for several consecutive times System.out.println("EM early stops because the likelihood does not change for several iterationss; break at iter " + numCalls); break; } }else{ checkConverge=0; } lastLikelihood = tLikelihood; runOneEMStep(numCalls); //printStatistics(num_calls, last_function_val, gradient_vector, weights_vector); } printStatistics(numCalls); } }