package dr.inference.operators; import dr.evomodel.continuous.GibbsSampleFromTreeInterface; import dr.inference.model.LatentFactorModel; import dr.inference.model.MatrixParameterInterface; import dr.inference.model.Parameter; import dr.math.MathUtils; import dr.math.distributions.MultivariateNormalDistribution; import dr.math.matrixAlgebra.Matrix; import dr.math.matrixAlgebra.SymmetricMatrix; /** * Created by max on 5/16/16. */ public class FactorTreeGibbsOperator extends SimpleMCMCOperator implements GibbsOperator { private final LatentFactorModel lfm; private double pathParameter = 1; private final GibbsSampleFromTreeInterface tree; private final GibbsSampleFromTreeInterface workingTree; private final MatrixParameterInterface factors; private final MatrixParameterInterface errorPrec; private final boolean randomScan; private final Parameter missingIndicator; public FactorTreeGibbsOperator(double weight, LatentFactorModel lfm, GibbsSampleFromTreeInterface tree, Boolean randomScan){ setWeight(weight); this.tree = tree; this.lfm = lfm; this.factors = lfm.getFactors(); errorPrec = lfm.getColumnPrecision(); this.randomScan = randomScan; this.workingTree = null; missingIndicator = lfm.getMissingIndicator(); } @Override public int getStepCount() { return 0; } @Override public String getPerformanceSuggestion() { return null; } @Override public String getOperatorName() { return "Factor Tree Gibbs Operator"; } @Override public double doOperation() { if(randomScan){ int column = MathUtils.nextInt(factors.getColumnDimension()); MultivariateNormalDistribution mvn = getMVN(column); double[] draw = (double[]) mvn.nextRandom(); for (int i = 0; i < factors.getRowDimension(); i++) { factors.setParameterValue(i, column, draw[i]); } } else{ for (int i = 0; i < factors.getColumnDimension(); i++) { MultivariateNormalDistribution mvn = getMVN(i); double[] draw = (double[]) mvn.nextRandom(); for (int j = 0; j < factors.getRowDimension(); j++) { factors.setParameterValue(j, i, draw[j]); } } } return 0; } MultivariateNormalDistribution getMVN(int column){ double[][] precision = getPrecision(column); double[] mean = getMean(column, precision); return new MultivariateNormalDistribution(mean, precision); } double[][] getPrecision(int column){ double [][] treePrec = getTreePrec(column); for (int i = 0; i < lfm.getLoadings().getColumnDimension(); i++) { for (int j = i; j < lfm.getLoadings().getColumnDimension(); j++) { for (int k = 0; k < lfm.getLoadings().getRowDimension(); k++) { treePrec[i][j] += lfm.getLoadings().getParameterValue(k, i) * errorPrec.getParameterValue(k, k) * lfm.getLoadings().getParameterValue(k, j) * pathParameter; treePrec[j][i] = treePrec[i][j]; } } } return treePrec; } double[] getMean(int column, double[][] precision){ Matrix variance = (new SymmetricMatrix(precision)).inverse(); double[] midMean = new double[lfm.getLoadings().getColumnDimension()]; double[] condMean = getTreeMean(column); double[][] condPrec = getTreePrec(column); for (int i = 0; i < midMean.length; i++) { // for (int j = 0; j < midMean.length; j++) { midMean [i] += condPrec[i][i] * condMean[i]; // } } for (int i = 0; i < lfm.getLoadings().getRowDimension(); i++) { for (int j = 0; j < lfm.getLoadings().getColumnDimension(); j++) { if(missingIndicator == null || missingIndicator.getParameterValue(column * lfm.getScaledData().getRowDimension() + i) != 1) midMean[j] += lfm.getScaledData().getParameterValue(i, column) * errorPrec.getParameterValue(i,i) * lfm.getLoadings().getParameterValue(i, j) * pathParameter; } } double[] mean = new double[midMean.length]; for (int i = 0; i < mean.length; i++) { for (int j = 0; j < mean.length; j++) { mean[i] += variance.component(i, j) * midMean[j]; } } return mean; } public double[][] getTreePrec(int column){ double answerFactor = tree.getPrecisionFactor(column); double[][] answer = new double[factors.getRowDimension()][factors.getRowDimension()]; for (int i = 0; i < factors.getRowDimension(); i++) { answer[i][i] = answerFactor; } if (workingTree != null) { double[][] temp = workingTree.getConditionalPrecision(column); for (int i = 0; i < answer.length; i++) { for (int j = 0; j < answer.length; j++) { answer[i][j] = answer[i][j] * pathParameter + temp[i][j] * (1 - pathParameter); } } } return answer; } public double[] getTreeMean(int column){ double[] answer = tree.getConditionalMean(column); if (workingTree != null) { double[] temp = workingTree.getConditionalMean(column); for (int i = 0; i < answer.length; i++) { answer[i] = answer[i] * pathParameter + temp[i] * (1 - pathParameter); } } return answer; } @Override public void setPathParameter(double beta) { pathParameter = beta; } }