/**
* Copyright 2007 DFKI GmbH.
* All Rights Reserved. Use is subject to license terms.
*
* This file is part of MARY TTS.
*
* MARY TTS is free software: you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as published by
* the Free Software Foundation, version 3 of the License.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*
*/
package marytts.machinelearning;
/**
* Wrapper class for K-Means clustering training parameters
*
* @author Oytun Türk
*/
public class KMeansClusteringTrainerParams {
// A set of default values for K-Means training parameters
public static final int KMEANS_MAX_ITERATIONS_DEFAULT = 200;
public static final double KMEANS_MIN_CLUSTER_CHANGE_PERCENT_DEFAULT = 0.0001;
public static final boolean KMEANS_IS_DIAGONAL_COVARIANCE_DEFAULT = true;
public static final int KMEANS_MIN_SAMPLES_IN_ONE_CLUSTER_DEFAULT = 10;
private static final double KMEANS_MIN_COVARIANCE_ALLOWED_DEFAULT = 1e-5;
//
public int numClusters; // Number of clusters to be trained
public int maxIterations; // Maximum iterations to stop K-means training
public double minClusterChangePercent; // Minimum percent change in cluster assignments to stop K-Means iterations
public boolean isDiagonalOutputCovariance; // Estimate diagonal cluster covariances finally?
public int minSamplesInOneCluster; // Minimum number of observations allowed in one cluster
public double minCovarianceAllowed; // Minimum covariance value allowed for final cluster covariance matrices
public double[] globalVariances; // Global variance vector of whole data
// Default constructor
public KMeansClusteringTrainerParams() {
numClusters = 0;
maxIterations = KMEANS_MAX_ITERATIONS_DEFAULT;
minClusterChangePercent = KMEANS_MIN_CLUSTER_CHANGE_PERCENT_DEFAULT;
isDiagonalOutputCovariance = KMEANS_IS_DIAGONAL_COVARIANCE_DEFAULT;
minSamplesInOneCluster = KMEANS_MIN_SAMPLES_IN_ONE_CLUSTER_DEFAULT;
minCovarianceAllowed = KMEANS_MIN_COVARIANCE_ALLOWED_DEFAULT;
globalVariances = null;
}
// Constructor using GMM training parameters
public KMeansClusteringTrainerParams(GMMTrainerParams gmmParams) {
numClusters = gmmParams.totalComponents;
maxIterations = gmmParams.kmeansMaxIterations;
minClusterChangePercent = gmmParams.kmeansMinClusterChangePercent;
isDiagonalOutputCovariance = gmmParams.isDiagonalCovariance;
minSamplesInOneCluster = gmmParams.kmeansMinSamplesInOneCluster;
minCovarianceAllowed = gmmParams.minCovarianceAllowed;
globalVariances = null;
}
// Constructor using an existing parameter set
public KMeansClusteringTrainerParams(KMeansClusteringTrainerParams existing) {
numClusters = existing.numClusters;
maxIterations = existing.maxIterations;
minClusterChangePercent = existing.minClusterChangePercent;
isDiagonalOutputCovariance = existing.isDiagonalOutputCovariance;
minSamplesInOneCluster = existing.minSamplesInOneCluster;
setGlobalVariances(existing.globalVariances);
}
// Set global variance values
public void setGlobalVariances(double[] globalVariancesIn) {
if (globalVariancesIn != null) {
if (globalVariances == null || globalVariancesIn.length != globalVariances.length)
globalVariances = new double[globalVariancesIn.length];
System.arraycopy(globalVariancesIn, 0, globalVariances, 0, globalVariancesIn.length);
}
}
}