/**
* Copyright 2007 DFKI GmbH.
* All Rights Reserved. Use is subject to license terms.
*
* This file is part of MARY TTS.
*
* MARY TTS is free software: you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as published by
* the Free Software Foundation, version 3 of the License.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*
*/
package marytts.machinelearning;
import java.io.IOException;
import java.util.Iterator;
import java.util.Set;
import marytts.modules.phonemiser.Allophone;
import marytts.modules.phonemiser.AllophoneSet;
import marytts.util.io.MaryRandomAccessFile;
import marytts.util.math.MathUtils;
import marytts.util.signal.SignalProcUtils;
import marytts.util.string.StringUtils;
/**
*
* Wrapper for contextual parameters for GMM training - includes various phone identity or class based groups
*
* @author Oytun Türk
*/
public class ContextualGMMParams {
public static final int FRICATIVE = Integer.parseInt("000000000001", 2);
public static final int GLIDE = Integer.parseInt("000000000010", 2);
public static final int LIQUID = Integer.parseInt("000000000100", 2);
public static final int NASAL = Integer.parseInt("000000001000", 2);
public static final int PAUSE = Integer.parseInt("000000010000", 2);
public static final int PLOSIVE = Integer.parseInt("000000100000", 2);
public static final int SONORANT = Integer.parseInt("000001000000", 2);
public static final int SYLLABIC = Integer.parseInt("000010000000", 2);
public static final int VOICED = Integer.parseInt("000100000000", 2);
public static final int VOWEL = Integer.parseInt("001000000000", 2);
public int contextClassificationType;
public static final int NO_PHONEME_CLASS = -1;
public static final int SILENCE_SPEECH = 1;
public static final int VOWEL_SILENCE_CONSONANT = 2;
public static final int PHONOLOGY_CLASS = 3;
public static final int FRICATIVE_GLIDELIQUID_NASAL_PLOSIVE_VOWEL_OTHER = 4;
public static final int PHONEME_IDENTITY = 5;
public static final int FRICATIVE_MULTIPLIER = 1;
public static final int GLIDELIQUID_MULTIPLIER = 1;
public static final int NASAL_MULTIPLIER = 1;
public static final int PLOSIVE_MULTIPLIER = 1;
public static final int VOWEL_MULTIPLIER = 8;
public static final int OTHER_MULTIPLIER = 1;
public static final int CONSONANT_MULTIPLIER = 4;
public static final int SILENCE_MULTIPLIER = 1;
public static final int SPEECH_MULTIPLIER = 8;
public String[][] phoneClasses; // Each row corresponds to a String array of phones that are grouped in the same class
public GMMTrainerParams[] classTrainerParams; // Training parameters for each context class
public ContextualGMMParams() {
this(null, null);
}
public ContextualGMMParams(AllophoneSet allophoneSet, GMMTrainerParams commonParams) {
this(allophoneSet, commonParams, NO_PHONEME_CLASS);
}
public ContextualGMMParams(AllophoneSet allophoneSet, GMMTrainerParams commonParams, int contextClassificationTypeIn) {
// To do: Use contextClassificationType to actually create classes here
contextClassificationType = contextClassificationTypeIn;
if (allophoneSet != null) {
Set<String> tmpPhonemes = allophoneSet.getAllophoneNames();
allocate(tmpPhonemes.size());
int count = 0;
Allophone[] phns = new Allophone[tmpPhonemes.size()];
for (Iterator<String> it = tmpPhonemes.iterator(); it.hasNext();) {
phns[count] = allophoneSet.getAllophone(it.next());
count++;
if (count >= tmpPhonemes.size())
break;
}
setClasses(phns, commonParams);
} else {
allocate(0);
}
}
public static Allophone[] getAllophones(AllophoneSet allophoneSet) {
Set<String> tmpPhonemes = allophoneSet.getAllophoneNames();
int count = 0;
Allophone[] phns = new Allophone[tmpPhonemes.size()];
for (Iterator<String> it = tmpPhonemes.iterator(); it.hasNext();) {
phns[count] = allophoneSet.getAllophone(it.next());
count++;
if (count >= tmpPhonemes.size())
break;
}
return phns;
}
public ContextualGMMParams(AllophoneSet allophoneSet, GMMTrainerParams[] params, int contextClassificationTypeIn) {
// To do: Use contextClassificationType to actually create classes here
contextClassificationType = contextClassificationTypeIn;
if (allophoneSet != null) {
Allophone[] phns = getAllophones(allophoneSet);
allocate(phns.length);
setClasses(phns, params);
} else {
allocate(0);
}
}
public ContextualGMMParams(int numPhonemeClasses) {
allocate(numPhonemeClasses);
}
public ContextualGMMParams(ContextualGMMParams existing) {
if (existing != null) {
if (existing.phoneClasses != null && existing.classTrainerParams != null
&& existing.phoneClasses.length == existing.classTrainerParams.length) {
allocate(existing.phoneClasses.length);
setClasses(existing.phoneClasses, existing.classTrainerParams);
} else
allocate(0);
} else {
allocate(0);
}
}
public void allocate(int numPhonemeClasses) {
if (numPhonemeClasses > 0) {
phoneClasses = new String[numPhonemeClasses][];
classTrainerParams = new GMMTrainerParams[numPhonemeClasses];
for (int i = 0; i < numPhonemeClasses; i++)
classTrainerParams[i] = new GMMTrainerParams();
} else {
phoneClasses = null;
classTrainerParams = null;
}
}
public void setClassFromSinglePhoneme(int classIndex, String phone) {
setClassFromSinglePhoneme(classIndex, phone, null);
}
public void setClassFromSinglePhoneme(int classIndex, String phone, GMMTrainerParams currentClassTrainerParams) {
String[] phones = new String[1];
phones[0] = phone;
setClass(classIndex, phones, currentClassTrainerParams);
}
public void setClasses(String[][] phoneClassesIn) {
if (phoneClassesIn != null) {
for (int i = 0; i < phoneClassesIn.length; i++)
setClass(i, phoneClassesIn[i], null);
}
}
public void setClasses(String[][] phoneClassesIn, GMMTrainerParams[] classTrainerParamsIn) {
if (phoneClassesIn != null && classTrainerParamsIn != null) {
for (int i = 0; i < Math.min(phoneClassesIn.length, classTrainerParamsIn.length); i++)
setClass(i, phoneClassesIn[i], classTrainerParamsIn[i]);
}
}
public void setClasses(Allophone[] phns, GMMTrainerParams commonParams) {
GMMTrainerParams[] params = new GMMTrainerParams[1];
params[0] = new GMMTrainerParams(commonParams);
setClasses(phns, params);
}
public void setClasses(Allophone[] phns, GMMTrainerParams[] params) {
if (phns != null) {
int i;
// Print phns to a text file for easy comparison
// StringUtils.writeTextFile(phns, "d:/phns.txt");
//
if (contextClassificationType == NO_PHONEME_CLASS) // All phones go to the same class, this is identical to
// non-contextual GMM training
{
phoneClasses = new String[1][phns.length];
classTrainerParams = new GMMTrainerParams[1];
classTrainerParams[0] = new GMMTrainerParams(params[0]);
for (i = 0; i < phns.length; i++)
phoneClasses[0][i] = phns[i].name();
} else if (contextClassificationType == SILENCE_SPEECH) {
int[] phonologyClasses = getPhonologyClasses(phns);
int[] differentPhonologyClasses = StringUtils.getDifferentItemsList(phonologyClasses);
int[][] inds = new int[2][];
// Silences
inds[0] = findIndices(phonologyClasses, PAUSE);
// Remaining will be inds[1], i.e. speech
int j;
int totalOther = 0;
for (i = 0; i < phonologyClasses.length; i++) {
boolean bFound = false;
for (j = 0; j < inds.length - 1; j++) {
if (MathUtils.find(inds[j], MathUtils.EQUALS, i) != null) {
bFound = true;
break;
}
}
if (!bFound)
totalOther++;
}
int count = 0;
if (totalOther > 0) {
inds[1] = new int[totalOther];
for (i = 0; i < phonologyClasses.length; i++) {
boolean bFound = false;
for (j = 0; j < inds.length - 1; j++) {
if (MathUtils.find(inds[j], MathUtils.EQUALS, i) != null) {
bFound = true;
break;
}
}
if (!bFound) {
inds[1][count] = i;
count++;
}
if (count >= totalOther)
break;
}
}
int total = 0;
for (i = 0; i < inds.length; i++) {
if (inds[i] != null)
total++;
}
phoneClasses = new String[total][];
classTrainerParams = new GMMTrainerParams[total];
count = 0;
for (i = 0; i < total; i++) {
if (inds[i] != null) {
phoneClasses[count] = new String[inds[i].length];
for (j = 0; j < inds[i].length; j++)
phoneClasses[count][j] = phns[inds[i][j]].name();
if (i < params.length)
classTrainerParams[count] = new GMMTrainerParams(params[i]);
else
classTrainerParams[count] = new GMMTrainerParams(params[0]);
if (i == 0)
classTrainerParams[count].totalComponents *= SILENCE_MULTIPLIER;
else if (i == 1)
classTrainerParams[count].totalComponents *= SPEECH_MULTIPLIER;
count++;
}
}
} else if (contextClassificationType == VOWEL_SILENCE_CONSONANT) {
int[] phonologyClasses = getPhonologyClasses(phns);
int[] differentPhonologyClasses = StringUtils.getDifferentItemsList(phonologyClasses);
int[][] inds = new int[3][];
// Vowels
inds[0] = findIndices(phonologyClasses, VOWEL);
// Silences
inds[1] = findIndices(phonologyClasses, PAUSE);
// Remaining will be inds[2], i.e. consonants
int j;
int totalOther = 0;
for (i = 0; i < phonologyClasses.length; i++) {
boolean bFound = false;
for (j = 0; j < inds.length - 1; j++) {
if (MathUtils.find(inds[j], MathUtils.EQUALS, i) != null) {
bFound = true;
break;
}
}
if (!bFound)
totalOther++;
}
int count = 0;
if (totalOther > 0) {
inds[2] = new int[totalOther];
for (i = 0; i < phonologyClasses.length; i++) {
boolean bFound = false;
for (j = 0; j < inds.length - 1; j++) {
if (MathUtils.find(inds[j], MathUtils.EQUALS, i) != null) {
bFound = true;
break;
}
}
if (!bFound) {
inds[2][count] = i;
count++;
}
if (count >= totalOther)
break;
}
}
int total = 0;
for (i = 0; i < inds.length; i++) {
if (inds[i] != null)
total++;
}
phoneClasses = new String[total][];
if (params != null)
classTrainerParams = new GMMTrainerParams[total];
else
classTrainerParams = null;
count = 0;
for (i = 0; i < total; i++) {
if (inds[i] != null) {
phoneClasses[count] = new String[inds[i].length];
for (j = 0; j < inds[i].length; j++)
phoneClasses[count][j] = phns[inds[i][j]].name();
if (params != null) {
if (i < params.length)
classTrainerParams[count] = new GMMTrainerParams(params[i]);
else
classTrainerParams[count] = new GMMTrainerParams(params[0]);
if (i == 0)
classTrainerParams[count].totalComponents *= VOWEL_MULTIPLIER;
else if (i == 1)
classTrainerParams[count].totalComponents *= SILENCE_MULTIPLIER;
else if (i == 2)
classTrainerParams[count].totalComponents *= CONSONANT_MULTIPLIER;
}
count++;
}
}
} else if (contextClassificationType == PHONOLOGY_CLASS) // Each phonology class goes into a separate class, however
// this cannot handle phone replications since labels do
// not have phonology information that could be used in
// transformation phase
{
int[] phonologyClasses = getPhonologyClasses(phns);
int[] differentPhonologyClasses = StringUtils.getDifferentItemsList(phonologyClasses);
phoneClasses = new String[differentPhonologyClasses.length][];
classTrainerParams = new GMMTrainerParams[differentPhonologyClasses.length];
int j;
for (i = 0; i < differentPhonologyClasses.length; i++) {
int[] indices = MathUtils.find(phonologyClasses, MathUtils.EQUALS, differentPhonologyClasses[i]);
phoneClasses[i] = new String[indices.length];
if (i < params.length)
classTrainerParams[i] = new GMMTrainerParams(params[i]);
else
classTrainerParams[i] = new GMMTrainerParams(params[0]);
for (j = 0; j < indices.length; j++)
phoneClasses[i][j] = phns[indices[j]].name();
}
} else if (contextClassificationType == FRICATIVE_GLIDELIQUID_NASAL_PLOSIVE_VOWEL_OTHER) {
int[] phonologyClasses = getPhonologyClasses(phns);
int[] differentPhonologyClasses = StringUtils.getDifferentItemsList(phonologyClasses);
int[][] inds = new int[6][];
// Fricatives
inds[0] = findIndices(phonologyClasses, FRICATIVE);
// Glide or liquids
int[] tmpInds1 = findIndices(phonologyClasses, GLIDE);
int[] tmpInds2 = findIndices(phonologyClasses, LIQUID);
int[] tmpInds = SignalProcUtils.merge(tmpInds1, tmpInds2);
MathUtils.quickSort(tmpInds);
inds[1] = StringUtils.getDifferentItemsList(tmpInds);
// Nasals
inds[2] = findIndices(phonologyClasses, NASAL);
// Plosives
inds[3] = findIndices(phonologyClasses, PLOSIVE);
// Vowels
inds[4] = findIndices(phonologyClasses, VOWEL);
// Remaining will be other in inds[5]
int j;
int totalOther = 0;
for (i = 0; i < phonologyClasses.length; i++) {
boolean bFound = false;
for (j = 0; j < inds.length - 1; j++) {
if (MathUtils.find(inds[j], MathUtils.EQUALS, i) != null) {
bFound = true;
break;
}
}
if (!bFound)
totalOther++;
}
int count = 0;
if (totalOther > 0) {
inds[5] = new int[totalOther];
for (i = 0; i < phonologyClasses.length; i++) {
boolean bFound = false;
for (j = 0; j < inds.length - 1; j++) {
if (MathUtils.find(inds[j], MathUtils.EQUALS, i) != null) {
bFound = true;
break;
}
}
if (!bFound) {
inds[5][count] = i;
count++;
}
if (count >= totalOther)
break;
}
}
int total = 0;
for (i = 0; i < inds.length; i++) {
if (inds[i] != null)
total++;
}
phoneClasses = new String[total][];
classTrainerParams = new GMMTrainerParams[total];
count = 0;
for (i = 0; i < total; i++) {
if (inds[i] != null) {
phoneClasses[count] = new String[inds[i].length];
for (j = 0; j < inds[i].length; j++)
phoneClasses[count][j] = phns[inds[i][j]].name();
if (i < params.length)
classTrainerParams[count] = new GMMTrainerParams(params[i]);
else
classTrainerParams[count] = new GMMTrainerParams(params[0]);
if (i == 0)
classTrainerParams[count].totalComponents *= FRICATIVE_MULTIPLIER;
else if (i == 1)
classTrainerParams[count].totalComponents *= GLIDELIQUID_MULTIPLIER;
else if (i == 2)
classTrainerParams[count].totalComponents *= NASAL_MULTIPLIER;
else if (i == 3)
classTrainerParams[count].totalComponents *= PLOSIVE_MULTIPLIER;
else if (i == 4)
classTrainerParams[count].totalComponents *= VOWEL_MULTIPLIER;
else if (i == 5)
classTrainerParams[count].totalComponents *= OTHER_MULTIPLIER;
count++;
}
}
} else if (contextClassificationType == PHONEME_IDENTITY) // Each phone goes into a separate class, phone replications
// are taken care of
{
String[] allPhonemes = new String[phns.length];
for (i = 0; i < phns.length; i++)
allPhonemes[i] = phns[i].name();
String[] differentPhonemes = StringUtils.getDifferentItemsList(allPhonemes);
phoneClasses = new String[differentPhonemes.length][1];
classTrainerParams = new GMMTrainerParams[differentPhonemes.length];
for (i = 0; i < differentPhonemes.length; i++) {
phoneClasses[i][0] = differentPhonemes[i];
if (i < params.length)
classTrainerParams[i] = new GMMTrainerParams(params[i]);
else
classTrainerParams[i] = new GMMTrainerParams(params[0]);
}
} else {
phoneClasses = null;
classTrainerParams = null;
}
}
}
public static int[] getPhonologyClasses(Allophone[] phns) {
int[] phonologyClasses = null;
if (phns != null) {
phonologyClasses = new int[phns.length];
for (int i = 0; i < phns.length; i++) {
phonologyClasses[i] = 0;
if (phns[i].isFricative())
phonologyClasses[i] += FRICATIVE;
if (phns[i].isGlide())
phonologyClasses[i] += GLIDE;
if (phns[i].isLiquid())
phonologyClasses[i] += LIQUID;
if (phns[i].isNasal())
phonologyClasses[i] += NASAL;
if (phns[i].isPause())
phonologyClasses[i] += PAUSE;
if (phns[i].isPlosive())
phonologyClasses[i] += PLOSIVE;
if (phns[i].isSonorant())
phonologyClasses[i] += SONORANT;
if (phns[i].isSyllabic())
phonologyClasses[i] += SYLLABIC;
if (phns[i].isVoiced())
phonologyClasses[i] += VOICED;
if (phns[i].isVowel())
phonologyClasses[i] += VOWEL;
}
}
return phonologyClasses;
}
public static int[] findIndices(int[] phonologyClasses, int desiredClasses) {
int[] indices = null;
boolean[] desireds = new boolean[phonologyClasses.length];
int i;
int total = 0;
for (i = 0; i < phonologyClasses.length; i++) {
desireds[i] = StringUtils.isDesired(phonologyClasses[i], desiredClasses);
if (desireds[i])
total++;
}
if (total > 0) {
indices = new int[total];
int count = 0;
for (i = 0; i < desireds.length; i++) {
if (desireds[i]) {
indices[count] = i;
count++;
if (count >= total)
break;
}
}
}
return indices;
}
public void setClass(int classIndex, String[] phones) {
setClass(classIndex, phones, null);
}
public void setClass(int classIndex, String[] phones, GMMTrainerParams currentClassTrainerParams) {
if (phoneClasses != null && classTrainerParams != null && classIndex >= 0 && classIndex < phoneClasses.length
&& phoneClasses.length == classTrainerParams.length) {
phoneClasses[classIndex] = null;
if (phones != null) {
phoneClasses[classIndex] = new String[phones.length];
for (int i = 0; i < phones.length; i++)
phoneClasses[classIndex][i] = phones[i];
}
classTrainerParams[classIndex] = new GMMTrainerParams(currentClassTrainerParams);
}
}
// Returns the zero based index of the class the phone belongs to
// If it is not an element of any of the existing classes, -1 is returned
public int getClassIndex(String phone) {
int classInd = -1;
if (phoneClasses != null) {
int i, j;
for (i = 0; i < phoneClasses.length; i++) {
if (phoneClasses[i] != null) {
for (j = 0; j < phoneClasses[i].length; j++) {
if (phone.compareTo(phoneClasses[i][j]) == 0)
return i;
}
}
}
}
return classInd;
}
public void write(MaryRandomAccessFile stream) {
if (stream != null) {
if (phoneClasses != null) {
int i, j;
try {
stream.writeInt(phoneClasses.length);
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
for (i = 0; i < phoneClasses.length; i++) {
if (phoneClasses[i].length > 0) {
try {
stream.writeInt(phoneClasses[i].length);
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
for (j = 0; j < phoneClasses[i].length; j++) {
if (phoneClasses[i][j].length() > 0) {
try {
stream.writeInt(phoneClasses[i][j].length());
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
try {
stream.writeChar(phoneClasses[i][j].toCharArray());
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
} else {
try {
stream.writeInt(0);
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
}
} else {
try {
stream.writeInt(0);
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
}
if (classTrainerParams != null) {
if (classTrainerParams.length > 0) {
try {
stream.writeInt(classTrainerParams.length);
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
for (i = 0; i < classTrainerParams.length; i++)
classTrainerParams[i].write(stream);
} else {
try {
stream.writeInt(0);
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
} else {
try {
stream.writeInt(0);
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
} else {
try {
stream.writeInt(0);
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
}
}
public void read(MaryRandomAccessFile stream) {
if (stream != null) {
int tmpLen = 0;
try {
tmpLen = stream.readInt();
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
if (tmpLen > 0)
phoneClasses = new String[tmpLen][];
if (phoneClasses != null) {
int i, j;
for (i = 0; i < phoneClasses.length; i++) {
tmpLen = 0;
try {
tmpLen = stream.readInt();
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
if (tmpLen > 0)
phoneClasses[i] = new String[tmpLen];
if (phoneClasses[i].length > 0) {
for (j = 0; j < phoneClasses[i].length; j++) {
try {
tmpLen = stream.readInt();
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
if (tmpLen > 0) {
try {
phoneClasses[i][j] = String.copyValueOf(stream.readChar(tmpLen));
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
}
}
}
tmpLen = 0;
try {
tmpLen = stream.readInt();
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
classTrainerParams = null;
if (tmpLen > 0)
classTrainerParams = new GMMTrainerParams[tmpLen];
if (classTrainerParams.length > 0) {
for (i = 0; i < classTrainerParams.length; i++)
classTrainerParams[i] = new GMMTrainerParams(stream);
}
}
}
}
}