package dr.inference.operators; import dr.evomodel.continuous.OrderedLatentLiabilityLikelihood; import dr.inference.model.DiagonalMatrix; import dr.inference.model.LatentFactorModel; import dr.inference.model.MatrixParameterInterface; import dr.inference.model.Parameter; import dr.math.MathUtils; import dr.math.distributions.NormalDistribution; /** * Created by Max on 9/1/16. */ public class LatentFactorLiabilityGibbsOperator extends SimpleMCMCOperator implements GibbsOperator{ LatentFactorModel lfm; OrderedLatentLiabilityLikelihood liabilityLikelihood; public LatentFactorLiabilityGibbsOperator(double weight, LatentFactorModel lfm, OrderedLatentLiabilityLikelihood liabilityLikelihood) { setWeight(weight); this.lfm = lfm; this.liabilityLikelihood = liabilityLikelihood; } @Override public String getPerformanceSuggestion() { return null; } @Override public String getOperatorName() { return "LatentFactorLiabilityGibbsOperator"; } @Override public double doOperation() { if(liabilityLikelihood.getOrdering()) doUnorderedOperation(); else doOrderedOperation(); return 0; } void doUnorderedOperation() { double[] LxF = lfm.getLxF(); DiagonalMatrix colPrec = (DiagonalMatrix) lfm.getColumnPrecision(); Parameter continuous = lfm.getContinuous(); MatrixParameterInterface lfmData = lfm.getScaledData(); for (int i = 0; i < lfmData.getColumnDimension(); i++) { int LLpointer = 0; int[] data = liabilityLikelihood.getData(i); for (int index = 0; index < data.length; ++index) { int datum = data[index]; Parameter numClasses = liabilityLikelihood.numClasses; int dim = (int) numClasses.getParameterValue(index); if(datum >= dim && continuous.getParameterValue(LLpointer) == 0){ double draw = drawTruncatedNormalDistribution(LxF[i * lfmData.getRowDimension() + LLpointer], colPrec.getParameterValue(LLpointer, LLpointer), Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY); lfmData.setParameterValue(LLpointer, i, draw); } if (dim == 1.0) { if (continuous.getParameterValue(LLpointer) == 0) { double draw = drawTruncatedNormalDistribution(LxF[i * lfmData.getRowDimension() + LLpointer], colPrec.getParameterValue(LLpointer, LLpointer), Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY); lfmData.setParameterValue(LLpointer, i, draw); } LLpointer++; } else if (dim == 2.0) { if (datum == 0) { double draw = drawTruncatedNormalDistribution(LxF[i * lfmData.getRowDimension() + LLpointer], colPrec.getParameterValue(LLpointer, LLpointer), Double.NEGATIVE_INFINITY, 0); lfmData.setParameterValue(LLpointer, i, draw); } else { double draw = drawTruncatedNormalDistribution(LxF[i * lfmData.getRowDimension() + LLpointer], colPrec.getParameterValue(LLpointer, LLpointer), 0, Double.POSITIVE_INFINITY); lfmData.setParameterValue(LLpointer, i, draw); } LLpointer++; } else { double[] trait = new double[dim]; trait[0] = 0.0; if(datum == 0){ for (int l = 0; l < dim - 1; l++) { lfmData.setParameterValue(LLpointer + l, i, drawTruncatedNormalDistribution(LxF[i * lfmData.getRowDimension() + (LLpointer + l)], colPrec.getParameterValue((LLpointer + l), (LLpointer + l)), Double.NEGATIVE_INFINITY, 0)); } } else { trait[datum] = drawTruncatedNormalDistribution(LxF[i * lfmData.getRowDimension() + (LLpointer + datum - 1)], colPrec.getParameterValue((LLpointer + datum - 1), (LLpointer + datum - 1)), 0, Double.POSITIVE_INFINITY); lfmData.setParameterValue(LLpointer + datum - 1, i, trait[datum]); for (int l = 1; l < dim; l++) { if(l != datum){ // System.out.println("Free Rolls"); // System.out.println(LxF[i * lfmData.getRowDimension() + (LLpointer + l - 1)]); trait[l] = drawTruncatedNormalDistribution(LxF[i * lfmData.getRowDimension() + (LLpointer + l - 1)], colPrec.getParameterValue((LLpointer + l - 1), (LLpointer + l - 1)), Double.NEGATIVE_INFINITY, trait[datum]); lfmData.setParameterValue(LLpointer + l - 1, i, trait[l]); } } // double max = Double.NEGATIVE_INFINITY; // for (int j = 0; j < trait.length; j++) { // if(max < trait[j]){ // max = trait[j]; // } // } // System.out.println("Constrained"); // System.out.println(LxF[i * lfmData.getRowDimension() + (LLpointer + datum - 1)]); } // valid = isMax(trait, datum); LLpointer += dim - 1; } } } } void doOrderedOperation(){ double[] LxF = lfm.getLxF(); DiagonalMatrix colPrec = (DiagonalMatrix) lfm.getColumnPrecision(); Parameter continuous = lfm.getContinuous(); MatrixParameterInterface lfmData = lfm.getScaledData(); Parameter threshold = liabilityLikelihood.getThreshold(); for (int i = 0; i < lfmData.getColumnDimension(); i++) { int Thresholdpointer = 0; int[] data = liabilityLikelihood.getData(i); for (int index = 0; index < data.length; ++index) { int datum = data[index]; Parameter numClasses = liabilityLikelihood.numClasses; int dim = (int) numClasses.getParameterValue(index); if(datum >= dim && continuous.getParameterValue(index) == 0){ double draw = drawTruncatedNormalDistribution(LxF[i * lfmData.getRowDimension() + index], colPrec.getParameterValue(index, index), Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY); lfmData.setParameterValue(index, i, draw); } else { if (dim == 1.0) { if (continuous.getParameterValue(index) == 0) { double draw = drawTruncatedNormalDistribution(LxF[i * lfmData.getRowDimension() + index], colPrec.getParameterValue(index, index), Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY); lfmData.setParameterValue(index, i, draw); } } else if (dim == 2.0) { if (datum == 0) { double draw = drawTruncatedNormalDistribution(LxF[i * lfmData.getRowDimension() + index], colPrec.getParameterValue(index, index), Double.NEGATIVE_INFINITY, 0); lfmData.setParameterValue(index, i, draw); } else { double draw = drawTruncatedNormalDistribution(LxF[i * lfmData.getRowDimension() + index], colPrec.getParameterValue(index, index), 0, Double.POSITIVE_INFINITY); lfmData.setParameterValue(index, i, draw); } } else { double[] thresholdList = new double[dim + 1]; thresholdList[0] = Double.NEGATIVE_INFINITY; thresholdList[1] = 0; thresholdList[dim] = Double.POSITIVE_INFINITY; for (int j = 0; j < thresholdList.length - 3; j++) { thresholdList[j + 2] = threshold.getParameterValue(Thresholdpointer + j); } Thresholdpointer += dim - 2; double draw = drawTruncatedNormalDistribution(LxF[i * lfmData.getRowDimension() + index], colPrec.getParameterValue(index, index), thresholdList[datum], thresholdList[datum + 1]); lfmData.setParameterValue(index, i, draw); } } // valid = isMax(trait, datum); } } } double drawTruncatedNormalDistribution(double mean, double precision, double lower, double upper){ double sd = Math.sqrt(1 / precision); NormalDistribution normal = new NormalDistribution(mean, sd); double newLower = normal.cdf(lower); double newUpper = normal.cdf(upper); double cdfDraw = 1.0; int iterator = 0; boolean invalid = true; double draw = 0; while(iterator < 10000 && invalid){ cdfDraw = MathUtils.nextDouble() * (newUpper - newLower) + newLower; draw = normal.quantile(cdfDraw); if(!Double.isNaN(draw) && draw > lower && draw < upper) { invalid = false; } iterator++; } // if(iterator != 1){ // System.out.println(iterator); // System.out.println(draw); // System.out.println(lower); // System.out.println(upper);} if(Double.isNaN(draw) || Double.isInfinite(draw)){ if(Double.isInfinite(lower)){ // System.out.println("upper"); // System.out.println(upper); return upper;} else if(Double.isInfinite(upper)){ // System.out.println("lower"); // System.out.println(lower); return lower;} else return (lower + upper) / 2; } else return draw; } @Override public int getStepCount() { return 0; } }