package pl.edu.fuw.fid.signalanalysis.dtf; import org.apache.commons.math.complex.Complex; import org.apache.commons.math.complex.ComplexField; import org.apache.commons.math.linear.Array2DRowFieldMatrix; import org.apache.commons.math.linear.Array2DRowRealMatrix; import org.apache.commons.math.linear.FieldLUDecompositionImpl; import org.apache.commons.math.linear.FieldMatrix; import org.apache.commons.math.linear.LUDecompositionImpl; import org.apache.commons.math.linear.MatrixUtils; import org.apache.commons.math.linear.RealMatrix; /** * AR model class. Contains a factory method compute, calculating * AR model coefficients from Yule-Walker method (with whitening). * * @author ptr@mimuw.edu.pl */ public class ArModel { private final int C; private final double detV; private final RealMatrix V; private final RealMatrix[] A; private final double freqSampling; /** * Create a new AR model instance from precomputed matrix of coefficients * and error covariance matrix. * * @param C number of channels * @param A coefficient matrix * @param V error covariance matrix * @param freqSampling sampling frequency (Hz) */ public ArModel(int C, RealMatrix[] A, RealMatrix V, double freqSampling) { for (RealMatrix M : A) { if (M.getRowDimension() != C || M.getColumnDimension() != C) { throw new RuntimeException("matrix dimension mismatch"); } } this.C = C; this.A = A; this.V = V; this.detV = new LUDecompositionImpl(V).getDeterminant(); this.freqSampling = freqSampling; } /** * Fit AR model to given multi-channel signal, using the Yule-Walker method. * Signal will be whitened (mean subtracted and divided by standard variation) * prior to calculations. * * @param X multichannel signal data, rows=channels, columns=samples * @param freqSampling sampling frequency (Hz) * @param order order > 0 of the AR model to be fit * @return AR model instance with computed coefficients */ public static ArModel compute(RealMatrix X, double freqSampling, int order) { final int N = X.getColumnDimension(); final int C = X.getRowDimension(); // whitening matrix data for (int c=0; c<C; ++c) { double[] row = X.getRow(c); double sum = 0.0, sum2 = 0.0; for (double x : row) { sum += x; sum2 += x * x; } double EX = sum / N; double EX2 = sum2 / N; double D = Math.sqrt(EX2 - EX*EX); for (int i=0; i<N; ++i) { X.setEntry(c, i, (row[i] - EX) / D); } } // calculating lag correlations RealMatrix[] R = new RealMatrix[1+order]; for (int s=0; s<=order; ++s) { R[s] = new Array2DRowRealMatrix(C, C); for (int i=0; i<C; ++i) { for (int j=0; j<C; ++j) { double sum = 0; for (int t=0; t<N-s; ++t) { // causality i -> j, so t < t+s sum += X.getEntry(i, t) * X.getEntry(j, t+s); } R[s].setEntry(i, j, sum / N); } } } // matrices for Yule-Walker equations RealMatrix bigMatrix = new Array2DRowRealMatrix(order*C, order*C); RealMatrix bigColumn = new Array2DRowRealMatrix(order*C, C); for (int i=0; i<order; ++i) for (int j=0; j<order; ++j) { int s = i - j; RealMatrix block = (s < 0) ? R[-s].transpose() : R[s]; bigMatrix.setSubMatrix(block.getData(), i*C, j*C); } for (int i=0; i<order; ++i) { RealMatrix block = R[i+1]; bigColumn.setSubMatrix(block.getData(), i*C, 0); } // solution of Yule-Walker equations RealMatrix bigMatrixInverse = new LUDecompositionImpl(bigMatrix).getSolver().getInverse(); // TODO what if not invertible? RealMatrix bigSolution = bigMatrixInverse.multiply(bigColumn); RealMatrix[] A = new RealMatrix[1+order]; A[0] = MatrixUtils.createRealIdentityMatrix(C).scalarMultiply(-1); for (int s=1; s<=order; ++s) { A[s] = bigSolution.getSubMatrix((s-1)*C, s*C-1, 0, C-1); } // computing residual error RealMatrix V = new Array2DRowRealMatrix(C, C); for (int s=0; s<=order; ++s) { V = V.subtract(A[s].transpose().multiply(R[s])); } return new ArModel(C, A, V, freqSampling); } public ArModelData[][] computeSpectralData(int spectrumSize, boolean normalized) { ArModelData[][] data = new ArModelData[C][C]; for (int i=0; i<C; ++i) for (int j=0; j<C; ++j) { data[i][j] = new ArModelData(spectrumSize); } final double nyquist = 0.5 * getSamplingFrequency(); for (int f=0; f<spectrumSize; ++f) { double freq = f * nyquist / spectrumSize; RealMatrix H = computeTransferMatrix(freq, normalized); for (int i=0; i<C; ++i) for (int j=0; j<C; ++j) { double value = H.getEntry(i, j); data[i][j].freqcs[f] = freq; data[i][j].values[f] = value; } } return data; } public RealMatrix computeTransferMatrix(double freq, boolean normalize) { FieldMatrix<Complex> S = new Array2DRowFieldMatrix<Complex>(ComplexField.getInstance(), C, C); for (int s=0; s<A.length; ++s) { Complex exp = new Complex(0, -2*Math.PI*s*freq/freqSampling).exp(); for (int i=0; i<C; ++i) for (int j=0; j<C; ++j) { double val = A[s].getEntry(i, j); S.addToEntry(i, j, exp.multiply(val)); } } FieldMatrix<Complex> H = new FieldLUDecompositionImpl(S).getSolver().getInverse(); RealMatrix DTF = new Array2DRowRealMatrix(C, C); for (int i=0; i<C; ++i) { for (int j=0; j<C; ++j) { Complex h = H.getEntry(i, j); double re = h.getReal(); double im = h.getImaginary(); DTF.setEntry(i, j, re*re + im*im); } } if (normalize) { // entry (i, j) represents causality i -> for (int j=0; j<C; ++j) { double norm = 0; for (int i=0; i<C; ++i) { norm += DTF.getEntry(i, j); } norm = 1.0 / norm; for (int i=0; i<C; ++i) { DTF.multiplyEntry(i, j, norm); } } } FieldMatrix<Complex> Hplus = new FieldLUDecompositionImpl(S).getSolver().getInverse(); FieldMatrix<Complex> cV = new Array2DRowFieldMatrix<Complex>(ComplexField.getInstance(), C, C); for (int i=0; i<C; ++i) for (int j=0; j<C; ++j) { cV.setEntry(i, j, new Complex(V.getEntry(i, j), 0)); Hplus.setEntry(i, j, H.getEntry(j, i).conjugate()); } FieldMatrix<Complex> spectrum = H.multiply(cV).multiply(Hplus); for (int i=0; i<C; ++i) { DTF.setEntry(i, i, spectrum.getEntry(i, i).abs()); } return DTF; } private static String exportMatrix(RealMatrix M) { boolean comma = false; String result = "["; for (int r=0; r<M.getRowDimension(); ++r) { if (comma) result += ","; result += exportRow(M.getRow(r)); comma = true; } result += "]"; return result; } private static String exportRow(double[] row) { boolean comma = false; String result = "["; for (double v : row) { if (comma) result += ","; result += v; comma = true; } result += "]"; return result; } public String exportCoefficients() { boolean comma = false; String result = "["; for (int i=1; i<A.length; ++i) { if (comma) result += ","; result += exportMatrix(A[i]); comma = true; } result += "]"; return result; } public int getChannelCount() { return C; } public double getErrorDeterminant() { return detV; } public double getSamplingFrequency() { return freqSampling; } }