package edu.cmu.sphinx.decoder.adaptation; import java.io.File; import java.io.PrintWriter; import java.util.Scanner; import org.apache.commons.math3.linear.Array2DRowRealMatrix; import org.apache.commons.math3.linear.ArrayRealVector; import org.apache.commons.math3.linear.DecompositionSolver; import org.apache.commons.math3.linear.LUDecomposition; import org.apache.commons.math3.linear.RealMatrix; import org.apache.commons.math3.linear.RealVector; import edu.cmu.sphinx.linguist.acoustic.tiedstate.Sphinx3Loader; public class Transform { private float[][][][] As; private float[][][] Bs; private Sphinx3Loader loader; private int nrOfClusters; public Transform(Sphinx3Loader loader, int nrOfClusters) { this.loader = loader; this.nrOfClusters = nrOfClusters; } /** * Used for access to A matrix. * * @return A matrix (representing A from A*x + B = C) */ public float[][][][] getAs() { return As; } /** * Used for access to B matrix. * * @return B matrix (representing B from A*x + B = C) */ public float[][][] getBs() { return Bs; } /** * Writes the transformation to file in a format that could further be used * in Sphinx3 and Sphinx4. * * @param filePath * path to store transform matrix * @param index * index of transform to store * @throws Exception * if something went wrong */ public void store(String filePath, int index) throws Exception { PrintWriter writer = new PrintWriter(filePath, "UTF-8"); // nMllrClass writer.println("1"); writer.println(loader.getNumStreams()); for (int i = 0; i < loader.getNumStreams(); i++) { writer.println(loader.getVectorLength()[i]); for (int j = 0; j < loader.getVectorLength()[i]; j++) { for (int k = 0; k < loader.getVectorLength()[i]; ++k) { writer.print(As[index][i][j][k]); writer.print(" "); } writer.println(); } for (int j = 0; j < loader.getVectorLength()[i]; j++) { writer.print(Bs[index][i][j]); writer.print(" "); } writer.println(); for (int j = 0; j < loader.getVectorLength()[i]; j++) { writer.print("1.0 "); } writer.println(); } writer.close(); } /** * Used for computing the actual transformations (A and B matrices). These * are stored in As and Bs. */ private void computeMllrTransforms(double[][][][][] regLs, double[][][][] regRs) { int len; DecompositionSolver solver; RealMatrix coef; RealVector vect, ABloc; for (int c = 0; c < nrOfClusters; c++) { this.As[c] = new float[loader.getNumStreams()][][]; this.Bs[c] = new float[loader.getNumStreams()][]; for (int i = 0; i < loader.getNumStreams(); i++) { len = loader.getVectorLength()[i]; this.As[c][i] = new float[len][len]; this.Bs[c][i] = new float[len]; for (int j = 0; j < len; ++j) { coef = new Array2DRowRealMatrix(regLs[c][i][j], false); solver = new LUDecomposition(coef).getSolver(); vect = new ArrayRealVector(regRs[c][i][j], false); ABloc = solver.solve(vect); for (int k = 0; k < len; ++k) { this.As[c][i][j][k] = (float) ABloc.getEntry(k); } this.Bs[c][i][j] = (float) ABloc.getEntry(len); } } } } /** * Read the transformation from a file * * @param filePath * file path to load transform * @throws Exception * if something went wrong */ public void load(String filePath) throws Exception { Scanner input = new Scanner(new File(filePath)); int numStreams, nMllrClass; nMllrClass = input.nextInt(); assert nMllrClass == 1; numStreams = input.nextInt(); this.As = new float[nMllrClass][numStreams][][]; this.Bs = new float[nMllrClass][numStreams][]; for (int i = 0; i < numStreams; i++) { int length = input.nextInt(); this.As[0][i] = new float[length][length]; this.Bs[0][i] = new float[length]; for (int j = 0; j < length; j++) { for (int k = 0; k < length; k++) { As[0][i][j][k] = input.nextFloat(); } } for (int j = 0; j < length; j++) { Bs[0][i][j] = input.nextFloat(); } for (int j = 0; j < length; j++) { // Skip MLLR variance scale input.nextFloat(); } } input.close(); } /** * Stores in current object a transform generated on the provided stats. * * @param stats * provided stats that were previously collected from Result * objects. */ public void update(Stats stats) { stats.fillRegLowerPart(); As = new float[nrOfClusters][][][]; Bs = new float[nrOfClusters][][]; this.computeMllrTransforms(stats.getRegLs(), stats.getRegRs()); } }