package edu.berkeley.cs.nlp.ocular.train; import static edu.berkeley.cs.nlp.ocular.train.ModelPathMaker.makeFontPath; import static edu.berkeley.cs.nlp.ocular.train.ModelPathMaker.makeGsmPath; import static edu.berkeley.cs.nlp.ocular.train.ModelPathMaker.makeLmPath; import static edu.berkeley.cs.nlp.ocular.util.Tuple2.Tuple2; import static edu.berkeley.cs.nlp.ocular.util.Tuple3.Tuple3; import java.io.File; import edu.berkeley.cs.nlp.ocular.font.Font; import edu.berkeley.cs.nlp.ocular.gsm.GlyphSubstitutionModel; import edu.berkeley.cs.nlp.ocular.lm.CodeSwitchLanguageModel; import edu.berkeley.cs.nlp.ocular.main.InitializeFont; import edu.berkeley.cs.nlp.ocular.main.InitializeGlyphSubstitutionModel; import edu.berkeley.cs.nlp.ocular.main.InitializeLanguageModel; import edu.berkeley.cs.nlp.ocular.util.Tuple2; import edu.berkeley.cs.nlp.ocular.util.Tuple3; /** * @author Dan Garrette (dhgarrette@gmail.com) */ public class TrainingRestarter { /** * If requested, try and pick up where we left off */ public Tuple2<Integer, Tuple3<Font, CodeSwitchLanguageModel, GlyphSubstitutionModel>> getRestartModels( Font inputFont, CodeSwitchLanguageModel inputLm, GlyphSubstitutionModel inputGsm, boolean updateLM, boolean updateGsm, String outputPath, int numEMIters, int numUsableDocs, int updateDocBatchSize, boolean noUpdateIfBatchTooSmall) { int lastCompletedIteration = 0; String fontPath = null; int lastBatchNumOfIteration = getLastBatchNumOfIteration(numUsableDocs, updateDocBatchSize, noUpdateIfBatchTooSmall); for (int iter = 1; iter <= numEMIters; ++iter) { fontPath = makeFontPath(outputPath, iter, lastBatchNumOfIteration); if (new File(fontPath).exists()) { lastCompletedIteration = iter; } } Font newFont = inputFont; CodeSwitchLanguageModel newLm = inputLm; GlyphSubstitutionModel newGsm = inputGsm; if (lastCompletedIteration == numEMIters) { System.out.println("All iterations are already complete!"); } else if (lastCompletedIteration > 0) { System.out.println("Last completed iteration: "+lastCompletedIteration); if (fontPath != null) { String lastFontPath = makeFontPath(outputPath, lastCompletedIteration, lastBatchNumOfIteration); System.out.println(" Loading font of last completed iteration: "+lastFontPath); newFont = InitializeFont.readFont(lastFontPath); } if (updateLM) { String lastLmPath = makeLmPath(outputPath, lastCompletedIteration, lastBatchNumOfIteration); System.out.println(" Loading lm of last completed iteration: "+lastLmPath); newLm = InitializeLanguageModel.readCodeSwitchLM(lastLmPath); } if (updateGsm) { String lastGsmPath = makeGsmPath(outputPath, lastCompletedIteration, lastBatchNumOfIteration); System.out.println(" Loading gsm of last completed iteration: "+lastGsmPath); newGsm = InitializeGlyphSubstitutionModel.readGSM(lastGsmPath); } } else { System.out.println("No completed iterations found"); } return Tuple2(lastCompletedIteration, Tuple3(newFont,newLm,newGsm)); } private int getLastBatchNumOfIteration(int numUsableDocs, int updateDocBatchSize, boolean noUpdateIfBatchTooSmall) { int completedBatchesInIteration = 0; int currentBatchSize = 0; for (int docNum = 0; docNum < numUsableDocs; ++docNum) { ++currentBatchSize; if (FontTrainer.isBatchComplete(numUsableDocs, docNum, currentBatchSize, updateDocBatchSize, noUpdateIfBatchTooSmall)) { ++completedBatchesInIteration; } } return completedBatchesInIteration; } }