/** * Copyright 2000-2009 DFKI GmbH. * All Rights Reserved. Use is subject to license terms. * * This file is part of MARY TTS. * * MARY TTS is free software: you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as published by * the Free Software Foundation, version 3 of the License. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public License * along with this program. If not, see <http://www.gnu.org/licenses/>. * */ package marytts.machinelearning; import java.util.ArrayList; import java.util.List; /** * * This discretizes values according to a gaussian mixture model (gmm). The result of discretization is the mean of the class that * contributed most probability to a point. * * @author benjaminroth * */ public class GmmDiscretizer implements Discretizer { private GMM mixtureModel; private boolean extraZero; /** * This trains a gaussian mixture model having the specified number of components. * * @param values * the data the model is trained with * @param nrClasses * number of components the mixture will have * @param extraZero * specifies if zeroes are to be treated seperately from mixture model training and application. * @return a discretizer that discretizes according to the trained model */ public static GmmDiscretizer trainDiscretizer(List<Integer> values, int nrClasses, boolean extraZero) { List<Integer> retained = new ArrayList<Integer>(values); Integer zero = new Integer(0); if (extraZero && retained.contains(zero)) { // remove all zeroes int i = 0; while (i < retained.size()) { if (retained.get(i).equals(zero)) retained.remove(i); else i++; } } double[][] trainingData = new double[retained.size()][1]; for (int i = 0; i < retained.size(); i++) { trainingData[i][0] = (double) retained.get(i); } int trainClasses; if (extraZero) { // one class is not trained but assigned trainClasses = nrClasses - 1; } else { trainClasses = nrClasses; } GMMTrainerParams gmmParams = new GMMTrainerParams(); gmmParams.totalComponents = trainClasses; gmmParams.emMinIterations = 1000; gmmParams.emMaxIterations = 2000; GMM model = (new GMMTrainer().train(trainingData, gmmParams)); return new GmmDiscretizer(model, extraZero); } /** * This constructs a {@link Discretizer} using the specified mixture model. * * @param model * GMM to be used * @param extraZeroClass * specifies if zeros should be treated independently */ public GmmDiscretizer(GMM model, boolean extraZeroClass) { // TODO: debugging System.out.println("set model with the following components:"); for (int i = 0; i < model.totalComponents; i++) { System.out.println("component " + i); System.out.println(" mean: " + model.components[i].meanVector[0]); System.out.println(" weight: " + model.weights[i]); System.out.println(" variance: " + model.components[i].covMatrix[0][0]); } this.mixtureModel = model; this.extraZero = extraZeroClass; } /** * This discretizes a value by returning the mean of that gaussian component that has maximum probability for it. * * @param value * the value to be discretized * @return the discretization the value is mapped to * */ public int discretize(int value) { double[] x = new double[] { (double) value }; double[] probs = this.mixtureModel.componentProbabilities(x); if (this.extraZero && value == 0) return 0; int maxClass = 0; double maxP = 0f; for (int i = 0; i < mixtureModel.totalComponents; i++) { if (probs[i] > maxP) { maxClass = i; maxP = probs[i]; } } int maxClassMean = (int) mixtureModel.components[maxClass].meanVector[0]; return maxClassMean; } /** * Returns all poosible discretizations values can be mapped to. * * @return all poosible discretizations values can be mapped to. */ public int[] getPossibleValues() { if (this.extraZero) { // TODO: space for optimization int[] retArr = new int[mixtureModel.components.length + 1]; retArr[0] = 0; for (int i = 0; i < mixtureModel.components.length; i++) { retArr[i + 1] = (int) mixtureModel.components[i].meanVector[0]; } return retArr; } else { int[] retArr = new int[mixtureModel.components.length]; for (int i = 0; i < mixtureModel.components.length; i++) { retArr[i] = (int) mixtureModel.components[i].meanVector[0]; } return retArr; } } }