package beast.evolution.operators; import java.text.DecimalFormat; import beast.core.Description; import beast.core.Distribution; import beast.core.Input; import beast.core.Input.Validate; import beast.core.Operator; import beast.core.parameter.RealParameter; import beast.core.util.Evaluator; import beast.core.util.Log; import beast.util.Randomizer; @Description("A random walk operator that selects a random dimension of the real parameter and perturbs the value a " + "random amount within +/- windowSize.") public class SliceOperator extends Operator { final public Input<RealParameter> parameterInput = new Input<>("parameter", "the parameter to operate a random walk on.", Validate.REQUIRED); final public Input<Double> windowSizeInput = new Input<>("windowSize", "the size of the step for finding the slice boundaries", Input.Validate.REQUIRED); final public Input<Distribution> sliceDensityInput = new Input<>("sliceDensity", "The density to sample from using slice sampling.", Input.Validate.REQUIRED); Double totalDelta; int totalNumber; int n_learning_iterations; double W; double windowSize = 1; Distribution sliceDensity; @Override public void initAndValidate() { totalDelta = 0.0; totalNumber = 0; n_learning_iterations = 100; W = 0.0; windowSize = windowSizeInput.get(); sliceDensity = sliceDensityInput.get(); } boolean in_range(RealParameter X, double x) { return (X.getLower() < x && x < X.getUpper()); } boolean below_lower_bound(RealParameter X, double x) { return (x < X.getLower()); } boolean above_upper_bound(RealParameter X, double x) { return (x > X.getUpper()); } @Override public Distribution getEvaluatorDistribution() { return sliceDensity; } Double evaluate(Evaluator E) { return E.evaluate(); } Double evaluate(Evaluator E, RealParameter X, double x) { X.setValue(0, x); return evaluate(E); } Double[] find_slice_boundaries_stepping_out(Evaluator E, RealParameter X, double logy, double w, int m) { double x0 = X.getValue(0); assert in_range(X, x0); double u = Randomizer.nextDouble() * w; Double L = x0 - u; Double R = x0 + (w - u); // Expand the interval until its ends are outside the slice, or until // the limit on steps is reached. if (m > 1) { int J = (int) Math.floor(Randomizer.nextDouble() * m); int K = (m - 1) - J; while (J > 0 && (!below_lower_bound(X, L)) && evaluate(E, X, L) > logy) { L -= w; J--; } while (K > 0 && (!above_upper_bound(X, R)) && evaluate(E, X, R) > logy) { R += w; K--; } } else { while ((!below_lower_bound(X, L)) && evaluate(E, X, L) > logy) L -= w; while ((!above_upper_bound(X, R)) && evaluate(E, X, R) > logy) R += w; } // Shrink interval to lower and upper bounds. if (below_lower_bound(X, L)) L = X.getLower(); if (above_upper_bound(X, R)) R = X.getUpper(); assert L < R; Double[] range = {L, R}; return range; } // Does this x0 really need to be the original point? // I think it just serves to let you know which way the interval gets shrunk... double search_interval(Evaluator E, double x0, RealParameter X, Double L, Double R, double logy) { // assert evaluate(E,x0) > evaluate(E,L) && evaluate(E,x0) > evaluate(E,R); assert evaluate(E, X, x0) >= logy; assert L < R; assert L <= x0 && x0 <= R; double L0 = L, R0 = R; double gx0 = evaluate(E, X, x0); assert logy < gx0; double x1 = x0; for (int i = 0; i < 200; i++) { x1 = L + Randomizer.nextDouble() * (R - L); double gx1 = evaluate(E, X, x1); // System.err.println(" L0 = " + L0 + " x0 = " + x0 + " R0 = " + R0 + " gx0 = " + gx0); // System.err.println(" L = " + L + " x1 = " + x1 + " R = " + R0 + " gx1 = " + gx1); // System.err.println(" logy = " + logy); if (gx1 >= logy) return x1; if (x1 > x0) R = x1; else L = x1; } Log.warning.println("Warning! Is size of the interval really ZERO?"); // double logy_x0 = evaluate(E,X,x0); Log.warning.println(" L0 = " + L0 + " x0 = " + x0 + " R0 = " + R0 + " gx0 = " + gx0); Log.warning.println(" L = " + L + " x1 = " + x1 + " R = " + R0 + " gx1 = " + evaluate(E)); return x0; } @Override public double proposal() { return 0; } /** * override this for proposals, * returns log of hastingRatio, or Double.NEGATIVE_INFINITY if proposal should not be accepted * */ @Override public double proposal(Evaluator E) { int m = 100; RealParameter X = parameterInput.get(); // Find the density at the current point Double gx0 = evaluate(E); // System.err.println("gx0 = " + gx0); // Get the 1st element Double x0 = X.getValue(0); // System.err.println("x0 = " + x0); // Determine the slice level, in log terms. double logy = gx0 - Randomizer.nextExponential(1); // Find the initial interval to sample from. Double[] range = find_slice_boundaries_stepping_out(E, X, logy, windowSize, m); Double L = range[0]; Double R = range[1]; // Sample from the interval, shrinking it on each rejection double x_new = search_interval(E, x0, X, L, R, logy); X.setValue(x_new); if (n_learning_iterations > 0) { n_learning_iterations--; totalDelta += Math.abs(x_new - x0); totalNumber++; double W_predicted = totalDelta / totalNumber * 4.0; if (totalNumber > 3) { W = 0.95 * W + 0.05 * W_predicted; windowSize = W; } // System.err.println("W = " + W); } return Double.POSITIVE_INFINITY; } @Override public double getCoercableParameterValue() { return windowSize; } @Override public void setCoercableParameterValue(double value) { windowSize = value; } /** * called after every invocation of this operator to see whether * a parameter can be optimised for better acceptance hence faster * mixing * * @param logAlpha difference in posterior between previous state & proposed state + hasting ratio */ @Override public void optimize(double logAlpha) { // must be overridden by operator implementation to have an effect double delta = calcDelta(logAlpha); delta += Math.log(windowSize); windowSize = Math.exp(delta); } @Override public final String getPerformanceSuggestion() { // new scale factor double newWindowSize = totalDelta / totalNumber * 4; if (newWindowSize / windowSize < 0.8 || newWindowSize / windowSize > 1.2) { DecimalFormat formatter = new DecimalFormat("#.###"); return "Try setting window size to about " + formatter.format(newWindowSize); } else return ""; } } // class IntRandomWalkOperator