package edu.berkeley.cs.nlp.ocular.eval;
import static edu.berkeley.cs.nlp.ocular.util.Tuple2.Tuple2;
import java.io.File;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Calendar;
import java.util.List;
import java.util.Map;
import java.util.Set;
import edu.berkeley.cs.nlp.ocular.data.Document;
import edu.berkeley.cs.nlp.ocular.eval.Evaluator.EvalSuffStats;
import edu.berkeley.cs.nlp.ocular.font.Font;
import edu.berkeley.cs.nlp.ocular.gsm.GlyphSubstitutionModel;
import edu.berkeley.cs.nlp.ocular.lm.CodeSwitchLanguageModel;
import edu.berkeley.cs.nlp.ocular.main.FonttrainTranscribeShared.OutputFormat;
import edu.berkeley.cs.nlp.ocular.model.CharacterTemplate;
import edu.berkeley.cs.nlp.ocular.model.DecodeState;
import edu.berkeley.cs.nlp.ocular.model.DecoderEM;
import edu.berkeley.cs.nlp.ocular.model.em.DenseBigramTransitionModel;
import edu.berkeley.cs.nlp.ocular.train.FontTrainer;
import edu.berkeley.cs.nlp.ocular.util.Tuple2;
import tberg.murphy.indexer.Indexer;
/**
* Transcribe all document, write their results to files, and evaluate the results.
*
* @author Dan Garrette (dhgarrette@gmail.com)
*/
public class BasicMultiDocumentTranscriber implements MultiDocumentTranscriber {
private List<Document> documents;
private String inputDocPath;
private String outputPath;
private Set<OutputFormat> outputFormats;
private DecoderEM decoderEM;
private SingleDocumentEvaluatorAndOutputPrinter docOutputPrinterAndEvaluator;
private Indexer<String> charIndexer;
private boolean skipFailedDocs;
public BasicMultiDocumentTranscriber(
List<Document> documents, String inputDocPath, String outputPath, Set<OutputFormat> outputFormats,
DecoderEM decoderEM,
SingleDocumentEvaluatorAndOutputPrinter documentOutputPrinterAndEvaluator,
Indexer<String> charIndexer,
boolean skipFailedDocs) {
this.documents = documents;
this.inputDocPath = inputDocPath;
this.outputPath = outputPath;
this.outputFormats = outputFormats;
this.decoderEM = decoderEM;
this.docOutputPrinterAndEvaluator = documentOutputPrinterAndEvaluator;
this.charIndexer = charIndexer;
this.skipFailedDocs = skipFailedDocs;
}
public void transcribe(Font font, CodeSwitchLanguageModel lm, GlyphSubstitutionModel gsm) {
transcribe(0, 0, font, lm, gsm);
}
public void transcribe(int iter, int batchId, Font font, CodeSwitchLanguageModel lm, GlyphSubstitutionModel gsm) {
int numDocs = documents.size();
CharacterTemplate[] templates = FontTrainer.loadTemplates(font, charIndexer);
DenseBigramTransitionModel backwardTransitionModel = new DenseBigramTransitionModel(lm);
double totalJointLogProb = 0.0;
List<Tuple2<String, Map<String, EvalSuffStats>>> allDiplomaticEvals = new ArrayList<Tuple2<String, Map<String, EvalSuffStats>>>();
List<Tuple2<String, Map<String, EvalSuffStats>>> allNormalizedEvals = new ArrayList<Tuple2<String, Map<String, EvalSuffStats>>>();
for (int docNum = 0; docNum < numDocs; ++docNum) {
Document doc = documents.get(docNum);
System.out.println((iter > 0 ? "Training iteration "+iter+", " : "") + (batchId > 0 ? "batch "+batchId+", " : "") + "Transcribing eval document "+(docNum+1)+" of "+numDocs+": "+doc.baseName() + " " + (new SimpleDateFormat("yyyy/MM/dd HH:mm:ss").format(Calendar.getInstance().getTime())));
try {
Tuple2<DecodeState[][], Double> decodeResults = decoderEM.computeEStep(doc, false, lm, gsm, templates, backwardTransitionModel);
final DecodeState[][] decodeStates = decodeResults._1;
totalJointLogProb += decodeResults._2;
Tuple2<Map<String, EvalSuffStats>,Map<String, EvalSuffStats>> evals = docOutputPrinterAndEvaluator.evaluateAndPrintTranscription(iter, batchId, doc, decodeStates, inputDocPath, outputPath, outputFormats, lm);
if (evals._1 != null) allDiplomaticEvals.add(Tuple2(doc.baseName(), evals._1));
if (evals._2 != null) allNormalizedEvals.add(Tuple2(doc.baseName(), evals._2));
} catch(RuntimeException e) {
if (skipFailedDocs) {
System.err.println("DOCUMENT FAILED! Skipping " + doc.baseName());
e.printStackTrace();
} else {
throw e;
}
}
}
double avgLogProb = totalJointLogProb / numDocs;
System.out.println("Iteration "+iter+", batch "+batchId+": eval avg joint log prob: " + avgLogProb);
if (new File(inputDocPath).isDirectory()) {
//Document doc = documents.get(0);
//String fileParent = FileUtil.removeCommonPathPrefixOfParents(new File(inputDocPath), new File(doc.baseName()))._2;
String preext = "eval";
String outputFilenameBase = outputPath + "/all_transcriptions/" + new File(inputDocPath).getName() + "/" + preext;
if (iter > 0) outputFilenameBase += "_iter-" + iter;
if (batchId > 0) outputFilenameBase += "_batch-" + batchId;
if (!allDiplomaticEvals.isEmpty())
EvalPrinter.printEvaluation(allDiplomaticEvals, outputFilenameBase + "_diplomatic.txt");
if (!allNormalizedEvals.isEmpty())
EvalPrinter.printEvaluation(allNormalizedEvals, outputFilenameBase + "_normalized.txt");
}
}
}