package beast.core;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintStream;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Formatter;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;
import org.json.JSONArray;
import org.json.JSONException;
import org.json.JSONObject;
import beast.core.util.Log;
import beast.util.Randomizer;
@Description("Specify operator selection and optimisation schedule")
public class OperatorSchedule extends BEASTObject {
enum OptimisationTransform {none, log, sqrt}
final public Input<OptimisationTransform> transformInput = new Input<>("transform",
"transform optimisation schedule (default none) This can be "
+ Arrays.toString(OptimisationTransform.values()) + " (default 'none')",
OptimisationTransform.none, OptimisationTransform.values());
final public Input<Boolean> autoOptimiseInput = new Input<>("autoOptimize", "whether to automatically optimise operator settings", true);
final public Input<Boolean> detailedRejectionInput = new Input<>("detailedRejection", "true if detailed rejection statistics should be included. (default=false)", false);
final public Input<Integer> autoOptimizeDelayInput = new Input<>("autoOptimizeDelay", "number of samples to skip before auto optimisation kicks in (default=10000)", 10000);
// the following inputs are for to deal with schedules nested inside other schedules
// this allows operators to be grouped, and a percentage of operator weights to be
// assigned to a group of operators.
final public Input<List<Operator>> operatorsInput = new Input<>("operator", "operator that the schedule can choose from. Any operators "
+ "added by other classes (e.g. MCMC) will be added if there are no duplicates.", new ArrayList<>());
final public Input<List<OperatorSchedule>> subschedulesInput = new Input<>("subschedule", "operator schedule representing a subset of"
+ "the weight of the operators it contains.", new ArrayList<>());
final public Input<Double> weightInput = new Input<>("weight", "weight with which this operator schedule is selected. Only used when "
+ "this operator schedule is nested inside other schedules. This weight is relative to other operators and operator schedules "
+ "of the parent schedule.", 100.0);
final public Input<Boolean> weightIsPercentageInput = new Input<>("weightIsPercentage", "indicates weight is a percentage of total weight instead of a relative weight", false);
final public Input<String> operatorPatternInput = new Input<>("operatorPattern", "Regular expression matching operator IDs of operators of parent schedule");
/**
* list of operators in the schedule *
*/
// temporary for play
public List<Operator> operators = new ArrayList<>();
/**
* sum of weight of operators *
*/
double totalWeight = 0;
/**
* cumulative weights, with unity as max value *
*/
double[] cumulativeProbs;
/**
* name of the file to store operator related info *
*/
String stateFileName;
/**
* Don't start optimisation at the start of the chain, but wait till
* autoOptimizeDelay has been reached.
*/
protected int autoOptimizeDelay = 10000;
protected int autoOptimizeDelayCount = 0;
OptimisationTransform transform = OptimisationTransform.none;
boolean autoOptimise = true;
boolean detailedRejection = false;
private boolean reweighted = false;
@Override
public void initAndValidate() {
transform = transformInput.get();
autoOptimise = autoOptimiseInput.get();
autoOptimizeDelay = autoOptimizeDelayInput.get();
detailedRejection = detailedRejectionInput.get();
operators.addAll(operatorsInput.get());
for (Operator o : operators) {
o.setOperatorSchedule(this);
}
// sanity check: make sure weight percentages add to less than 100%
double sumPercentage = 0;
for (OperatorSchedule o : subschedulesInput.get()) {
if (o.weightIsPercentageInput.get()) {
sumPercentage += o.weightInput.get();
}
}
if (sumPercentage > 100) {
throw new IllegalArgumentException("Sum of percentages of subschedules should not exceed 100%. Reduce the weight of subschedules.");
}
if (Math.abs(sumPercentage - 100) < 1e-6 && operators.size() > 0) {
throw new IllegalArgumentException("Sum of percentages of subschedules add to 100%, so operators in main schedule will be ignored. Reduce the weight of subschedules.");
}
// sanity check: warn if operators appear in multiple schedules
Set<Operator> allOperators = new LinkedHashSet<>();
allOperators.addAll(operators);
for (OperatorSchedule os : subschedulesInput.get()) {
for (Operator o : os.operators) {
if (allOperators.contains(o)) {
Log.warning("WARNING: Operator " + o.getID() + " is contained in multiple operator schedules.\n"
+ "Operator weighting may not work as expected.");
}
allOperators.add(o);
}
}
}
public void setStateFileName(final String name) {
this.stateFileName = name;
}
/**
* add operator to the schedule *
* @param p
*/
public void addOperator(final Operator p) {
// check for duplicates
for (Operator o : operators) {
if (o == p) {
// operator was already added earlier
return;
}
}
operators.add(p);
p.setOperatorSchedule(this);
reweighted = false;
totalWeight += p.getWeight();
}
/** used to add operators to subschedules matching a pattern **/
protected void addOperators(Collection<Operator> ops) {
if (operatorPatternInput.get() == null || operatorPatternInput.get().trim().equals("")) {
return;
}
String operatorPattern = operatorPatternInput.get();
for (Operator o : ops) {
if (o.getID() != null && o.getID().matches(operatorPattern)) {
for (Operator o2 : operators) {
if (o2 == o) {
// operator was already added earlier
return;
}
}
operators.add(o);
}
}
reweighted = false;
}
/**
* randomly select an operator with probability proportional to the weight
* of the operator
* @return
*/
public Operator selectOperator() {
if (!reweighted) {
reweightOperators();
reweighted = true;
}
final int operatorIndex = Randomizer.randomChoice(cumulativeProbs);
return operators.get(operatorIndex);
}
private static final String TUNING = "Tuning";
private static final String NUM_ACCEPT = "#accept";
private static final String NUM_REJECT = "#reject";
private static final String PR_M = "Pr(m)";
private static final String PR_ACCEPT = "Pr(acc|m)";
/**
* report operator statistics *
* @param out
*/
public void showOperatorRates(final PrintStream out) {
Formatter formatter = new Formatter(out);
int longestName = 0;
for (final Operator operator : operators) {
if (operator.getName().length() > longestName) {
longestName = operator.getName().length();
}
}
formatter.format("%-" + longestName + "s", "Operator");
int colWidth = 10;
String headerFormat = " %" + colWidth + "s";
formatter.format(headerFormat, TUNING);
formatter.format(headerFormat, NUM_ACCEPT);
formatter.format(headerFormat, NUM_REJECT);
if (detailedRejection) {
formatter.format(headerFormat, "rej.inv");
formatter.format(headerFormat, "rej.op");
}
formatter.format(headerFormat, PR_M);
formatter.format(headerFormat, PR_ACCEPT);
out.println();
for (final Operator operator : operators) {
out.println(prettyPrintOperator(operator, longestName, colWidth, 4, totalWeight, detailedRejection));
}
out.println();
formatter.format(headerFormat,TUNING);
out.println(": The value of the operator's tuning parameter, or '-' if the operator can't be optimized.");
formatter.format(headerFormat, NUM_ACCEPT);
out.println(": The total number of times a proposal by this operator has been accepted.");
formatter.format(headerFormat, NUM_REJECT);
out.println(": The total number of times a proposal by this operator has been rejected.");
formatter.format(headerFormat, PR_M);
out.println(": The probability this operator is chosen in a step of the MCMC (i.e. the normalized weight).");
formatter.format(headerFormat, PR_ACCEPT);
out.println(": The acceptance probability (" + NUM_ACCEPT + " as a fraction of the total proposals for this operator).");
out.println();
formatter.close();
}
protected static String prettyPrintOperator(
Operator op,
int nameColWidth,
int colWidth,
int dp,
double totalWeight,
boolean detailedRejection) {
double tuning = op.getCoercableParameterValue();
double accRate = (double) op.m_nNrAccepted / (double) (op.m_nNrAccepted + op.m_nNrRejected);
StringBuilder sb = new StringBuilder();
Formatter formatter = new Formatter(sb);
String intFormat = " %" + colWidth + "d";
String doubleFormat = " %" + colWidth + "." + dp + "f";
formatter.format("%-" + nameColWidth + "s", op.getName());
if (!Double.isNaN(tuning)) {
formatter.format(doubleFormat, tuning);
} else {
formatter.format(" %" + colWidth + "s", "-");
}
formatter.format(intFormat, op.m_nNrAccepted);
formatter.format(intFormat, op.m_nNrRejected);
if (detailedRejection) {
formatter.format(doubleFormat, (double) op.m_nNrRejectedInvalid / (double) op.m_nNrRejected);
formatter.format(doubleFormat, (double) op.m_nNrRejectedOperator / (double) op.m_nNrRejected);
}
if (totalWeight > 0.0) {
formatter.format(doubleFormat, op.getWeight() / totalWeight);
}
formatter.format(doubleFormat, accRate);
sb.append(" " + op.getPerformanceSuggestion());
formatter.close();
return sb.toString();
}
/**
* store operator optimisation specific information to file *
* @throws IOException
*/
public void storeToFile() throws IOException {
// appends state of operator set to state file
File file = new File(stateFileName);
PrintWriter out = new PrintWriter(new FileWriter(file, true));
out.println("<!--");
out.println("{\"operators\":[");
int k = 0;
for (Operator operator: operators) {
operator.storeToFile(out);
if (k++ < operators.size() - 1) {
out.println(",");
}
}
out.println("\n]}");
out.println("-->");
out.flush();
out.close();
}
/**
* restore operator optimisation specific information from file *
* @throws IOException
*/
public void restoreFromFile() throws IOException {
// reads state of operator set from state file
String xml = "";
final BufferedReader fin = new BufferedReader(new FileReader(stateFileName));
while (fin.ready()) {
xml += fin.readLine() + "\n";
}
fin.close();
int start = xml.indexOf("</itsabeastystatewerein>") + 25 + 5;
if (start >= xml.length() - 4) {
return;
}
xml = xml.substring(xml.indexOf("</itsabeastystatewerein>") + 25 + 5, xml.length() - 4);
try {
JSONObject o = new JSONObject(xml);
JSONArray operatorlist = o.getJSONArray("operators");
autoOptimizeDelayCount = 0;
for (int i = 0; i < operatorlist.length(); i++) {
JSONObject item = operatorlist.getJSONObject(i);
String id = item.getString("id");
boolean found = false;
if (!id.equals("null")) {
for (Operator operator: operators) {
if (id.equals(operator.getID())) {
operator.restoreFromFile(item);
autoOptimizeDelayCount += operator.m_nNrAccepted + operator.m_nNrRejected;
found = true;
break;
}
}
}
if (!found) {
Log.warning.println("Operator (" + id + ") found in state file that is not in operator list any more");
}
}
for (Operator operator: operators) {
if (operator.getID() == null) {
Log.warning.println("Operator (" + operator.getClass() + ") found in BEAST file that could not be restored because it has not ID");
}
}
} catch (JSONException e) {
// it is not a JSON file -- probably a version 2.0.X state file
String[] strs = xml.split("\n");
autoOptimizeDelayCount = 0;
for (int i = 0; i < operators.size() && i + 2 < strs.length; i++) {
String[] strs2 = strs[i + 1].split(" ");
Operator operator = operators.get(i);
if ((operator.getID() == null && strs2[0].equals("null")) || operator.getID().equals(strs2[0])) {
cumulativeProbs[i] = Double.parseDouble(strs2[1]);
if (!strs2[2].equals("NaN")) {
operator.setCoercableParameterValue(Double.parseDouble(strs2[2]));
}
operator.m_nNrAccepted = Integer.parseInt(strs2[3]);
operator.m_nNrRejected = Integer.parseInt(strs2[4]);
autoOptimizeDelayCount += operator.m_nNrAccepted + operator.m_nNrRejected;
operator.m_nNrAcceptedForCorrection = Integer.parseInt(strs2[5]);
operator.m_nNrRejectedForCorrection = Integer.parseInt(strs2[6]);
} else {
throw new RuntimeException("Cannot resume: operator order or set changed from previous run");
}
}
}
showOperatorRates(System.err);
}
/**
* Calculate change of coerceable parameter for operators that allow
* optimisation
*
* @param operator
* @param logAlpha difference in posterior between previous state & proposed
* state + hasting ratio
* @return change of value of a parameter for MCMC chain optimisation
*/
public double calcDelta(final Operator operator, final double logAlpha) {
// do no optimisation for the first N optimisable operations
if (autoOptimizeDelayCount < autoOptimizeDelay || !autoOptimise) {
autoOptimizeDelayCount++;
return 0;
}
final double target = operator.getTargetAcceptanceProbability();
double count = (operator.m_nNrRejectedForCorrection + operator.m_nNrAcceptedForCorrection + 1.0);
switch (transform) {
case log:
count = Math.log(count + 1.0);
break;
case sqrt:
count = Math.sqrt(count);
break;
case none:
break;
default:
break;
}
final double deltaP = ((1.0 / count) * (Math.exp(Math.min(logAlpha, 0)) - target));
if (deltaP > -Double.MAX_VALUE && deltaP < Double.MAX_VALUE) {
return deltaP;
}
return 0;
}
/**
* collect all operators (both local and from sub schedules) and calculate weight for each of them
* **/
private void reweightOperators() {
Set<Operator> allOperators = new LinkedHashSet<>();
Set<Operator> subOperators = new LinkedHashSet<>();
allOperators.addAll(operators);
for (OperatorSchedule os : subschedulesInput.get()) {
allOperators.addAll(os.operators);
}
for (OperatorSchedule os : subschedulesInput.get()) {
os.addOperators(allOperators);
subOperators.addAll(os.operators);
}
allOperators.addAll(subOperators);
Set<Operator> localOperators = new LinkedHashSet<>();
localOperators.addAll(allOperators);
localOperators.removeAll(subOperators);
// set up operators list and raw weights
operators.clear();
for (Operator o : localOperators) {
operators.add(o);
}
for (OperatorSchedule os : subschedulesInput.get()) {
for (Operator o : os.operators) {
operators.add(o);
}
}
// operatorCount can double count operators that appear in multiple operator schedules
int operatorCount = operators.size();
double [] weights = new double[operatorCount];
int i = 0;
for (Operator o : localOperators) {
weights[i++] = o.getWeight();
}
for (OperatorSchedule os : subschedulesInput.get()) {
for (Operator o : os.operators) {
weights[i++] = o.getWeight();
}
}
// calculate weights per OperatorSchedule
double localWeight = 0;
for (Operator o : localOperators) {
localWeight += o.getWeight();
}
double totalSubSchedulePercentage = 0;
double totalSubScheduleWeight = 0;
for (OperatorSchedule os : subschedulesInput.get()) {
if (os.weightIsPercentageInput.get()) {
totalSubSchedulePercentage += os.weightInput.get();
} else {
totalSubScheduleWeight += os.weightInput.get();
}
}
double totalWeight = totalSubSchedulePercentage >= 100 ? 100 :
(localWeight + totalSubScheduleWeight) * 100 / (100-totalSubSchedulePercentage);
// reweight local operators
double localFactor = (1/totalWeight);
i = 0;
for (Operator o : localOperators) {
weights[i++] *= localFactor;
}
// reweight operators of sub OperatorSchedules
for (OperatorSchedule os : subschedulesInput.get()) {
localWeight = 0;
for (Operator o : os.operators) {
localWeight += o.getWeight();
}
double factor;
if (!os.weightIsPercentageInput.get()) {
factor = (os.weightInput.get() / localWeight) * (1/totalWeight);
} else {
factor = (os.weightInput.get() / 100) * 1.0/localWeight;
}
for (Operator o : os.operators) {
weights[i++] *= factor;
}
}
// calc cumulative probabilities
cumulativeProbs = new double[weights.length];
cumulativeProbs[0] = weights[0];
for (i = 1; i < operators.size(); i++) {
cumulativeProbs[i] = weights[i] + cumulativeProbs[i - 1];
}
// log results
//Log.debug("operator weight cumulativeProbs");
//for (i = 0; i < operatorCount; i++) {
// Log.debug(operators.get(i).getID() + " " + weights[i] + " " + cumulativeProbs[i]);
//}
}
/** handy for unit tests **/
public double [] getCummulativeProbs() {
return cumulativeProbs.clone();
}
} // class OperatorSchedule