package statalign.base; import java.io.FileWriter; import java.io.IOException; import java.text.DecimalFormat; import java.util.ArrayList; import java.util.HashMap; import java.util.Locale; import java.util.Random; import mpi.MPI; import org.apache.commons.math3.random.RandomGenerator; import org.apache.commons.math3.random.Well19937c; import statalign.MPIUtils; import statalign.base.mcmc.AlignmentMove; import statalign.base.mcmc.AllEdgeMove; import statalign.base.mcmc.CoreMcmcModule; import statalign.base.mcmc.EdgeMove; import statalign.base.mcmc.IndelMove; import statalign.base.mcmc.LOCALTopologyMove; import statalign.base.mcmc.LambdaMove; import statalign.base.mcmc.MuMove; import statalign.base.mcmc.PhiMove; import statalign.base.mcmc.RMove; import statalign.base.mcmc.RhoMove; import statalign.base.mcmc.SilentIndelMove; import statalign.base.mcmc.SubstMove; import statalign.base.mcmc.ThetaMove; import statalign.base.mcmc.TopologyMove; import statalign.base.thread.Stoppable; import statalign.base.thread.StoppedException; import statalign.mcmc.BetaPrior; import statalign.mcmc.GammaPrior; import statalign.mcmc.GaussianProposal; import statalign.mcmc.LogisticProposal; import statalign.mcmc.McmcCombinationMove; import statalign.mcmc.McmcModule; import statalign.mcmc.McmcMove; import statalign.mcmc.MultiplicativeProposal; import statalign.mcmc.UniformProposal; import statalign.model.ext.ModelExtManager; import statalign.postprocess.PostprocessManager; import statalign.postprocess.plugins.contree.CNetwork; import statalign.ui.ErrorMessage; import statalign.ui.MainFrame; import statalign.utils.BetaDistribution; /** * * This class handles an MCMC run. * * The class extends <tt>Stoppable</tt>, it may be terminated/suspended in * graphical mode. * * @author miklos, novak, herman * */ public abstract class Mcmc extends Stoppable { /** Is this a parallel chain? By-default false. */ protected boolean isParallel = false; private boolean simulatedAnnealing = false; public CNetwork network; /** Current tree in the MCMC chain. */ public Tree tree; /** Total log-likelihood of the current state, cached for speed */ protected double totalLogLike; /** * If this variable is true, all proposed moves are accepted. * The purpose of this is to allow: * a) an initial run of moves to randomise the starting configuration * b) a series of moves to be proposed before choosing whether to * accept or reject the final configuration, with Hastings ratio * equal to the product of the Hastings ratios along the way. */ boolean acceptAllMoves = false; /** * Number of steps in which the chain will be allowed to move randomly * with all moves accepted, in order to create a random starting * configuration. */ int randomisationPeriod = 0; /** * To be used in order to allow a series of moves to be proposed before * deciding whether to accept or reject the final state, with Hastings * ratio given by this cumulative value. */ double cumulativeLogProposalRatio = 0.0; /** * MCMC parameters including the number of burn-in steps, the total number * of steps in the MCMC and the sampling rate. */ public MCMCPars mcmcpars; public McmcStep mcmcStep = new McmcStep(); /** PostprocessManager that handles the postprocessing modules. */ public PostprocessManager postprocMan; /** Manager that handles model extension plugins */ public ModelExtManager modelExtMan; /** McmcModule containing the moves for the core components of * the model, i.e. the indel parameters, substitution model parameters, * alignment, topology and edge lengths. * The coreModel also decides whether to execute MCMC moves from the * ModelExtension modules. */ public McmcModule coreModel; /** * Interval (in terms of number of samples) * at which current postprocessing information is flushed to file * and MCMC info is printed to stdout. */ int LOG_INTERVAL = 100; /** True while the MCMC is in the burn-in phase. */ public boolean burnin; /** True while the MCMC is in the first half of burn-in phase. */ public boolean firstHalfBurnin; public Mcmc(Tree tree, MCMCPars pars, PostprocessManager ppm, ModelExtManager modelExtMan) { postprocMan = ppm; this.modelExtMan = modelExtMan; ppm.mcmc = this; this.modelExtMan.setMcmc(this); this.tree = tree; tree.owner = this; mcmcpars = pars; this.tree.heat = 1.0d; randomisationPeriod = mcmcpars.randomisationPeriod; } private static final DecimalFormat df = new DecimalFormat("0.0000"); /** * Initialises the core McmcModule. This method is to be implemented by * specific instances of the Mcmc class. * @param tree */ protected abstract void initCoreModel(Tree tree); /** * Triggers a call to <tt>coreModel</tt> to propose a move from one * of the McmcMove objects, selected according to its weight. * If there are active ModelExtension plugins, then <tt>coreModel</tt> * will delegate the sampling to the ModelExtensionManager with * probability proportional to the sum of the weights of the * active plugins. * * @param samplingMethod Currently unused. * @throws StoppedException */ private void sample(int samplingMethod) throws StoppedException { stoppable(); if(Utils.DEBUG) { System.out.println("tree.getLogLike() (BEFORE) = "+tree.getLogLike()); tree.recomputeCheckLogLike(); if(Math.abs(modelExtMan.totalLogLike(tree)-totalLogLike) > 1e-5) { System.out.println("\nBefore: "+modelExtMan.totalLogLike(tree)+" "+totalLogLike); throw new Error("Log-likelihood inconsistency at start of sample()"); } } boolean accepted = coreModel.proposeParamChange(tree); if (accepted) { if (Utils.DEBUG) System.out.println("\t\tMove accepted."); totalLogLike = coreModel.curLogLike; } else { if (Utils.DEBUG) System.out.println("Move rejected."); coreModel.setLogLike(totalLogLike); } if(Utils.DEBUG) { tree.recomputeCheckLogLike(); tree.checkPointers(); if(Math.abs(modelExtMan.totalLogLike(tree)-totalLogLike) > 1e-5) { System.out.println("After: "+modelExtMan.totalLogLike(tree)+" "+totalLogLike); throw new Error("Log-likelihood inconsistency at end of sample()"); } } } /** * * @param aligned Vector indicating which characters are aligned to the current * column in the subtrees below. * @return Logarithm of emission probability for subtrees */ double calcEm(int[] aligned) { return modelExtMan.calcLogEm(aligned); } /** * This function is called by the McmcMove objects in order to determine whether * the proposed moves are to be accepted. * * @param logProposalRatio This also includes the contribution from the prior densities. * It is assumed that any dependencies between the priors and other parameters will be * handled inside the McmcMove objects. * @return true if the move is accepted */ public boolean isParamChangeAccepted(double logProposalRatio,McmcMove m) { double newLogLike = coreModel.curLogLike; if (Utils.SHAKE_IF_STUCK && firstHalfBurnin && (m.lowCounts > Utils.LOW_COUNT_THRESHOLD)) { newLogLike *= Math.pow(Utils.LOW_COUNT_MULTIPLIER,m.lowCounts-Utils.LOW_COUNT_THRESHOLD); } acceptAllMoves = firstHalfBurnin && m.acceptAllDuringFirstHalfBurnin; boolean accept = acceptanceDecision(totalLogLike,newLogLike,logProposalRatio,acceptAllMoves); if (accept) m.lowCounts = 0; return accept; } public boolean acceptanceDecision(double oldLogLikelihood, double newLogLikelihood, double logProposalRatio, boolean acceptMoveIfPossible) { if (Utils.DEBUG) System.out.print("logLikelihoodRatio = "+(newLogLikelihood-oldLogLikelihood)); if (logProposalRatio > Double.NEGATIVE_INFINITY) { cumulativeLogProposalRatio += logProposalRatio; } else { return false; } if (Utils.DEBUG) System.out.println("\tlogProposalRatio = "+logProposalRatio); if (acceptMoveIfPossible) { return (newLogLikelihood > Double.NEGATIVE_INFINITY); } return (Math.log(Utils.generator.nextDouble()) < (cumulativeLogProposalRatio + tree.heat*(newLogLikelihood - oldLogLikelihood)) + (cumulativeLogProposalRatio=0)); } /** * Returns a {@link State} object that describes the current state of the * MCMC. This can then be passed on to other classes such as postprocessing * plugins. */ public State getState() { return tree.getState(); } protected void beginRandomisationPeriod() { } protected void endRandomisationPeriod() { } /** * Additional initialisation routines before starting MCMC, if required. */ protected void beforeMCMC() { System.out.println("Starting MCMC...\n"); Utils.generator = new Well19937c(mcmcpars.seed); } /** * @return Always <code>true</true> if running non-parallel version; * when running parallel version, <code>true</code> if this is the * master chain. */ protected boolean isMaster() { return true; } /** * Starts an MCMC run. * * If <tt>AutomateParameters.shouldAutomateProposalVariances() = true</tt> * then the proposal distributions will be automatically adjusted during the * burnin. * * If <tt>AutomateParameters.shouldAutomateNumberOfSamples() = true</tt> * or <tt>AutomateParameters.shouldAutomateStepRate() = true</tt> * or <tt>AutomateParameters.shouldAutomateBurnin() = true</tt> * then these parameters will be adjusted automatically, although this * approach may affect the theoretical convergence properties of the MCMC * chain, so this type of automation should be regarded more as a quick way * of getting some initial results without tweaking the parameters. * * This function also calls the appropriate functions of the PostpocessManager * <tt>postprocMan</tt> to trigger data transfer to postprocessing modules * when necessary */ public int doMCMC() { beforeMCMC(); MainFrame frame = postprocMan.mainManager.frame; //edgeWeight *= tree.vertex.length; coreModel = new CoreMcmcModule(this,modelExtMan); initCoreModel(tree); // notifies MCMC modules (including plugins) of start of MCMC sampling coreModel.beforeSampling(tree); // Triggers a /before first sample/ of the plugins. if (isMaster()) { postprocMan.beforeFirstSample(); } long currentTime, start = System.currentTimeMillis(); // calculates initial log-likelihood (includes coreModel likelihood) totalLogLike = modelExtMan.totalLogLike(tree); ArrayList<Double> logLikeList = new ArrayList<Double>(); int errorCode = 0; try { stoppable(); // Recompute progressive alignment now that everything is initialised. //TreeAlgo treeAlgo = new TreeAlgo(); //treeAlgo.alignSeqsRec(tree.root); //only to use if AutomateParameters.shouldAutomate() == true // final int SAMPLE_RATE_WHEN_DETERMINING_THE_SPACE = 100; // final int BURNIN_TO_CALCULATE_THE_SPACE = 25000; // ArrayList<String[]> alignmentsFromSamples = new ArrayList<String[]>(); int burnIn = mcmcpars.burnIn; // boolean stopBurnIn = false; // if(AutomateParameters.shouldAutomateBurnIn()){ // burnIn = 10000000; // } // if(AutomateParameters.shouldAutomateStepRate()){ // burnIn += BURNIN_TO_CALCULATE_THE_SPACE; // } // Randomise the initial starting configuration // by accepting all moves for a period. tree.root.recomputeLogLike(); // For testing totalLogLike = modelExtMan.totalLogLike(tree); if (randomisationPeriod > 0) { System.out.println("Randomising initial configuration for "+randomisationPeriod+" steps."); acceptAllMoves = true; beginRandomisationPeriod(); for (int i = 0; i < randomisationPeriod; i++) { sample(0); } endRandomisationPeriod(); acceptAllMoves = false; } burnin = true; firstHalfBurnin = true; tree.root.recomputeLogLike(); totalLogLike = modelExtMan.totalLogLike(tree); for (int i = 0; i < burnIn; i++) { if (firstHalfBurnin && i > burnIn / 2) { firstHalfBurnin = false; coreModel.afterFirstHalfBurnin(); modelExtMan.afterFirstHalfBurnin(); coreModel.incrementWeights(); modelExtMan.incrementWeights(); if (simulatedAnnealing) { tree.heat = 1; } } else { if (simulatedAnnealing) { tree.heat = Math.log(i) / Math.log(burnIn / 2); } } // Perform an MCMC move sample(0); // Triggers a /new step/ and a /new peek/ (if appropriate) of // the plugins. if (isMaster()) { // TODO do above inside sample() and add more info mcmcStep.newLogLike = modelExtMan.totalLogLike(tree); mcmcStep.burnIn = burnin; postprocMan.newStep(mcmcStep); if (i % mcmcpars.sampRate == 0) { postprocMan.newPeek(); } } if (i>0 && mcmcpars.doReportDuringBurnin && (i % mcmcpars.sampRate == 0)) { report(i, mcmcpars.cycles / mcmcpars.sampRate); } // if(AutomateParameters.shouldAutomateBurnIn() && i % 50 == 0){ // // every 50 steps, add the current loglikelihood to a list // // and check if we find a major decline in that list // logLikeList.add(getState().logLike); // if(!stopBurnIn){ // stopBurnIn = AutomateParameters.shouldStopBurnIn(logLikeList); // if(AutomateParameters.shouldAutomateStepRate() && stopBurnIn){ // burnIn = i + BURNIN_TO_CALCULATE_THE_SPACE; // }else if (stopBurnIn){ // burnIn = i; // } // } // } currentTime = System.currentTimeMillis(); // int realBurnIn = burnIn - BURNIN_TO_CALCULATE_THE_SPACE; if (frame != null) { String text = ""; // if((i > realBurnIn ) && AutomateParameters.shouldAutomateStepRate()){ // text = "Burn-in to aid automation of MCMC parameters: " + (i-realBurnIn + 1) ; // }else{ text = "Burn-in: " + (i + 1); // } frame.statusText.setText(text); } else if (i % 1000 == 999) { if (Utils.DEBUG) { System.err.println("Burn in: " + (i + 1)); } else { System.out.println("Burn in: " + (i + 1)); } } // if( AutomateParameters.shouldAutomateStepRate() && (i >= realBurnIn) && i % SAMPLE_RATE_WHEN_DETERMINING_THE_SPACE == 0) { // String[] align = getState().getLeafAlign(); // alignmentsFromSamples.add(align); // } if (AutomateParameters.shouldAutomateProposalVariances() && i % mcmcpars.sampRate == 0) { coreModel.modifyProposalWidths(); modelExtMan.modifyProposalWidths(); } } //both real burn-in and the one to determine the sampling rate have now been completed. burnin = false; coreModel.afterBurnin(); modelExtMan.afterBurnin(); coreModel.zeroAllMoveCounts(); modelExtMan.zeroAllMoveCounts(); //Utils.DEBUG = true; int period; // if(AutomateParameters.shouldAutomateNumberOfSamples()){ // period = 1000000; // }else{ period = mcmcpars.cycles / mcmcpars.sampRate; // } int sampRate; // if(AutomateParameters.shouldAutomateStepRate()){ // if(frame != null) // { // frame.statusText.setText("Calculating the sample rate"); // } // else // { // System.out.println("Calculating the sample rate"); // } // ArrayList<Double> theSpace = Distance.spaceAMA(alignmentsFromSamples); // sampRate = AutomateParameters.getSampleRateOfTheSpace(theSpace,SAMPLE_RATE_WHEN_DETERMINING_THE_SPACE); // // }else{ sampRate = mcmcpars.sampRate; // } // AlignmentData alignment = new AlignmentData(getState().getLeafAlign()); // ArrayList<AlignmentData> allAlignments = new ArrayList<AlignmentData>(); // ArrayList<Double> distances = new ArrayList<Double>(); boolean shouldStop = false; // double currScore = 0; for (int i = 0; i < period && !shouldStop; i++) { if (i > 0 && (i % LOG_INTERVAL == 0)) { postprocMan.flushAll(); if (coreModel.printExtraInfo) printMcmcInfo(); } for (int j = 0; j < sampRate; j++) { // Perform an MCMC move sample(0); // Proposes a swap. if (isParallel && ((i*sampRate + j) % mcmcpars.swapRate == 0)) { doSwap(); } // Triggers a /new step/ and a /new peek/ (if appropriate) // of the plugins. if (isMaster()) { mcmcStep.newLogLike = totalLogLike; mcmcStep.burnIn = burnin; postprocMan.newStep(mcmcStep); if (burnIn + i * period + j % mcmcpars.sampRate == 0) { postprocMan.newPeek(); } } currentTime = System.currentTimeMillis(); if (frame != null) { String text = "Samples taken: " + Integer.toString(i); // //remainingTime((currentTime - start) // // * ((period - i - 1) * sampRate // // + sampRate - j - 1) // // / (burnIn + i * sampRate + j + 1)) // text += " The sampling rate: " + sampRate; // if(AutomateParameters.shouldAutomateNumberOfSamples()){ // text += ", Similarity(alignment n-1, alignment n): " + df.format(currScore) + " < " + df.format(AutomateParameters.PERCENT_CONST); // } frame.statusText.setText(text ); } } if (frame == null && !isParallel) { if (Utils.DEBUG) { System.err.println("Sample: " + (i + 1)); } else { System.out.println("Sample: " + (i + 1)); } } // if(AutomateParameters.shouldAutomateNumberOfSamples()){ // alignment = new AlignmentData(getState().getLeafAlign()); // allAlignments.add(alignment); // if (allAlignments.size() >1){ // FuzzyAlignment Fa = FuzzyAlignment.getFuzzyAlignmentAndProject(allAlignments.subList(0, allAlignments.size()-1), 0); // FuzzyAlignment Fb = FuzzyAlignment.getFuzzyAlignmentAndProject(allAlignments, 0); // currScore = FuzzyAlignment.AMA(Fa, Fb); // System.out.println(currScore); // distances.add(currScore); // if (allAlignments.size() >5){ // shouldStop = AutomateParameters.shouldStopSampling(distances); // } // // } // } // Report the results of the sample. report(i, period); } } catch (StoppedException ex) { errorCode = 1; // stopped: report and save state } //if(Utils.DEBUG) { printMcmcInfo(); //} // Triggers a /after first sample/ of the plugins. if (isMaster()) { postprocMan.afterLastSample(); } // notifies model extension plugins of the end of sampling modelExtMan.afterSampling(); System.out.println(coreModel.getSummaryInfo()); coreModel.afterSampling(); if (frame != null) { frame.statusText.setText(MainFrame.IDLE_STATUS_MESSAGE); } return errorCode; } public String getInfoString() { return coreModel.getSummaryInfo(); } private void printMcmcInfo() { String info = "\n"+Utils.repeatedString("#",64)+"\n"; info += String.format("%-24s","Move name")+ String.format("%8s","t")+ String.format("%8s","nMoves")+ String.format("%8s","t/move")+ String.format("%8s", "acc")+ String.format("%8s\n", "propVar"); info += coreModel.getMcmcInfo(); info += modelExtMan.getMcmcInfo(); info += Utils.repeatedString("#",64)+"\n"; System.out.println(info); } /** * Triggers <tt>postProcMan</tt> to print out a report of the current * state of the chain. * * @param no * @param total */ private void report(int no, int total) { report(no,total,true); } private void reportDuringBurnin(int no, int total) { report(no,total,false); } protected void report(int no, int total, boolean useSample) { if (useSample) postprocSample(no,total); // Log the accept ratios/params to the (.log) file. TODO: move to a plugin. try { if (isMaster()) { postprocMan.logFile.write(coreModel.getSummaryInfo() + "\n"); coreModel.printParameters(); } } catch (IOException e) { if (postprocMan.mainManager.frame != null) { ErrorMessage.showPane(postprocMan.mainManager.frame, e, true); } else { e.printStackTrace(System.out); } } } protected void postprocSample(int no, int total) { postprocMan.newSample(coreModel,getState(), no, total); } // This function is used (and defined) only by the parallel version protected void doSwap() { } }