package statalign.mcmc;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import statalign.base.Mcmc;
import statalign.base.Tree;
import statalign.base.Utils;
/**
*
* Generic class for a group of McmcMove objects. Each ModelExtension
* is an instance of this class, as is the CoreMcmcModule
*
* @author herman
*
*/
public abstract class McmcModule {
protected Mcmc mcmc;
protected String moduleName = "";
public boolean logParametersToFile = false;
FileWriter parameterLog;
public McmcModule() { }
public McmcModule(String name) {
moduleName = name;
logParametersToFile = true;
}
public void setOutputFile(String baseFileName) {
try {
parameterLog = new FileWriter(baseFileName+moduleName+".params");
//System.out.println(baseFileName+moduleName+".params");
}
catch (IOException e) {
e.printStackTrace();
}
}
public void setMcmc(Mcmc m) {
mcmc = m;
}
public boolean printExtraInfo = false;
public boolean isFirstHalfBurnin() {
return mcmc.firstHalfBurnin;
}
public boolean isBurnin() {
return mcmc.burnin;
}
/** Current log-likelihood contribution */
public double curLogLike = 0;
protected List<McmcMove> mcmcMoves = new ArrayList<McmcMove>();
protected List<Integer> mcmcMoveWeights = new ArrayList<Integer>();
protected List<Integer> mcmcMoveWeightIncrements = new ArrayList<Integer>();
public int getParamChangeWeight() {
int w = 0;
for (int i=0; i<mcmcMoveWeights.size(); i++) {
w += mcmcMoveWeights.get(i);
}
return w;
}
public void setWeight(String name, int weight) {
for (int i=0; i<mcmcMoves.size(); i++) {
if (mcmcMoves.get(i).name.contains(name)) {
mcmcMoveWeights.set(i, weight);
if (printExtraInfo) System.out.println("Move \""+mcmcMoves.get(i).name+"\" now has weight "+weight);
}
}
}
public void addMcmcMove(McmcMove m, int weight) {
mcmcMoves.add(m);
mcmcMoveWeights.add(weight);
mcmcMoveWeightIncrements.add(0);
}
public void addMcmcMove(McmcMove m, int weight, int increment) {
mcmcMoves.add(m);
mcmcMoveWeights.add(weight);
mcmcMoveWeightIncrements.add(increment);
}
public List<McmcMove> getMcmcMoves() {
return mcmcMoves;
}
public void setAllMovesNotProposed() {
for (McmcMove mcmcMove : mcmcMoves) {
mcmcMove.moveProposed = false;
}
}
public void zeroAllMoveCounts() {
for (McmcMove mcmcMove : mcmcMoves) {
mcmcMove.proposalCount = 0;
mcmcMove.acceptanceCount = 0;
mcmcMove.lowCounts = 0;
}
}
public McmcMove getMcmcMove(String name) {
for (McmcMove mcmcMove : mcmcMoves) {
if (mcmcMove.name.equals(name)) {
return mcmcMove;
}
}
throw new RuntimeException("McmcMove "+name+" not found.");
}
public String getMcmcInfo() {
String info = "";
for (McmcMove mcmcMove : mcmcMoves) {
String infoFormat = "%-24s%8s%8d%8d%8.4f%8.4f\n";
info += String.format(Locale.US, infoFormat,
mcmcMove.name,
Utils.convertTime(mcmcMove.getTime()),
mcmcMove.proposalCount,
mcmcMove.getTime()/(mcmcMove.proposalCount>0 ? mcmcMove.proposalCount : 1),
mcmcMove.acceptanceRate(),
mcmcMove.proposalWidthControlVariable);
}
return info;
}
public String getSummaryInfo() {
String info = "Acceptance rates: ";
for (McmcMove m : mcmcMoves) {
info += m.name+": "+String.format(Locale.US, "%f ", m.acceptanceRate());
}
return info;
}
public void printParameters() {
if (logParametersToFile) {
String params = "";
for (McmcMove m : mcmcMoves) {
if (m.printableParam) {
if (params != "") params += ", ";
params += m.getParameterString();
}
}
if (params != null) {
try {
parameterLog.write(params+"\n");
}
catch (IOException e) {
e.printStackTrace();
}
}
}
}
/**
* Called before the start of MCMC sampling, but after the initial tree, alignment etc. have been
* generated. Override to initialise data structures etc.
* @param tree the starting tree
*/
public void beforeSampling(Tree tree) {
if (logParametersToFile) {
String paramNames = "";
for (McmcMove m : mcmcMoves) {
if (m.printableParam) {
if (paramNames != "") paramNames += ", ";
paramNames += m.getNameString();
}
}
if (paramNames != null) {
try {
if (parameterLog == null) System.out.println("null log");
parameterLog.write(paramNames+"\n");
}
catch (IOException e) {
e.printStackTrace();
}
}
}
}
public void afterSampling() {
if (parameterLog != null) {
try {
parameterLog.close();
}
catch (IOException e) {
e.printStackTrace();
}
}
for (McmcMove m : mcmcMoves) {
m.printInfo();
}
}
/**
* This should return the log of the model's contribution to the likelihood, it will be added on to
* the log-likelihood of the current point in the MCMC state space. Normally it will be called once at the
* initialisation of the MCMC process and from then on once in each MCMC step, when proposing any change.
* In debug mode, will be called more often (including after proposed changes) to ensure consistency.
* @param tree current tree
* @return log of model extension likelihood, conditional on current tree, alignment and params
*/
public abstract double logLikeFactor(Tree tree);
public double getLogLike() {
return curLogLike;
}
public void setLogLike(double ll) {
curLogLike = ll;
}
/**
* This should return the log of the total prior calculated for the model parameters. It is only used
* in parallel mode when proposing swaps between chains. By default returns 0.
*/
public double logPrior(Tree tree) {
return 0;
}
public boolean proposeParamChange(Tree tree) {
int selectedMoveIndex = Utils.weightedChoose(mcmcMoveWeights);
McmcMove selectedMove = mcmcMoves.get(selectedMoveIndex);
selectedMove.move(tree);
return selectedMove.lastMoveAccepted;
}
public void modifyProposalWidths() {
for (McmcMove m : mcmcMoves) {
if (!m.autoTune) { continue; }
//System.out.println(m.name+" ("+m.acceptanceCount+"/"+m.proposalCount+") "+m.proposalWidthControlVariable+" ");
if (m.proposalCount > Utils.MIN_SAMPLES_FOR_ACC_ESTIMATE) {
if (m.acceptanceRate() < m.minAcceptance &&
m.proposalWidthControlVariable >= m.minProposalWidthControlVariable) {
m.proposalWidthControlVariable *= m.spanMultiplier;
m.proposalCount = 0;
m.acceptanceCount = 0;
m.lowCounts++;
}
else if (m.acceptanceRate() > m.maxAcceptance &&
m.proposalWidthControlVariable <= m.maxProposalWidthControlVariable) {
m.proposalWidthControlVariable /= m.spanMultiplier;
m.proposalCount = 0;
m.acceptanceCount = 0;
m.lowCounts = 0;
}
else m.lowCounts = 0;
}
}
}
public boolean isParamChangeAccepted(double logProposalRatio,McmcMove m) {
return mcmc.isParamChangeAccepted(logProposalRatio,m);
}
public void incrementWeights() {
for (int i=0; i<mcmcMoves.size(); i++) {
if (mcmcMoveWeightIncrements.get(i) != 0) {
mcmcMoveWeights.set(i,mcmcMoveWeights.get(i)+
mcmcMoveWeightIncrements.get(i));
if (printExtraInfo) System.out.println("Move \""+mcmcMoves.get(i).name+"\" now has weight "+mcmcMoveWeights.get(i));
}
}
}
public void afterFirstHalfBurnin() {
for (McmcMove m : mcmcMoves) m.afterFirstHalfBurnin();
}
public void afterBurnin() {
for (McmcMove m : mcmcMoves) m.afterBurnin();
}
}