/* Copyright (c) 2009-2011 Speech Group at Informatik 5, Univ. Erlangen-Nuremberg, GERMANY Korbinian Riedhammer Tobias Bocklet This file is part of the Java Speech Toolkit (JSTK). The JSTK is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. The JSTK is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with the JSTK. If not, see <http://www.gnu.org/licenses/>. */ package de.fau.cs.jstk.trans; import java.io.BufferedReader; import java.io.File; import java.io.FileInputStream; import java.io.FileReader; import java.io.IOException; import java.io.LineNumberReader; import java.util.Arrays; import java.util.Collections; import java.util.Comparator; import java.util.HashMap; import java.util.Iterator; import java.util.LinkedList; import java.util.List; import java.util.Map.Entry; import org.apache.log4j.BasicConfigurator; import Jama.EigenvalueDecomposition; import Jama.Matrix; import de.fau.cs.jstk.io.FrameInputStream; import de.fau.cs.jstk.io.LabelFrameInputStream; import de.fau.cs.jstk.io.SampleInputStream; import de.fau.cs.jstk.stat.Sample; import de.fau.cs.jstk.util.Arithmetics; import de.fau.cs.jstk.util.LabelTranslator; import de.fau.cs.jstk.util.Pair; public class LDA extends Projection { /** global stats for all seen data */ private Accumulator global = null; /** class dependent stats */ private HashMap<Short, Accumulator> stats = new HashMap<Short, Accumulator>(); /** for internal purposes: remember inverse(Sw) */ private double [][] Swi = null; /** for internal purposes: remember ev(inv(Sw) Sb) */ private double [] evals = null; /** * Allocate a new LDA for the given feature dimension * @param fd */ public LDA(int fd) { super(fd); global = new Accumulator(fd); } /** * Accumulate a List of Samples * @param list */ public void accumulate(List<Sample> list) { for (Sample s : list) accumulate(s.c, s.x); } /** * Accumulate a single Sample * @param s */ public void accumulate(Sample s) { accumulate(s.c, s.x); } /** * Accumulate an observation for given class * @param c class * @param x observation */ public void accumulate(short c, double [] x) { // build up global stats global.accumulate(x); // make sure we have an accumulator if (!stats.containsKey(c)) stats.put(c, new Accumulator(x.length)); // build up class dependent stats stats.get(c).accumulate(x); } /** * Estimate the projection matrix. The resulting projection will reduce the * dimension of the input data at least to (number of classes - 1). If * desired, specify manual priors for the classes. * @param priors HashMap(ClassID->prior) to specify manual priors, or null */ public void estimate(HashMap<Short, Double> priors) { // compute priors if necessary if (priors == null) { priors = new HashMap<Short, Double>(); for (Entry<Short, Accumulator> e : stats.entrySet()) priors.put(e.getKey(), (double) e.getValue().getCount() / global.getCount()); } fd = global.getFd(); double [] gm = global.getMean(); // build up within-class-covariance (lower triangular) double [] sw = new double [fd * (fd + 1) / 2]; // build up between-class-covariance (lower triangular) double [] sb = new double [fd * (fd + 1) / 2]; for (Entry<Short, Accumulator> e : stats.entrySet()) { double p = priors.get(e.getKey()); double [] m = e.getValue().getMean(); double [] c = e.getValue().getCovariance(); // sum_k p_k K_k for (int i = 0; i < sw.length; ++i) sw[i] += p * c[i]; // sum_k p_k (m_k - m)(m_k - m)^T int k = 0; for (int i = 0; i < m.length; ++i) for (int j = 0; j <= i; ++j) sb[k++] += p * (m[i] - gm[i]) * (m[j] - gm[j]); } // build up the matrices for JAMA use Matrix Sw = new Matrix(fd, fd); Matrix Sb = new Matrix(fd, fd); int k = 0; for (int i = 0; i < fd; ++i) { for (int j = 0; j <= i; ++j) { Sw.set(i, j, sw[k]); Sw.set(j, i, sw[k]); Sb.set(i, j, sb[k]); Sb.set(j, i, sb[k]); k++; } } // compute pseudo inverse to avoid regularization issues Matrix Swi = new Matrix(Arithmetics.pinv(Sw.getArray(), 1e-12)); this.Swi = Swi.getArray(); // eig(p-inv(Sw) Sb) EigenvalueDecomposition eig = new EigenvalueDecomposition(Swi.times(Sb)); // save the eigen vectors (use transposed for java convenience) double [][] vhelp = eig.getV().transpose().getArray(); LinkedList<Pair<double [], Double>> sortedEV = new LinkedList<Pair<double [], Double>>(); for (int i = 0; i < fd; ++i) sortedEV.add(new Pair<double [], Double>(vhelp[i], eig.getD().get(i, i))); // sort strongest EV first Collections.sort(sortedEV, new Comparator<Pair<double [], Double>>() { public int compare(Pair<double[], Double> o1, Pair<double[], Double> o2) { return (int) Math.signum(o2.b - o1.b); } }); // save the global mean mean = global.getMean(); // keep (num classes - 1) eigenvectors int numv = Math.min(stats.size() - 1, sortedEV.size()); proj = new double [numv][]; evals = new double [numv]; Iterator<Pair<double [], Double>> it = sortedEV.iterator(); for (int i = 0; i < numv; ++i) { Pair<double [], Double> p = it.next(); proj[i] = p.a; evals[i] = p.b; } } public double [] getEigenvalues() { return evals; } /** * Produce a String representation of the LDA containing both Projection and * LDA information. */ public String toString() { StringBuffer sb = new StringBuffer(); sb.append("Projection = \n"); sb.append(super.toString()); sb.append("LDA = \n"); sb.append("Swi = \n"); for (double [] d : Swi) sb.append(Arrays.toString(d) + "\n"); sb.append("evals = " + Arrays.toString(evals)); return sb.toString(); } public static final String SYNOPSIS = "sikoried, 2/2/2011\n" + "Compute LDA using (regularized) pseudo-inverse (SVD) and save the resulting\n" + "transformation y = A * (x-m) to the given projection file.\n" + "usage: transformations.LDA proj list1 [list2 ...] indir\n" + " proj : output file for projection (Frame format)\n" + " list : file list(s); in case of single list expecting binary sample format instead of frame.\n" + " indir : directory where the input files are located (use . for current dir)\n"; public static void main(String[] args) throws IOException { BasicConfigurator.configure(); if (args.length < 2) { System.err.println(SYNOPSIS); System.exit(1); } String outf = args[0]; String indir = args[args.length-1] + System.getProperty("file.separator"); // copy list file(s) String [] lifs = new String [args.length - 2]; System.arraycopy(args, 1, lifs, 0, lifs.length); LabelTranslator lt = new LabelTranslator("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"); LDA lda = null; if (lifs.length == 1) { // only single list: assume binary frame format LineNumberReader br = new LineNumberReader(new FileReader(lifs[0])); String line; while ((line = br.readLine()) != null) { String [] spl = line.split("\\s+"); if (spl.length == 1) { SampleInputStream sis = new SampleInputStream(new FileInputStream(indir + line)); Sample s; while ((s = sis.read()) != null) { if (lda == null) lda = new LDA(s.x.length); lda.accumulate(s); } } else if (spl.length == 2) { FrameInputStream fis = new FrameInputStream(new File(indir + spl[0])); FileInputStream lis = new FileInputStream(indir + spl[1]); LabelFrameInputStream lfis = new LabelFrameInputStream(lis, fis); double [] x = new double [lfis.getFrameSize()]; // init LDA if (lda == null) lda = new LDA(x.length); while (lfis.read(x)) { int l = lt.labelToId(lfis.getLabel()); lda.accumulate((short) l, x); } fis.close(); lis.close(); } else throw new IOException("invalid line " + br.getLineNumber() + " : " + line); } br.close(); } else { // basic case: individual list file per class for (int i = 0; i < lifs.length; ++i) { BufferedReader br = new BufferedReader(new FileReader(lifs[i])); String line; while ((line = br.readLine()) != null) { FrameInputStream fis = new FrameInputStream(new File(indir + line)); double [] buf = new double [fis.getFrameSize()]; if (lda == null) lda = new LDA(buf.length); while (fis.read(buf)) lda.accumulate((short) i, buf); } br.close(); } } lda.estimate(null); lda.save(new File(outf)); } }