package statalign.base.mcmc;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Random;
import org.apache.commons.math3.random.Well19937c;
import mpi.MPI;
import statalign.MPIUtils;
import statalign.base.MCMCPars;
import statalign.base.Mcmc;
import statalign.base.State;
import statalign.base.Tree;
import statalign.base.Utils;
import statalign.mcmc.BetaPrior;
import statalign.mcmc.GammaPrior;
import statalign.mcmc.GaussianProposal;
import statalign.mcmc.LogisticProposal;
import statalign.mcmc.McmcCombinationMove;
import statalign.mcmc.McmcMove;
import statalign.mcmc.MultiplicativeProposal;
import statalign.mcmc.UniformProposal;
import statalign.model.ext.ModelExtManager;
import statalign.postprocess.PostprocessManager;
/**
*
* Contains the specifics of the MCMC scheme for StatAlign. Currently this
* includes various hard-coded parameters, and choice of move types.
*
* @author herman
*
*/
public class StatAlignParallelMcmc extends StatAlignMcmc {
/** The number of processes. */
protected int noOfProcesses;
/** The rank of the process. */
protected int rank;
/** The random number generator used for swapping. */
protected Random swapGenerator;
public StatAlignParallelMcmc(Tree tree, MCMCPars pars, PostprocessManager ppm, ModelExtManager modelExtMan,
int noOfProcesses, int rank, double heat) {
super(tree,pars,ppm,modelExtMan);
this.noOfProcesses = noOfProcesses;
this.rank = rank;
this.tree.heat = heat;
isParallel = true;
}
protected boolean isMaster() {
return MPIUtils.isMaster(rank);
}
protected void postprocSample(int no, int total) {
int[] ranks = new int[] { (isColdChain() ? rank : 0) };
int[] coldChainLoc = new int[1];
int coldChainLocation = -1;
MPI.COMM_WORLD.Reduce(ranks, 0, coldChainLoc, 0, 1, MPI.INT, MPI.SUM, 0);
coldChainLocation = coldChainLoc[0];
// TODO: Remove - for debugging purposes
if (MPIUtils.isMaster(rank)) {
MPIUtils.println(rank, "Cold chain is at: " + coldChainLocation);
}
if (isColdChain() && MPIUtils.isMaster(rank)) {
// Sample normally.
postprocMan.newSample(coreModel,getState(), no, total);
} else if (isColdChain() && !MPIUtils.isMaster(rank)) {
// Send state.
State state = getState();
MPIStateSend(state);
} else if (!isColdChain() && MPIUtils.isMaster(rank)) {
// Receive state.
State state = MPIStateReceieve(coldChainLocation);
postprocMan.newSample(coreModel,state, no, total);
}
if (MPIUtils.isMaster(rank)) {
try {
postprocMan.logFile.write("Cold chain location: " + coldChainLocation + "\n");
}
catch (IOException e) {
e.printStackTrace();
}
}
}
protected void beforeMCMC() {
String str = String.format(
"Starting MCMC chain no. %d/%d (heat: %.2f)\n\n",
rank + 1, noOfProcesses, tree.heat);
MPIUtils.println(rank, str);
swapGenerator = new Random(mcmcpars.swapSeed);
Utils.generator = new Well19937c(mcmcpars.seed + rank);
}
protected void doSwap() {
int swapA, swapB;
swapA = swapGenerator.nextInt(noOfProcesses);
do {
swapB = swapGenerator.nextInt(noOfProcesses);
} while (swapA == swapB);
System.out.printf("SwapNo: %d - SwapA: %d - SwapB: %d\n",swapA, swapB);
double swapAccept = swapGenerator.nextDouble();
if (rank == swapA || rank == swapB) {
double[] myStateInfo = new double[3];
myStateInfo[0] = totalLogLike;
myStateInfo[1] = modelExtMan.totalLogPrior(tree);
//myStateInfo[1] = coreModel.totalLogPrior(tree) + modelExtMan.totalLogPrior(tree);
myStateInfo[2] = tree.heat;
double[] partnerStateInfo = new double[3];
mpi.Request send, recieve;
if (rank == swapA) {
send = MPI.COMM_WORLD.Isend(myStateInfo, 0, 3, MPI.DOUBLE,
swapB, 0);
recieve = MPI.COMM_WORLD.Irecv(partnerStateInfo, 0, 3,
MPI.DOUBLE, swapB, 1);
} else {
send = MPI.COMM_WORLD.Isend(myStateInfo, 0, 3, MPI.DOUBLE,
swapA, 1);
recieve = MPI.COMM_WORLD.Irecv(partnerStateInfo, 0, 3,
MPI.DOUBLE, swapA, 0);
}
mpi.Request.Waitall(new mpi.Request[] { send, recieve });
System.out
.printf("[Worker %d] Heat: [%f] - Sent: [%f,%f,%f] - Recv: [%f,%f,%f]\n",
rank, tree.heat, myStateInfo[0], myStateInfo[1],
myStateInfo[2], partnerStateInfo[0],
partnerStateInfo[1], partnerStateInfo[2]);
double myLogLike = myStateInfo[0];
double myLogPrior = myStateInfo[1];
double myTemp = myStateInfo[2];
double hisLogLike = partnerStateInfo[0];
double hisLogPrior = partnerStateInfo[1];
double hisTemp = partnerStateInfo[2];
double acceptance = myTemp * (hisLogLike + hisLogPrior) + hisTemp
* (myLogLike + myLogPrior);
acceptance -= hisTemp * (hisLogLike + hisLogPrior) + myTemp
* (myLogLike + myLogPrior);
MPIUtils.println(rank,
"Math.log(swapAccept): " + Math.log(swapAccept));
MPIUtils.println(rank, "acceptance: "
+ acceptance);
if (acceptance > Math.log(swapAccept)) {
MPIUtils.println(rank,
"Just swapped heat with my partner. New heat: "
+ hisTemp);
tree.heat = hisTemp;
}
// MPI.COMM_WORLD.Send(myStateInfo, 0, 3, MPI.DOUBLE,
// swapB, 0);
// statalign.Utils.printLine(swapA, "Just sent " + swapB
// + " my state.");
}
}
protected boolean isColdChain() {
return tree.heat == 1.0d;
}
protected State MPIStateReceieve(int peer) {
// Creates a new, uninitialized state and initializes the variables.
State state = new State(tree.vertex.length);
// We already know the names
for (int i = 0; i < state.nl; i++) {
state.name[i] = tree.vertex[i].name;
}
int nn = state.nn;
int tag = 0;
// left
MPI.COMM_WORLD.Recv(state.left, 0, nn, MPI.INT, peer, tag++);
// right
MPI.COMM_WORLD.Recv(state.right, 0, nn, MPI.INT, peer, tag++);
// parent
MPI.COMM_WORLD.Recv(state.parent, 0, nn, MPI.INT, peer, tag++);
// edgeLen
MPI.COMM_WORLD.Recv(state.edgeLen, 0, nn, MPI.DOUBLE, peer, tag++);
// sequences
int[] seqLengths = new int[nn];
MPI.COMM_WORLD.Recv(seqLengths, 0, nn, MPI.INT, peer, tag++);
for (int i = 0; i < nn; i++) {
char[] c = new char[seqLengths[i]];
MPI.COMM_WORLD.Recv(c, 0, seqLengths[i], MPI.CHAR, peer, tag++);
state.seq[i] = new String(c);
}
// align
Object[] recvObj = new Object[1];
MPI.COMM_WORLD.Recv(recvObj, 0, 1, MPI.OBJECT, peer, tag++);
state.align = (int[][]) recvObj[0];
// felsen
MPI.COMM_WORLD.Recv(recvObj, 0, 1, MPI.OBJECT, peer, tag++);
state.felsen = (double[][][]) recvObj[0];
// indelParams
final int noOfIndelParameter = 3;
state.indelParams = new double[noOfIndelParameter];
MPI.COMM_WORLD.Recv(state.indelParams, 0, noOfIndelParameter,
MPI.DOUBLE, peer, tag++);
// substParams
int l = tree.substitutionModel.params.length;
state.substParams = new double[l];
MPI.COMM_WORLD.Recv(state.substParams, 0, l, MPI.DOUBLE, peer, tag++);
// log-likelihood
double[] d = new double[1];
MPI.COMM_WORLD.Recv(d, 0, 1, MPI.DOUBLE, peer, tag++);
state.logLike = d[0];
// root
int[] root = new int[1];
MPI.COMM_WORLD.Recv(root, 0, 1, MPI.INT, peer, tag++);
state.root = root[0];
return state;
}
protected void MPIStateSend(State state) {
String[] seq = state.seq;
int[][] align = state.align;
double[][][] felsen = state.felsen;
int nn = state.nn;
int tag = 0;
// left
MPI.COMM_WORLD.Send(state.left, 0, nn, MPI.INT, 0, tag++);
// right
MPI.COMM_WORLD.Send(state.right, 0, nn, MPI.INT, 0, tag++);
// parent
MPI.COMM_WORLD.Send(state.parent, 0, nn, MPI.INT, 0, tag++);
// edgeLen
MPI.COMM_WORLD.Send(state.edgeLen, 0, nn, MPI.DOUBLE, 0, tag++);
// TODO: START OF OPTIMIZATION.
// sequences
int[] seqLength = new int[nn];
char[][] seqChars = new char[nn][];
for (int i = 0; i < nn; i++) {
seqLength[i] = seq[i].length();
seqChars[i] = seq[i].toCharArray();
}
MPI.COMM_WORLD.Send(seqLength, 0, nn, MPI.INT, 0, tag++);
for (int i = 0; i < nn; i++) {
MPI.COMM_WORLD.Send(seqChars[i], 0, seqLength[i], MPI.CHAR, 0, tag++);
}
// align
Object[] alignObj = new Object[1];
alignObj[0] = align;
MPI.COMM_WORLD.Send(alignObj, 0, 1, MPI.OBJECT, 0, tag++);
/*
* int[] alignLength = new int[align.length]; for (int i = 0; i <
* seq.length; i++) { alignLength[i] = align[i].length; }
* MPI.COMM_WORLD.Send(alignLength, 0, nn, MPI.INT, 0, tag++); for (int
* i = 0; i < align.length; i++) { MPI.COMM_WORLD.Send(align[i], 0,
* alignLength[i], MPI.INT, 0, tag++); }
*/
// felsen
Object[] felsenObj = new Object[] { felsen };
MPI.COMM_WORLD.Send(felsenObj, 0, 1, MPI.OBJECT, 0, tag++);
// indelParams
MPI.COMM_WORLD.Send(state.indelParams, 0, 3, MPI.DOUBLE, 0, tag++);
// substParams
MPI.COMM_WORLD.Send(state.substParams, 0, state.substParams.length,
MPI.DOUBLE, 0, tag++);
// loglikelihood
MPI.COMM_WORLD.Send(new double[] { state.logLike }, 0, 1, MPI.DOUBLE,
0, tag++);
// root
MPI.COMM_WORLD.Send(new int[] { state.root }, 0, 1, MPI.INT, 0, tag++);
// TODO: END OF OPTIMIZATION.
}
}