/** * Copyright 2007 DFKI GmbH. * All Rights Reserved. Use is subject to license terms. * * This file is part of MARY TTS. * * MARY TTS is free software: you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as published by * the Free Software Foundation, version 3 of the License. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public License * along with this program. If not, see <http://www.gnu.org/licenses/>. * */ package marytts.machinelearning; import java.io.IOException; import marytts.util.io.MaryRandomAccessFile; /** * Wrapper class for GMM training parameters * * @author Oytun Türk */ public class GMMTrainerParams { // A set of default values for GMM training parameters public static final int EM_TOTAL_COMPONENTS_DEFAULT = 1; public static final boolean EM_IS_DIAGONAL_COVARIANCE_DEFAULT = true; public static final int EM_MIN_ITERATIONS_DEFAULT = 500; public static final int EM_MAX_ITERATIONS_DEFAULT = 2000; public static final boolean EM_IS_UPDATE_COVARIANCES_DEFAULT = true; public static final double EM_TINY_LOGLIKELIHOOD_CHANGE_PERCENT_DEFAULT = 0.0001; public static final double EM_MIN_COVARIANCE_ALLOWED_DEFAULT = 1e-4; public static final boolean EM_USE_NATIVE_C_LIB_TRAINER_DEFAULT = false; // public int totalComponents; // Total number of Gaussians in the GMM public boolean isDiagonalCovariance; // Estimate diagonal covariance matrices? // Full-covariance training is likely to result in ill-conditioned training due to // insufficient training data public int kmeansMaxIterations; // Minimum number of K-Means iterations to initialize the GMM public double kmeansMinClusterChangePercent; // Maximum number of K-Means iterations to initialize the GMM public int kmeansMinSamplesInOneCluster; // Minimum number of observations in one cluster while initializing the GMM with // K-Means public int emMinIterations; // Minimum number of EM iterations for which the algorithm will not quit // even when the total likelihood does not change much with additional iterations public int emMaxIterations; // Maximum number of EM iterations for which the algorithm will quit // even when total likelihood has not settled yet public boolean isUpdateCovariances; // Update covariance matrices in EM iterations? public double tinyLogLikelihoodChangePercent; // Threshold to compare percent decrease in total log-likelihood to stop // iterations automatically public double minCovarianceAllowed; // Minimum covariance value allowed - should be a small positive number to avoid // ill-conditioned training public boolean useNativeCLibTrainer; // Use native C library trainer (Windows OS only) // Default constructor public GMMTrainerParams() { totalComponents = EM_TOTAL_COMPONENTS_DEFAULT; isDiagonalCovariance = EM_IS_DIAGONAL_COVARIANCE_DEFAULT; kmeansMaxIterations = KMeansClusteringTrainerParams.KMEANS_MAX_ITERATIONS_DEFAULT; kmeansMinClusterChangePercent = KMeansClusteringTrainerParams.KMEANS_MIN_CLUSTER_CHANGE_PERCENT_DEFAULT; kmeansMinSamplesInOneCluster = KMeansClusteringTrainerParams.KMEANS_MIN_SAMPLES_IN_ONE_CLUSTER_DEFAULT; emMinIterations = EM_MIN_ITERATIONS_DEFAULT; emMaxIterations = EM_MAX_ITERATIONS_DEFAULT; isUpdateCovariances = EM_IS_UPDATE_COVARIANCES_DEFAULT; tinyLogLikelihoodChangePercent = EM_TINY_LOGLIKELIHOOD_CHANGE_PERCENT_DEFAULT; minCovarianceAllowed = EM_MIN_COVARIANCE_ALLOWED_DEFAULT; useNativeCLibTrainer = EM_USE_NATIVE_C_LIB_TRAINER_DEFAULT; } // Constructor using an existing parameter set public GMMTrainerParams(GMMTrainerParams existing) { totalComponents = existing.totalComponents; isDiagonalCovariance = existing.isDiagonalCovariance; kmeansMaxIterations = existing.kmeansMaxIterations; kmeansMinClusterChangePercent = existing.kmeansMinClusterChangePercent; kmeansMinSamplesInOneCluster = existing.kmeansMinSamplesInOneCluster; emMinIterations = existing.emMinIterations; emMaxIterations = existing.emMaxIterations; isUpdateCovariances = existing.isUpdateCovariances; tinyLogLikelihoodChangePercent = existing.tinyLogLikelihoodChangePercent; minCovarianceAllowed = existing.minCovarianceAllowed; useNativeCLibTrainer = existing.useNativeCLibTrainer; } // Constructor that reads GMM training parameters from a binary file stream public GMMTrainerParams(MaryRandomAccessFile stream) { read(stream); } // Function to write GMM training parameters to a binary file stream public void write(MaryRandomAccessFile stream) { if (stream != null) { try { stream.writeInt(totalComponents); } catch (IOException e) { // TODO Auto-generated catch block e.printStackTrace(); } try { stream.writeBoolean(isDiagonalCovariance); } catch (IOException e) { // TODO Auto-generated catch block e.printStackTrace(); } try { stream.writeInt(kmeansMaxIterations); } catch (IOException e) { // TODO Auto-generated catch block e.printStackTrace(); } try { stream.writeDouble(kmeansMinClusterChangePercent); } catch (IOException e) { // TODO Auto-generated catch block e.printStackTrace(); } try { stream.writeInt(kmeansMinSamplesInOneCluster); } catch (IOException e) { // TODO Auto-generated catch block e.printStackTrace(); } try { stream.writeInt(emMinIterations); } catch (IOException e) { // TODO Auto-generated catch block e.printStackTrace(); } try { stream.writeInt(emMaxIterations); } catch (IOException e) { // TODO Auto-generated catch block e.printStackTrace(); } try { stream.writeBoolean(isUpdateCovariances); } catch (IOException e) { // TODO Auto-generated catch block e.printStackTrace(); } try { stream.writeDouble(tinyLogLikelihoodChangePercent); } catch (IOException e) { // TODO Auto-generated catch block e.printStackTrace(); } try { stream.writeDouble(minCovarianceAllowed); } catch (IOException e) { // TODO Auto-generated catch block e.printStackTrace(); } try { stream.writeBoolean(useNativeCLibTrainer); } catch (IOException e) { // TODO Auto-generated catch block e.printStackTrace(); } } } // Function that reads GMM training parameters from a binary file stream public void read(MaryRandomAccessFile stream) { if (stream != null) { try { totalComponents = stream.readInt(); } catch (IOException e) { // TODO Auto-generated catch block e.printStackTrace(); } try { isDiagonalCovariance = stream.readBoolean(); } catch (IOException e) { // TODO Auto-generated catch block e.printStackTrace(); } try { kmeansMaxIterations = stream.readInt(); } catch (IOException e) { // TODO Auto-generated catch block e.printStackTrace(); } try { kmeansMinClusterChangePercent = stream.readDouble(); } catch (IOException e) { // TODO Auto-generated catch block e.printStackTrace(); } try { kmeansMinSamplesInOneCluster = stream.readInt(); } catch (IOException e) { // TODO Auto-generated catch block e.printStackTrace(); } try { emMinIterations = stream.readInt(); } catch (IOException e) { // TODO Auto-generated catch block e.printStackTrace(); } try { emMaxIterations = stream.readInt(); } catch (IOException e) { // TODO Auto-generated catch block e.printStackTrace(); } try { isUpdateCovariances = stream.readBoolean(); } catch (IOException e) { // TODO Auto-generated catch block e.printStackTrace(); } try { tinyLogLikelihoodChangePercent = stream.readDouble(); } catch (IOException e) { // TODO Auto-generated catch block e.printStackTrace(); } try { minCovarianceAllowed = stream.readDouble(); } catch (IOException e) { // TODO Auto-generated catch block e.printStackTrace(); } try { useNativeCLibTrainer = stream.readBoolean(); } catch (IOException e) { // TODO Auto-generated catch block e.printStackTrace(); } } } }