/* * ApproximateFactorAnalysisPrecisionMatrix.java * * Copyright (c) 2002-2016 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.evomodel.continuous; import cern.colt.matrix.DoubleMatrix2D; import cern.colt.matrix.impl.DenseDoubleMatrix2D; import dr.inference.model.*; import dr.math.matrixAlgebra.Matrix; import dr.math.matrixAlgebra.RobustSingularValueDecomposition; import dr.math.matrixAlgebra.Vector; import dr.xml.*; /** * @author Marc A. Suchard * @author Max R. Tolkoff * */ public class ApproximateFactorAnalysisPrecisionMatrix extends Parameter.Abstract implements MatrixParameterInterface, VariableListener { private double[] values; private double[] storedValues; private boolean valuesKnown; private boolean storedValuesKnow; private final MatrixParameterInterface L; private final Parameter gamma; private int dim; private final static boolean DEBUG = false; public ApproximateFactorAnalysisPrecisionMatrix(String name, MatrixParameterInterface L, Parameter gamma) { super(name); this.L = L; this.gamma = gamma; L.addVariableListener(this); gamma.addVariableListener(this); this.dim = L.getRowDimension(); values = new double[dim * dim]; } private void computeValues() { if (!valuesKnown) { computeValuesImp(); valuesKnown = true; } } private void computeValuesImp() { dim = L.getRowDimension(); double[][] matrix = new double[dim][dim]; for (int row = 0; row < L.getRowDimension(); ++row) { for (int col = 0; col < L.getRowDimension(); ++col) { double sum = 0; for (int k = 0; k < L.getColumnDimension(); ++k) { sum += L.getParameterValue(row, k) * L.getParameterValue(col, k); } matrix[row][col] = sum; } } // DoubleMatrix2D LL = new DenseDoubleMatrix2D(matrix); // RobustSingularValueDecomposition SVD = new RobustSingularValueDecomposition(LL); // double[][] U = SVD.getU().toArray(); // double[][] V = SVD.getV().toArray(); // DoubleMatrix2D EvalsTemp= SVD.getS(); // // double[] EVals = new double[EvalsTemp.rows()]; // for (int i = 0; i < EVals.length ; i++) { // EVals[i] = EvalsTemp.get(i, i); // } // // for (int i = 0; i < EVals.length ; i++) { // if(Math.abs(EVals[i]) >= Math.pow(10, -10)){ // EVals[i] = 1 / EVals[i]; // } // else // EVals[i] = 0; // } // // // for (int i = 0; i <U.length; i++) { // for (int j = 0; j <V.length ; j++) { // matrix[i][j] = 0; // for (int k = 0; k < U.length; k++) { // matrix[i][j] += V[i][k] * EVals[k] * U[j][k]; // } // // } // // } for (int row = 0; row < dim; row++) { matrix[row][row] += 1 / gamma.getParameterValue(row); } if (DEBUG) { System.err.println("mult:"); System.err.println(new Matrix(L.getParameterAsMatrix())); System.err.println(new Vector(gamma.getParameterValues()) + "\n"); System.err.println(new Matrix(matrix)); } matrix = new Matrix(matrix).inverse().toComponents(); int index = 0; for (int row = 0; row < dim; ++row) { for (int col = 0; col < dim; ++col) { values[index] = matrix[row][col]; ++index; } } } @Override public String getDimensionName(int index) { int row = index % dim + 1; int col = index / dim + 1; // column-major return getParameterName() + row + col; } @Override public int getDimension() { return dim * dim; } @Override public double getParameterValue(int index) { computeValues(); return values[index]; } @Override public double[][] getParameterAsMatrix() { computeValues(); double[][] matrix = new double[dim][dim]; for (int i = 0; i < dim; ++i) { System.arraycopy(values, i * dim, matrix[i], 0, dim); } if (DEBUG) { System.err.println("vec:"); System.err.println(new Vector(values)); System.err.println(new Matrix(matrix)); System.err.println(""); } return matrix; } @Override public double getParameterValue(int row, int col) { computeValues(); return values[col * dim + row]; // column-major } @Override public Parameter getParameter(int column) { return null; } @Override public double[] getParameterValues() { computeValues(); double[] matrix = new double[values.length]; System.arraycopy(values, 0, matrix, 0, values.length); return matrix; } @Override protected void storeValues() { L.storeParameterValues(); gamma.storeParameterValues(); if (storedValues == null) { storedValues = new double[dim * dim]; } System.arraycopy(values, 0, storedValues, 0, values.length); storedValuesKnow = valuesKnown; } @Override protected void restoreValues() { L.restoreParameterValues(); gamma.restoreParameterValues(); double[] tmp = values; values = storedValues; storedValues = tmp; valuesKnown = storedValuesKnow; } @Override protected void acceptValues() { L.acceptParameterValues(); gamma.acceptParameterValues(); } @Override protected void adoptValues(Parameter source) { throw new RuntimeException("Not implemented"); } @Override public void setParameterValue(int dim, double value) { throw new RuntimeException("Not implemented"); } @Override public void setParameterValueQuietly(int dim, double value) { throw new RuntimeException("Not implemented"); } @Override public void setParameterValueNotifyChangedAll(int dim, double value) { throw new RuntimeException("Not implemented"); } @Override public String getParameterName() { return getId(); } @Override public void addBounds(Bounds<Double> bounds) { } @Override public Bounds<Double> getBounds() { return null; } @Override public void addDimension(int index, double value) { throw new RuntimeException("Not implemented"); } @Override public double removeDimension(int index) { throw new RuntimeException("Not implemented"); } @Override public void setParameterValue(int row, int col, double value) { throw new RuntimeException("Not implemented"); } @Override public void setParameterValueQuietly(int row, int col, double value) { throw new RuntimeException("Not implemented"); } @Override public void setParameterValueNotifyChangedAll(int row, int col, double value) { throw new RuntimeException("Not implemented"); } @Override public double[] getColumnValues(int col) { throw new RuntimeException("Not yet implemented"); } @Override public int getColumnDimension() { return dim; } @Override public int getRowDimension() { return dim; } @Override public int getUniqueParameterCount() { return 2; } @Override public Parameter getUniqueParameter(int index) { return (index == 0) ? L : gamma; } @Override public void copyParameterValues(double[] destination, int offset) { throw new RuntimeException("Not yet implemented"); } @Override public void setAllParameterValuesQuietly(double[] values, int offset) { throw new RuntimeException("Not implemented"); } @Override public String toSymmetricString() { return MatrixParameter.toSymmetricString(this); } @Override public void variableChangedEvent(Variable variable, int index, Parameter.ChangeType type) { valuesKnown = false; fireParameterChangedEvent(); } public static XMLObjectParser PARSER = new AbstractXMLObjectParser() { public static final String APPROXIMATE_PARAMETER = "approximateFactorAnalysisPrecision"; public static final String L_LABEL = "L"; public static final String GAMMA = "gamma"; public String getParserName() { return APPROXIMATE_PARAMETER; } public Object parseXMLObject(XMLObject xo) throws XMLParseException { MatrixParameterInterface L = (MatrixParameterInterface) xo.getElementFirstChild(L_LABEL); Parameter gamma = (Parameter) xo.getElementFirstChild(GAMMA); String name = xo.hasId() ? xo.getId() : APPROXIMATE_PARAMETER; return new ApproximateFactorAnalysisPrecisionMatrix(name, L, gamma); } //************************************************************************ // AbstractXMLObjectParser implementation //************************************************************************ public String getParserDescription() { return "A diffusion approximation to a factor analysis"; } public XMLSyntaxRule[] getSyntaxRules() { return rules; } private final XMLSyntaxRule[] rules = { new ElementRule(L_LABEL, new XMLSyntaxRule[] { new ElementRule(MatrixParameterInterface.class), }), new ElementRule(GAMMA, new XMLSyntaxRule[] { new ElementRule(Parameter.class), }), }; public Class getReturnType() { return ApproximateFactorAnalysisPrecisionMatrix.class; } }; } // } // DoubleMatrix2D LL = new DenseDoubleMatrix2D(matrix); // RobustSingularValueDecomposition SVD = new RobustSingularValueDecomposition(LL); // double[][] U = SVD.getU().toArray(); // double[][] V = SVD.getV().toArray(); // double[] EVals = SVD.getSingularValues(); // // for (int i = 0; i <EVals.length ; i++) { // if(EVals[i] != 0){ // EVals[i] = 1 / EVals[i]; // } // } // for (int i = 0; i < EVals.length; i++) { // for (int j = 0; j < EVals.length; j++) { // U[i][j] = U[i][j] * EVals[i]; // } // } // // for (int i = 0; i <U.length; i++) { // for (int j = 0; j <V.length ; j++) { // matrix[i][j] = U[j][i] * V[i][j]; // } // // } // // for (int row = 0; row < dim; row++) {