/**
* Copyright 2007 DFKI GmbH.
* All Rights Reserved. Use is subject to license terms.
*
* This file is part of MARY TTS.
*
* MARY TTS is free software: you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as published by
* the Free Software Foundation, version 3 of the License.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*
*/
package marytts.signalproc.adaptation.codebook;
import java.io.IOException;
import javax.sound.sampled.UnsupportedAudioFileException;
import marytts.signalproc.adaptation.AdaptationUtils;
import marytts.signalproc.adaptation.BaselineAdaptationSet;
import marytts.signalproc.adaptation.BaselineFeatureCollection;
import marytts.signalproc.adaptation.BaselineFeatureExtractor;
import marytts.signalproc.adaptation.BaselinePreprocessor;
import marytts.signalproc.adaptation.BaselineTrainer;
import marytts.signalproc.adaptation.IndexMap;
import marytts.signalproc.adaptation.prosody.PitchMappingFile;
import marytts.signalproc.adaptation.prosody.PitchTrainer;
import marytts.util.io.FileUtils;
import marytts.util.string.StringUtils;
/**
*
* Baseline class for weighted codebook training
*
* @author Oytun Türk
*/
public class WeightedCodebookTrainer extends BaselineTrainer {
public WeightedCodebookTrainerParams wcParams;
public WeightedCodebookOutlierEliminator outlierEliminator;
public WeightedCodebookTrainer(BaselinePreprocessor pp, BaselineFeatureExtractor fe, WeightedCodebookTrainerParams pa) {
super(pp, fe);
wcParams = new WeightedCodebookTrainerParams(pa);
outlierEliminator = new WeightedCodebookOutlierEliminator();
}
// Call this function after initializing the trainer to perform training
public void run() throws IOException, UnsupportedAudioFileException {
if (checkParams()) {
BaselineAdaptationSet sourceTrainingSet = new BaselineAdaptationSet(wcParams.sourceTrainingFolder);
BaselineAdaptationSet targetTrainingSet = new BaselineAdaptationSet(wcParams.targetTrainingFolder);
int[] map = getIndexedMapping(sourceTrainingSet, targetTrainingSet);
train(sourceTrainingSet, targetTrainingSet, map);
}
}
// Validate parameters
public boolean checkParams() {
boolean bContinue = true;
wcParams.trainingBaseFolder = StringUtils.checkLastSlash(wcParams.trainingBaseFolder);
wcParams.sourceTrainingFolder = StringUtils.checkLastSlash(wcParams.sourceTrainingFolder);
wcParams.targetTrainingFolder = StringUtils.checkLastSlash(wcParams.targetTrainingFolder);
FileUtils.createDirectory(wcParams.trainingBaseFolder);
if (!FileUtils.exists(wcParams.trainingBaseFolder) || !FileUtils.isDirectory(wcParams.trainingBaseFolder)) {
System.out.println("Error! Training base folder " + wcParams.trainingBaseFolder + " not found.");
bContinue = false;
}
if (!FileUtils.exists(wcParams.sourceTrainingFolder) || !FileUtils.isDirectory(wcParams.sourceTrainingFolder)) {
System.out.println("Error! Source training folder " + wcParams.sourceTrainingFolder + " not found.");
bContinue = false;
}
if (!FileUtils.exists(wcParams.targetTrainingFolder) || !FileUtils.isDirectory(wcParams.targetTrainingFolder)) {
System.out.println("Error! Target training folder " + wcParams.targetTrainingFolder + " not found.");
bContinue = false;
}
wcParams.temporaryCodebookFile = wcParams.codebookFile + ".temp";
return bContinue;
}
// General purpose training with indexed pairs
// <map> is a vector of same length as sourceItems showing the index of the corresponding target item
// for each source item. This allows to specify the target files in any order, i.e. file names are not required to be in
// alphabetical order
public void train(BaselineAdaptationSet sourceTrainingSet, BaselineAdaptationSet targetTrainingSet, int[] map)
throws IOException, UnsupportedAudioFileException {
if (sourceTrainingSet.items != null && targetTrainingSet.items != null && map != null) {
if (sourceTrainingSet.items.length != targetTrainingSet.items.length || sourceTrainingSet.items.length != map.length) {
throw new RuntimeException("Lengths of source, target and map must be the same");
}
int numItems = sourceTrainingSet.items.length;
if (numItems > 0) {
preprocessor.run(sourceTrainingSet);
preprocessor.run(targetTrainingSet);
int desiredFeatures = wcParams.codebookHeader.vocalTractFeature + BaselineFeatureExtractor.F0_FEATURES
+ BaselineFeatureExtractor.ENERGY_FEATURES;
featureExtractor.run(sourceTrainingSet, wcParams, desiredFeatures);
featureExtractor.run(targetTrainingSet, wcParams, desiredFeatures);
}
WeightedCodebookFeatureCollection fcol = collectFeatures(sourceTrainingSet, targetTrainingSet, map);
learnMapping(fcol, sourceTrainingSet, targetTrainingSet, map);
outlierEliminator.run(wcParams);
deleteTemporaryFiles(fcol, sourceTrainingSet, targetTrainingSet);
}
}
// For parallel training, sourceItems and targetItems should have at least map.length elements (ensured if this function is
// called through train)
public WeightedCodebookFeatureCollection collectFeatures(BaselineAdaptationSet sourceTrainingSet,
BaselineAdaptationSet targetTrainingSet, int[] map) throws IOException {
WeightedCodebookFeatureCollection fcol = new WeightedCodebookFeatureCollection(wcParams, map.length);
int i;
IndexMap imap = null;
if (wcParams.codebookHeader.codebookType == WeightedCodebookFileHeader.FRAMES) {
for (i = 0; i < map.length; i++) {
if (wcParams.codebookHeader.vocalTractFeature == BaselineFeatureExtractor.LSF_FEATURES) {
imap = AdaptationUtils.mapFramesFeatures(sourceTrainingSet.items[i].labelFile,
targetTrainingSet.items[map[i]].labelFile, sourceTrainingSet.items[i].lsfFile,
targetTrainingSet.items[map[i]].lsfFile, wcParams.codebookHeader.vocalTractFeature,
wcParams.labelsToExcludeFromTraining);
} else if (wcParams.codebookHeader.vocalTractFeature == BaselineFeatureExtractor.MFCC_FEATURES_FROM_FILES) {
imap = AdaptationUtils.mapFramesFeatures(sourceTrainingSet.items[i].labelFile,
targetTrainingSet.items[map[i]].labelFile, sourceTrainingSet.items[i].mfccFile,
targetTrainingSet.items[map[i]].mfccFile, wcParams.codebookHeader.vocalTractFeature,
wcParams.labelsToExcludeFromTraining);
}
try {
imap.writeToFile(fcol.indexMapFiles[i]);
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
} else if (wcParams.codebookHeader.codebookType == WeightedCodebookFileHeader.FRAME_GROUPS) {
for (i = 0; i < map.length; i++) {
if (wcParams.codebookHeader.vocalTractFeature == BaselineFeatureExtractor.LSF_FEATURES) {
imap = AdaptationUtils.mapFrameGroupsFeatures(sourceTrainingSet.items[i].labelFile,
targetTrainingSet.items[map[i]].labelFile, sourceTrainingSet.items[i].lsfFile,
targetTrainingSet.items[map[i]].lsfFile, wcParams.codebookHeader.numNeighboursInFrameGroups,
wcParams.codebookHeader.vocalTractFeature, wcParams.labelsToExcludeFromTraining);
} else if (wcParams.codebookHeader.vocalTractFeature == BaselineFeatureExtractor.MFCC_FEATURES_FROM_FILES) {
imap = AdaptationUtils.mapFrameGroupsFeatures(sourceTrainingSet.items[i].labelFile,
targetTrainingSet.items[map[i]].labelFile, sourceTrainingSet.items[i].mfccFile,
targetTrainingSet.items[map[i]].mfccFile, wcParams.codebookHeader.numNeighboursInFrameGroups,
wcParams.codebookHeader.vocalTractFeature, wcParams.labelsToExcludeFromTraining);
}
try {
imap.writeToFile(fcol.indexMapFiles[i]);
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
} else if (wcParams.codebookHeader.codebookType == WeightedCodebookFileHeader.LABELS) {
for (i = 0; i < map.length; i++) {
if (wcParams.codebookHeader.vocalTractFeature == BaselineFeatureExtractor.LSF_FEATURES) {
imap = AdaptationUtils.mapLabelsFeatures(sourceTrainingSet.items[i].labelFile,
targetTrainingSet.items[map[i]].labelFile, sourceTrainingSet.items[i].lsfFile,
targetTrainingSet.items[map[i]].lsfFile, wcParams.codebookHeader.vocalTractFeature,
wcParams.labelsToExcludeFromTraining);
} else if (wcParams.codebookHeader.vocalTractFeature == BaselineFeatureExtractor.MFCC_FEATURES_FROM_FILES) {
imap = AdaptationUtils.mapLabelsFeatures(sourceTrainingSet.items[i].labelFile,
targetTrainingSet.items[map[i]].labelFile, sourceTrainingSet.items[i].mfccFile,
targetTrainingSet.items[map[i]].mfccFile, wcParams.codebookHeader.vocalTractFeature,
wcParams.labelsToExcludeFromTraining);
}
try {
imap.writeToFile(fcol.indexMapFiles[i]);
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
} else if (wcParams.codebookHeader.codebookType == WeightedCodebookFileHeader.LABEL_GROUPS) {
for (i = 0; i < map.length; i++) {
if (wcParams.codebookHeader.vocalTractFeature == BaselineFeatureExtractor.LSF_FEATURES) {
imap = AdaptationUtils.mapLabelGroupsFeatures(sourceTrainingSet.items[i].labelFile,
targetTrainingSet.items[map[i]].labelFile, sourceTrainingSet.items[i].lsfFile,
targetTrainingSet.items[map[i]].lsfFile, wcParams.codebookHeader.numNeighboursInLabelGroups,
wcParams.codebookHeader.vocalTractFeature, wcParams.labelsToExcludeFromTraining);
} else if (wcParams.codebookHeader.vocalTractFeature == BaselineFeatureExtractor.MFCC_FEATURES_FROM_FILES) {
imap = AdaptationUtils.mapLabelGroupsFeatures(sourceTrainingSet.items[i].labelFile,
targetTrainingSet.items[map[i]].labelFile, sourceTrainingSet.items[i].mfccFile,
targetTrainingSet.items[map[i]].mfccFile, wcParams.codebookHeader.numNeighboursInLabelGroups,
wcParams.codebookHeader.vocalTractFeature, wcParams.labelsToExcludeFromTraining);
}
try {
imap.writeToFile(fcol.indexMapFiles[i]);
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
} else if (wcParams.codebookHeader.codebookType == WeightedCodebookFileHeader.SPEECH) {
if (wcParams.codebookHeader.vocalTractFeature == BaselineFeatureExtractor.LSF_FEATURES)
imap = AdaptationUtils.mapSpeechFeatures();
else if (wcParams.codebookHeader.vocalTractFeature == BaselineFeatureExtractor.MFCC_FEATURES_FROM_FILES)
imap = AdaptationUtils.mapSpeechFeatures();
try {
imap.writeToFile(fcol.indexMapFiles[0]);
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
return fcol;
}
public void learnMapping(BaselineFeatureCollection fcol, BaselineAdaptationSet sourceTrainingSet,
BaselineAdaptationSet targetTrainingSet, int[] map) throws IOException {
assert fcol instanceof WeightedCodebookFeatureCollection;
learnMapping((WeightedCodebookFeatureCollection) fcol, sourceTrainingSet, targetTrainingSet, map);
}
// This function generates the codebooks from training pairs
public void learnMapping(WeightedCodebookFeatureCollection fcol, BaselineAdaptationSet sourceTrainingSet,
BaselineAdaptationSet targetTrainingSet, int[] map) throws IOException {
WeightedCodebookFeatureMapper featureMapper = null;
if (wcParams.codebookHeader.vocalTractFeature == BaselineFeatureExtractor.LSF_FEATURES)
featureMapper = new WeightedCodebookLsfMapper(wcParams);
else if (wcParams.codebookHeader.vocalTractFeature == BaselineFeatureExtractor.MFCC_FEATURES_FROM_FILES)
featureMapper = new WeightedCodebookMfccMapper(wcParams);
if (featureMapper != null) {
WeightedCodebookFile temporaryCodebookFile = new WeightedCodebookFile(wcParams.temporaryCodebookFile,
WeightedCodebookFile.OPEN_FOR_WRITE);
PitchMappingFile pitchMappingFile = new PitchMappingFile(wcParams.pitchMappingFile, PitchMappingFile.OPEN_FOR_WRITE);
if (wcParams.codebookHeader.codebookType == WeightedCodebookFileHeader.FRAMES)
featureMapper.learnMappingFrames(temporaryCodebookFile, (WeightedCodebookFeatureCollection) fcol,
sourceTrainingSet, targetTrainingSet, map);
else if (wcParams.codebookHeader.codebookType == WeightedCodebookFileHeader.FRAME_GROUPS)
featureMapper.learnMappingFrameGroups(temporaryCodebookFile, (WeightedCodebookFeatureCollection) fcol,
sourceTrainingSet, targetTrainingSet, map);
else if (wcParams.codebookHeader.codebookType == WeightedCodebookFileHeader.LABELS)
featureMapper.learnMappingLabels(temporaryCodebookFile, (WeightedCodebookFeatureCollection) fcol,
sourceTrainingSet, targetTrainingSet, map);
else if (wcParams.codebookHeader.codebookType == WeightedCodebookFileHeader.LABEL_GROUPS)
featureMapper.learnMappingLabelGroups(temporaryCodebookFile, (WeightedCodebookFeatureCollection) fcol,
sourceTrainingSet, targetTrainingSet, map);
else if (wcParams.codebookHeader.codebookType == WeightedCodebookFileHeader.SPEECH)
featureMapper.learnMappingSpeech(temporaryCodebookFile, (WeightedCodebookFeatureCollection) fcol,
sourceTrainingSet, targetTrainingSet, map);
temporaryCodebookFile.close();
PitchTrainer ptcTrainer = new PitchTrainer(wcParams);
ptcTrainer.learnMapping(pitchMappingFile, (WeightedCodebookFeatureCollection) fcol, sourceTrainingSet,
targetTrainingSet, map);
pitchMappingFile.close();
} else
System.out.println("Error! Specified feature mapper does not exist...");
}
public void deleteTemporaryFiles(WeightedCodebookFeatureCollection fcol, BaselineAdaptationSet sourceTrainingSet,
BaselineAdaptationSet targetTrainingSet) {
FileUtils.delete(fcol.indexMapFiles, true);
// FileUtils.delete(sourceTrainingSet.getLsfFiles(), true);
// FileUtils.delete(targetTrainingSet.getLsfFiles(), true);
// FileUtils.delete(sourceTrainingSet.getF0Files(), true);
// FileUtils.delete(targetTrainingSet.getF0Files(), true);
// FileUtils.delete(sourceTrainingSet.getEnergyFiles(), true);
// FileUtils.delete(targetTrainingSet.getEnergyFiles(), true);
FileUtils.delete(wcParams.temporaryCodebookFile);
}
}