/**
* Copyright 2003-2007 DFKI GmbH.
* All Rights Reserved. Use is subject to license terms.
*
* This file is part of MARY TTS.
*
* MARY TTS is free software: you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as published by
* the Free Software Foundation, version 3 of the License.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*
*/
package marytts.fst;
import java.io.BufferedReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import marytts.util.MaryUtils;
import org.apache.log4j.Logger;
/**
* This trains an alignment model between Strings. Applications are for example letter-to-sound rule training (see LTSTrainer) or
* transducer construction/minimization.
* <p>
* The basic idea is to perform a Levenshtein search for the cheapest path and read off an alignment from that. The costs used in
* the distance computation are not uniform but estimated in an iterative process, according to -log of the relative frequencies
* of the respective operations in the previous iteration. Perform several iterations (e.g. 4) of aligning in order to get stable
* estimates of the costs (and a good alignment in turn).
* <p>
* The algorithm, in its essence, is implemented after a description of Levenshtein distance as it can be found in Wikipedia (see
* below); consider the costs used in the pseudo-code:
*
* <pre>
* d[i, j] := minimum
* (
* d[i-1, j] + 1, // deletion
* d[i, j-1] + 1, // insertion
* d[i-1, j-1] + 1 // substitution
* )
* </pre>
*
* In our implementation there are only two operations, corresponding to deletion and insertion. So, if you look at the matrices
* in the wiki article, you can only go down and to the right, but not diagonal. Second, the costs are not 1 but set as explained
* in the following (note that this is a heuristic that seems to work fine but <em>not</em> a derived EM-algorithm).
* <p>
* "insertion" menas in our case, to insert something for (dependent on) the current input symbol. The cost for this operation is
* lower if the two symbols were already aligned in the preceding iteration, they are set to -log
* P(output-symbol|"insertion",input-symbol).
* <p>
* "deletion" means in our case to go to the next input symbol. If a deletion operation is performed without an preceding
* insertion operation (i.e. two subsequent deletion operations) this is called a "skip" and will produce costs, going to the next
* symbol after an insertion is free (this is to avoid unaligned input symbols). The skip costs are estimated from the preceding
* iteration and set to -log P(skip|"deletion").
* <p>
* In addition, I made the following optimization, described in Wikipedia: <blockquote>We can adapt the algorithm to use less
* space, O(m) instead of O(mn), since it only requires that the previous row and current row be stored at any one time.
* </blockquote> therefore the three arrays for all information and the swapping statements in the align method. (note that what
* are rows in Wikipedia are columns here)
*
* @see <a
* href="http://en.wikipedia.org/w/index.php?title=Levenshtein_distance&oldid=349201802#Computing_Levenshtein_distance">Computing
* Levenshtein distance</a>
* @author benjaminroth
*
*/
public class AlignerTrainer {
// cost of translating first element of the pair into the second
private HashMap<StringPair, Integer> aligncost;
private int defaultcost = 10;
// cost of deleting an element
private int skipcost;
private double logOf2 = Math.log(2.0);
// optional info, eg. part-of-speech
protected List<String> optInfo;
// input side (eg. graphemes) of string pairs, split into symbols
protected List<String[]> inSplit;
// output side (eg. phones) of string pairs, split into symbols
protected List<String[]> outSplit;
protected Set<String> graphemeSet;
protected Logger logger;
private boolean inIsOut;
/**
*
* @param inIsOutAlphabet
* boolean indicating as input and output strings should be considered as belonging to the same symbol sets
* (alignment between identical symbol is then cost-free)
* @param hasOptInfo
* has opt info
*/
public AlignerTrainer(boolean inIsOutAlphabet, boolean hasOptInfo) {
this.skipcost = this.defaultcost;
this.aligncost = new HashMap<StringPair, Integer>();
this.inSplit = new ArrayList<String[]>();
this.outSplit = new ArrayList<String[]>();
this.graphemeSet = new HashSet<String>();
// allow "null" as a dummy value for all phone features
this.graphemeSet.add("null");
this.inIsOut = inIsOutAlphabet;
if (hasOptInfo) {
this.optInfo = new ArrayList<String>();
}
this.logger = MaryUtils.getLogger(this.getClass());
}
/**
* New AlignerTrainer for pairs of different symbol sets with no optional info.
*/
public AlignerTrainer() {
this(false, false);
}
/**
*
* This reads a lexicon where input and output strings are separated by a delimiter that can be specified (splitSym). Strings
* are taken as they are no normalization (eg. stress/syllable symbol removal, lower-casing ...) is performed; if space
* characters are present in the output string, it is used as a separator. In a third row additional info (eg. part of speech)
* can be given. Strings are stored split into symbols.
*
* @param lexicon
* reader for lexicon
* @param splitSym
* symbol to split columns of lexicon
* @throws IOException
* IOException
*/
public void readLexicon(BufferedReader lexicon, String splitSym) throws IOException {
String line;
while ((line = lexicon.readLine()) != null) {
String[] lineParts = line.trim().split(splitSym);
this.splitAndAdd(lineParts[0], lineParts[1]);
if (this.optInfo != null)
this.optInfo.add(lineParts.length > 2 ? lineParts[2] : null);
}
}
/**
* This adds the input and output string in the most simple way: symbols are simply the characters of the strings - no
* phonemisation/syllabification or whatsoever is performed. If outStr contains space characters, it is used as a separator
* for splitting.
*
* @param inStr
* inStr
* @param outStr
* outStr
*/
public void splitAndAdd(String inStr, String outStr) {
String[] inStrSplit = new String[inStr.length()];
for (int i = 0; i < inStr.length(); i++) {
String c = inStr.substring(i, i + 1);
this.graphemeSet.add(c);
inStrSplit[i] = c;
}
String[] outStrSplit;
if (outStr.contains(" ")) {
outStrSplit = outStr.split(" ");
// preserve space between allophones:
for (int i = 1, max = outStrSplit.length; i < max; i++) {
outStrSplit[i] = " " + outStrSplit[i];
}
} else { // split into individual characters
outStrSplit = new String[outStr.length()];
for (int i = 0; i < outStr.length(); i++) {
outStrSplit[i] = outStr.substring(i, i + 1);
}
}
this.inSplit.add(inStrSplit);
this.outSplit.add(outStrSplit);
}
public void addAlreadySplit(List<String> inStr, List<String> outStr) {
this.inSplit.add(inStr.toArray(new String[] {}));
this.outSplit.add(outStr.toArray(new String[] {}));
}
public void addAlreadySplit(String[] inStr, String[] outStr) {
this.inSplit.add(inStr);
this.outSplit.add(outStr);
}
public void addAlreadySplit(List<String> inStr, List<String> outStr, String optionalInfo) {
this.inSplit.add(inStr.toArray(new String[] {}));
this.outSplit.add(outStr.toArray(new String[] {}));
this.optInfo.add(optionalInfo);
}
public void addAlreadySplit(String[] inStr, String[] outStr, String optionalInfo) {
this.inSplit.add(inStr);
this.outSplit.add(outStr);
this.optInfo.add(optionalInfo);
}
/**
* One iteration of alignment, using adapted Levenshtein distance. After the iteration, the costs between a grapheme and a
* phone are set by the log probability of the phone given the grapheme. Analogously, The deletion cost is set by the log of
* deletion probability. In the first iteration, all operations cost maxCost.
*
*/
public void alignIteration() {
// this counts how many times a symbol is mapped to symbols
Map<String, Integer> symMapCount = new HashMap<String, Integer>();
// this counts how often particular mappings from one symbol to another occurred
Map<StringPair, Integer> sym2symCount = new HashMap<StringPair, Integer>();
// how many symbols are on input side
int symCount = 0;
// how many symbols are deleted
int symDels = 0;
// for every alignment pair collect counts
for (int i = 0; i < this.outSplit.size(); i++) {
String[] in = this.inSplit.get(i);
String[] out = this.outSplit.get(i);
int[] alignment = this.align(in, out);
symCount += in.length;
int pre = 0;
// for every input symbol...
for (int inNr = 0; inNr < in.length; inNr++) {
if (alignment[inNr] == pre) {
// is mapped to empty string
symDels++;
} else {
// mapped to one or several symbols
// increase count of overall mappings for this symbol
Integer c = symMapCount.get(in[inNr]);
if (null == c) {
symMapCount.put(in[inNr], alignment[inNr] - pre);
} else {
symMapCount.put(in[inNr], c + alignment[inNr] - pre);
}
// for every corresponding output symbol
for (int outNr = pre; outNr < alignment[inNr]; outNr++) {
// get key for mapping symbol to symbol
StringPair key = new StringPair(in[inNr], out[outNr]);
Integer mapC = sym2symCount.get(key);
if (null == mapC) {
sym2symCount.put(key, 1);
} else {
sym2symCount.put(key, 1 + mapC);
}
} // ...for each output-symbol
} // ...if > 0 output-symbols
pre = alignment[inNr];
} // ...for each input symbol
} // ...for each input string
// now build fractions, to estimate the new costs
// first reset skip costs
double delFraction = (double) symDels / symCount;
this.skipcost = (int) -this.log2(delFraction);
// now reset aligncosts
this.aligncost.clear();
for (StringPair mapping : sym2symCount.keySet()) {
String firstSym = mapping.getString1();
double fraction = (double) sym2symCount.get(mapping) / symMapCount.get(firstSym);
int cost = (int) -this.log2(fraction);
if (cost < this.defaultcost) {
this.aligncost.put(mapping, cost);
}
}
}
public int lexiconSize() {
return this.inSplit.size();
}
/**
*
* gets an alignment of the graphemes to the phones of an entry. a StringPair array is returned, where every entry contains a
* grapheme together with the phone sequence it is mapped to. The phone String is just the concatenation of the symbols in the
* aligned sequence.
*
* @param entryNr
* nr of the lexicon entry
* @return listArray
*/
public StringPair[] getAlignment(int entryNr) {
String[] in = this.inSplit.get(entryNr);
String[] out = this.outSplit.get(entryNr);
int[] align = this.align(in, out);
StringPair[] listArray = new StringPair[in.length];
int pre = 0;
for (int pos = 0; pos < in.length; pos++) {
String inStr = in[pos];
String oStr = "";
for (int alPos = pre; alPos < align[pos]; alPos++) {
oStr += out[alPos];
}
pre = align[pos];
listArray[pos] = new StringPair(inStr, oStr);
}
return listArray;
}
public String[] getAlignmentString(int entryNr) {
String[] in = this.inSplit.get(entryNr);
String[] out = this.outSplit.get(entryNr);
int[] align = this.align(in, out);
String[] stringArray = new String[in.length];
int pre = 0;
for (int pos = 0; pos < in.length; pos++) {
String inStr = in[pos];
String oStr = "";
for (int alPos = pre; alPos < align[pos]; alPos++) {
oStr += " " + out[alPos];
}
pre = align[pos];
stringArray[pos] = inStr + oStr;
}
return stringArray;
}
/**
*
* gets an alignment of the graphemes to the phones of an entry. a StringPair array is returned, where every entry contains a
* grapheme together with the phone sequence it is mapped to. The phone String is just the concatenation of the symbols in the
* aligned sequence. In addition, the extra info (eg. POS) is appended as one symbol on the input side.
*
* @param entryNr
* nr of the lexicon entry
* @return listArray
*/
public StringPair[] getInfoAlignment(int entryNr) {
if (null == optInfo.get(entryNr))
return getAlignment(entryNr);
String[] in = this.inSplit.get(entryNr);
String[] out = this.outSplit.get(entryNr);
int[] align = this.align(in, out);
StringPair[] listArray = new StringPair[in.length + 1];
int pre = 0;
for (int pos = 0; pos < in.length; pos++) {
String inStr = in[pos];
String oStr = "";
for (int alPos = pre; alPos < align[pos]; alPos++) {
oStr += out[alPos];
}
pre = align[pos];
listArray[pos] = new StringPair(inStr, oStr);
}
listArray[in.length] = new StringPair(optInfo.get(entryNr), "");
return listArray;
}
public Set<String> getInputSyms() {
if (this.graphemeSet == null || this.graphemeSet.isEmpty()) {
return this.collectInputSyms();
} else {
return this.graphemeSet;
}
}
private Set<String> collectInputSyms() {
this.graphemeSet = new HashSet<String>();
// allow "null" as a dummy value for all phone features
this.graphemeSet.add("null");
for (String[] is : this.inSplit) {
for (String sym : is) {
this.graphemeSet.add(sym);
}
}
return this.graphemeSet;
}
private double log2(double d) {
return Math.log(d) / logOf2;
}
private int symDist(StringPair key) {
Integer cost = aligncost.get(key);
if (null == cost) {
if (this.inIsOut)
return (key.getString1().equals(key.getString2())) ? 0 : this.defaultcost;
else
return this.defaultcost;
}
return cost;
}
/**
*
* This computes the alignment that has the lowest distance between two Strings.
*
* There are three differences to the normal Levenshtein-distance:
*
* 1. Only insertions and deletions are allowed, no replacements (i.e. no "diagonal" transitions) 2. insertion costs are
* dependent on a particular phone on the input side (the one they are aligned to) 3. deletion is equivalent to a symbol on
* the input side that is not aligned. There are costs associated with that.
*
* The method returns for each input symbol the indix of the right alignment boundary. eg. for input ['a','b'] and output
* ['a','a','b'] a correct alignment would be: [2,3]
*
* @param istr
* the input string
* @param ostr
* the output string
* @return length of p_al[ostr]
*/
public int[] align(String[] istr, String[] ostr) {
StringPair key = new StringPair(null, null);
// distances:
// 1. previous distance (= previous column in matrix)
int[] p_d = new int[ostr.length + 1];
// 2. current distance
int[] d = new int[ostr.length + 1];
// 3. dummy array for swapping, when switching to new column
int[] _d;
// array indicating if a skip was performed (= if current character has not been aligned)
// same arrays as for distances
boolean[] p_sk = new boolean[ostr.length + 1];
boolean[] sk = new boolean[ostr.length + 1];
boolean[] _sk;
// arrays storing the alignment boundaries
int[][] p_al = new int[ostr.length + 1][istr.length];
int[][] al = new int[ostr.length + 1][istr.length];
int[][] _al;
// initialize values
p_d[0] = 0;
p_sk[0] = true;
// ... still initializing
for (int j = 1; j < ostr.length + 1; j++) {
// only possibility first is to align the first letter
// of the input string to everything
p_al[j][0] = j;
key.setString1(istr[0]);
key.setString2(ostr[j - 1]);
p_d[j] = p_d[j - 1] + symDist(key);
p_sk[j] = false;
}
// constant penalty for not aligning a character
int skConst = this.skipcost;
// align
// can start at 1, since 0 has been treated in initialization
for (int i = 1; i < istr.length; i++) {
// zero'st row stands for skipping from the beginning on
d[0] = p_d[0] + skConst;
sk[0] = true;
for (int j = 1; j < ostr.length + 1; j++) {
// translation cost between symbols ( j-1, because 0 row
// inserted for not aligning at beginning)
key.setString1(istr[i]);
key.setString2(ostr[j - 1]);
int tr_cost = symDist(key);
// skipping cost greater zero if not yet aligned
int sk_cost = p_sk[j] ? skConst : 0;
if (sk_cost + p_d[j] < tr_cost + d[j - 1]) {
// skipping cheaper
// cost is cost from previous input char + skipping
d[j] = sk_cost + p_d[j];
// alignment is from prev. input + delimiter
al[j] = p_al[j];
al[j][i] = j;
// yes, we skipped
sk[j] = true;
} else {
// aligning cheaper
// cost is that from previously aligned output + distance
d[j] = tr_cost + d[j - 1];
// alignment continues from previously aligned
System.arraycopy(al[j - 1], 0, al[j], 0, i);// copy of...
al[j][i] = j;
// nope, didn't skip
sk[j] = false;
}
}
// swapping
_d = p_d;
p_d = d;
d = _d;
_sk = p_sk;
p_sk = sk;
sk = _sk;
_al = p_al;
p_al = al;
al = _al;
}
return p_al[ostr.length];
}
}