/* * Copyright 1999-2002 Carnegie Mellon University. * Portions Copyright 2002 Sun Microsystems, Inc. * Portions Copyright 2002 Mitsubishi Electric Research Laboratories. * All Rights Reserved. Use is subject to license terms. * * See the file "license.terms" for information on usage and * redistribution of this file, and for a DISCLAIMER OF ALL * WARRANTIES. * */ package edu.cmu.sphinx.trainer; import edu.cmu.sphinx.linguist.acoustic.AcousticModel; import edu.cmu.sphinx.linguist.acoustic.UnitManager; import edu.cmu.sphinx.linguist.acoustic.tiedstate.trainer.TrainerAcousticModel; import edu.cmu.sphinx.linguist.acoustic.tiedstate.trainer.TrainerScore; import edu.cmu.sphinx.util.Utilities; import edu.cmu.sphinx.util.props.PropertyException; import edu.cmu.sphinx.util.props.PropertySheet; import edu.cmu.sphinx.util.props.S4Component; import edu.cmu.sphinx.util.props.S4Boolean; import edu.cmu.sphinx.util.props.S4ComponentList; import java.io.IOException; import java.util.List; /** This is a dummy implementation of a TrainManager. */ public class SimpleTrainManager implements TrainManager { @S4Component(type = ControlFile.class) public static final String CONTROL_FILE = "control"; private ControlFile controlFile; private boolean dumpMemoryInfo; @S4Component(type = Learner.class) public static final String LEARNER = "learner"; private Learner learner; @S4Component(type = Learner.class) public static final String INIT_LEARNER = "initLearner"; private Learner initLearner; @S4Component(type = UnitManager.class) public static final String UNIT_MANAGER = "unitManager"; private UnitManager unitManager; @S4ComponentList(type = TrainerAcousticModel.class) public static final String AM_COLLECTION = "models"; private List<? extends TrainerAcousticModel> acousticModels; /** * The property for the boolean property that controls whether or not the recognizer will display detailed * memory information while it is running. The default value is <code>true</code>. */ @S4Boolean(defaultValue = false) public final static String DUMP_MEMORY_INFO = "dumpMemoryInfo"; private int maxIteration; private float minimumImprovement; public void newProperties(PropertySheet ps) throws PropertyException { dumpMemoryInfo = ps.getBoolean(DUMP_MEMORY_INFO); learner = (Learner) ps.getComponent(LEARNER); controlFile = (ControlFile) ps.getComponent(CONTROL_FILE); initLearner = (Learner) ps.getComponent(INIT_LEARNER); minimumImprovement = ps.getFloat(PROP_MINIMUM_IMPROVEMENT); maxIteration = ps.getInt(PROP_MAXIMUM_ITERATION); acousticModels = ps.getComponentList(AM_COLLECTION, TrainerAcousticModel.class); unitManager = (UnitManager)ps.getComponent(UNIT_MANAGER); } /** Do the train. */ public void train() { assert controlFile != null; for (controlFile.startUtteranceIterator(); controlFile.hasMoreUtterances();) { Utterance utterance = controlFile.nextUtterance(); System.out.println(utterance); for (utterance.startTranscriptIterator(); utterance.hasMoreTranscripts();) { System.out.println(utterance.nextTranscript()); } } } /** * Copy the model. * <p> * This method copies to model set, possibly to a new location and new format. This is useful if one wants to * convert from binary to ascii and vice versa, or from a directory structure to a JAR file. If only one model is * used, then name can be null. * * @param context this TrainManager's context * @throws IOException if an error occurs while loading the data */ public void copyModels(String context) throws IOException { loadModels(context); saveModels(context); } /** * Saves the acoustic models. * * @param context the context of this TrainManager * @throws IOException if an error occurs while loading the data */ public void saveModels(String context) throws IOException { if (1 == acousticModels.size()) { acousticModels.get(0).save(null); } else { for (AcousticModel model : acousticModels) { if (model instanceof TrainerAcousticModel) { TrainerAcousticModel tmodel = (TrainerAcousticModel) model; tmodel.save(model.getName()); } } } } /** * Loads the acoustic models. * * @param context the context of this TrainManager */ private void loadModels(String context) throws IOException { dumpMemoryInfo("TrainManager start"); for (TrainerAcousticModel model : acousticModels) { model.load(); } dumpMemoryInfo("acoustic model"); } /** * Initializes the acoustic models. * * @param context the context of this TrainManager */ public void initializeModels(String context) throws IOException { TrainerScore score[]; dumpMemoryInfo("TrainManager start"); for (TrainerAcousticModel model : acousticModels) { for (controlFile.startUtteranceIterator(); controlFile.hasMoreUtterances();) { Utterance utterance = controlFile.nextUtterance(); initLearner.setUtterance(utterance); while ((score = initLearner.getScore()) != null) { assert score.length == 1; model.accumulate(0, score); } } // normalize() has a return value, but we can ignore it here. model.normalize(); } dumpMemoryInfo("acoustic model"); } /** * Trains context independent models. If the initialization stage was skipped, it loads models from files, * automatically. * * @param context the context of this train manager. * @throws IOException if IO went wrong */ public void trainContextIndependentModels(String context) throws IOException { UtteranceGraph uttGraph; TrainerScore[] score; TrainerScore[] nextScore; // If initialization was performed, then learner should not be // null. Otherwise, we need to load the models. if (learner == null) { loadModels(context); } dumpMemoryInfo("TrainManager start"); for (TrainerAcousticModel model : acousticModels) { float logLikelihood; float lastLogLikelihood = Float.MAX_VALUE; float relativeImprovement = 100.0f; for (int iteration = 0; (iteration < maxIteration) && (relativeImprovement > minimumImprovement); iteration++) { System.out.println("Iteration: " + iteration); model.resetBuffers(); for (controlFile.startUtteranceIterator(); controlFile.hasMoreUtterances();) { Utterance utterance = controlFile.nextUtterance(); uttGraph = new UtteranceHMMGraph(context, utterance, model, unitManager); learner.setUtterance(utterance); learner.setGraph(uttGraph); nextScore = null; while ((score = learner.getScore()) != null) { for (int i = 0; i < score.length; i++) { if (i > 0) { model.accumulate(i, score, nextScore); } else { model.accumulate(i, score); } } nextScore = score; } model.updateLogLikelihood(); } logLikelihood = model.normalize(); System.out.println("Loglikelihood: " + logLikelihood); saveModels(context); if (iteration > 0) { if (lastLogLikelihood != 0) { relativeImprovement = (logLikelihood - lastLogLikelihood) / lastLogLikelihood * 100.0f; } else if (lastLogLikelihood == logLikelihood) { relativeImprovement = 0; } System.out.println("Finished iteration: " + iteration + " - Improvement: " + relativeImprovement); } lastLogLikelihood = logLikelihood; } } } /** * Conditional dumps out memory information * * @param what an additional info string */ private void dumpMemoryInfo(String what) { if (dumpMemoryInfo) { Utilities.dumpMemoryInfo(what); } } }