package beast.evolution.tree.coalescent; import java.io.PrintStream; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import beast.core.Description; import beast.core.Input; import beast.core.Loggable; import beast.core.parameter.BooleanParameter; import beast.core.parameter.RealParameter; /** * @author joseph * Date: 26/08/2010 */ @Description("An effective population size function based on coalecent times from a set of trees.") public class CompoundPopulationFunction extends PopulationFunction.Abstract implements Loggable { final public Input<RealParameter> popSizeParameterInput = new Input<>("populationSizes", "population value at each point.", Input.Validate.REQUIRED); final public Input<BooleanParameter> indicatorsParameterInput = new Input<>("populationIndicators", "Include/exclude population value from the population function.", Input.Validate.REQUIRED); final public Input<List<TreeIntervals>> treesInput = new Input<>("itree", "Coalecent intervals of this tree are " + "used in the compound population function.", new ArrayList<>(), Input.Validate.REQUIRED); final public Input<String> demographicTypeInput = new Input<>("type", "Flavour of demographic: either linear or stepwise for " + " piecewise-linear or piecewise-constant.", "linear"); final public Input<Boolean> useMiddleInput = new Input<>("useIntervalsMiddle", "When true, the demographic X axis points are " + "in the middle of the coalescent intervals. By default they are at the beginning.", false); private RealParameter popSizeParameter; private BooleanParameter indicatorsParameter; private List<TreeIntervals> trees; private Type type; private boolean useMid; public enum Type { LINEAR("linear"), //EXPONENTIAL("exponential"), STEPWISE("stepwise"); Type(String name) { this.name = name; } @Override public String toString() { return name; } String name; } private void getParams() { popSizeParameter = popSizeParameterInput.get(); indicatorsParameter = indicatorsParameterInput.get(); assert popSizeParameter != null && popSizeParameter.getArrayValue(0) > 0 && indicatorsParameter != null; } // why do we need this additional level on top of initAndValidate - does not seem to do anything? @Override public void prepare() { getParams(); // is that safe??? trees = treesInput.get(); useMid = useMiddleInput.get(); // used to work without upper case ??? type = Type.valueOf(demographicTypeInput.get().toUpperCase()); // errors? // set lengths int events = 0; for (TreeIntervals ti : trees) { // number of coalescent events events += ti.treeInput.get().getLeafNodeCount() - 1; } // all trees share time 0, need fixing for serial data events += type == Type.STEPWISE ? 0 : 1; try { if (popSizeParameter.getDimension() != events) { final RealParameter p = new RealParameter(); p.initByName("value", popSizeParameter.getValue() + "", "upper", popSizeParameter.getUpper(), "lower", popSizeParameter.getLower(), "dimension", events); p.setID(popSizeParameter.getID()); popSizeParameter.assignFromWithoutID(p); } if (indicatorsParameter.getDimension() != events - 1) { final BooleanParameter p = new BooleanParameter(); p.initByName("value", "" + indicatorsParameter.getValue(), "dimension", events - 1); p.setID(indicatorsParameter.getID()); indicatorsParameter.assignFrom(p); } } catch (Exception e) { // what to do? e.printStackTrace(); } initInternals(); for (int nt = 0; nt < trees.size(); ++nt) { setTreeTimes(nt); } mergeTreeTimes(); setDemographicArrays(); shadow = new Shadow(); } @Override public List<String> getParameterIds() { List<String> paramIDs = new ArrayList<>(); paramIDs.add(popSizeParameter.getID()); paramIDs.add(indicatorsParameter.getID()); for (TreeIntervals t : trees) { // I think this may be wrong, and we need the trees themselves paramIDs.add(t.getID()); } return paramIDs; } @Override public double getPopSize(double t) { double p; switch (type) { case STEPWISE: { final int j = getIntervalIndexStep(t); p = values[j]; break; } case LINEAR: { p = linPop(t); break; } default: throw new IllegalArgumentException(""); } return p; } @Override public double getIntensity(double t) { return getIntegral(0, t); } @Override public double getInverseIntensity(double x) { throw new UnsupportedOperationException(); } // values participating in the demographic private double[] values; // times participating in the demographic private double[] times; // convenience: intervals[n] = times[n+1] - times[n] private double[] intervals; // sorted times from each tree private double[][] ttimes; // sorted times from all trees (merge of ttimes above) private double[] alltimes; // no allocations, minimal copying class Shadow { double[] values; double[] times; double[] intervals; double[][] ttimes; double[] alltimes; boolean c_demo, c_alltimes; boolean[] c_ttimes; Shadow() { values = CompoundPopulationFunction.this.values.clone(); times = CompoundPopulationFunction.this.times.clone(); intervals = CompoundPopulationFunction.this.intervals.clone(); alltimes = CompoundPopulationFunction.this.alltimes.clone(); ttimes = CompoundPopulationFunction.this.ttimes.clone(); for (int nt = 0; nt < ttimes.length; ++nt) { ttimes[nt] = CompoundPopulationFunction.this.ttimes[nt].clone(); } c_ttimes = new boolean[ttimes.length]; reset(); } void reset() { c_alltimes = false; c_demo = false; Arrays.fill(c_ttimes, false); } void protect_demo() { values = CompoundPopulationFunction.this.values; times = CompoundPopulationFunction.this.times; intervals = CompoundPopulationFunction.this.intervals; CompoundPopulationFunction.this.values = null; CompoundPopulationFunction.this.times = null; CompoundPopulationFunction.this.intervals = null; // { // final double[] src = CompoundPopulationFunction.this.values; // final double[] target = values; // if( src.length == target.length ) { // System.arraycopy(src, 0, target, 0, src.length); // } else { // values = src.clone(); // } // } // { // final double[] src = CompoundPopulationFunction.this.times; // final double[] target = times; // if( src.length == target.length ) { // System.arraycopy(src, 0, target, 0, src.length); // } else { // times = src.clone(); // } // } // { // final double[] src = CompoundPopulationFunction.this.intervals; // final double[] target = intervals; // if( src.length == target.length ) { // System.arraycopy(src, 0, target, 0, src.length); // } else { // intervals = src.clone(); // } // } c_demo = true; } void protect_alltimes() { final double[] src = CompoundPopulationFunction.this.alltimes; System.arraycopy(src, 0, alltimes, 0, src.length); c_alltimes = true; } void protect_ttimes(int nt) { final double[] src = CompoundPopulationFunction.this.ttimes[nt]; System.arraycopy(src, 0, ttimes[nt], 0, src.length); c_ttimes[nt] = true; } void accept() { values = times = intervals = null; } void reject() { if (c_alltimes) { final double[] v = CompoundPopulationFunction.this.alltimes; CompoundPopulationFunction.this.alltimes = alltimes; alltimes = v; } if (c_demo) { CompoundPopulationFunction.this.values = values; CompoundPopulationFunction.this.times = times; CompoundPopulationFunction.this.intervals = intervals; values = times = intervals = null; // double[] v = CompoundPopulationFunction.this.values; // CompoundPopulationFunction.this.values = values; // values = v; // // v = CompoundPopulationFunction.this.times; // CompoundPopulationFunction.this.times = times; // times = v; // // v = CompoundPopulationFunction.this.intervals; // CompoundPopulationFunction.this.intervals = intervals; // intervals = v; } for (int nt = 0; nt < c_ttimes.length; ++nt) { if (c_ttimes[nt]) { double[] v = CompoundPopulationFunction.this.ttimes[nt]; CompoundPopulationFunction.this.ttimes[nt] = ttimes[nt]; ttimes[nt] = v; } } } } private Shadow shadow; private void initInternals() { ttimes = new double[trees.size()][]; int tot = 0; for (int k = 0; k < ttimes.length; ++k) { ttimes[k] = new double[trees.get(k).treeInput.get().getLeafNodeCount() - 1]; tot += ttimes[k].length; } alltimes = new double[tot]; } private int getIntervalIndexStep(final double t) { int j = 0; // ugly hack, // when doubles are added in a different order and compared later, they can be a tiny bit off. With a // stepwise model this creates a "one off" situation here, which is unpleasant. // use float comparison here to avoid it final float tf = (float) t; while (tf > (float) times[j + 1]) ++j; return j; } private int getIntervalIndexLin(final double t) { int j = 0; while (t > times[j + 1]) ++j; return j; } private double linPop(double t) { final int j = getIntervalIndexLin(t); if (j == values.length - 1) { return values[j]; } final double a = (t - times[j]) / (intervals[j]); return a * values[j + 1] + (1 - a) * values[j]; } private double intensityLinInterval(double start, double end, int index) { final double dx = end - start; if (dx == 0) { return 0; } final double popStart = values[index]; final double popDiff = (index < values.length - 1) ? values[index + 1] - popStart : 0.0; if (popDiff == 0.0) { return dx / popStart; } final double time0 = times[index]; final double interval = intervals[index]; assert (float) start <= (float) (time0 + interval) && start >= time0 && (float) end <= (float) (time0 + interval) && end >= time0; //better numerical stability but not perfect final double p1minusp0 = ((end - start) / interval) * popDiff; final double v = interval * (popStart / popDiff); final double p1overp0 = (v + (end - time0)) / (v + (start - time0)); if (p1minusp0 == 0.0 || p1overp0 <= 0) { // either dx == 0 or is very small (numerical inaccuracy) final double pop0 = popStart + ((start - time0) / interval) * popDiff; return dx / pop0; } return dx * Math.log(p1overp0) / p1minusp0; // return dx * Math.log(pop1/pop0) / (pop1 - pop0);*/ } private double intensityLinInterval(int index) { final double interval = intervals[index]; final double pop0 = values[index]; final double pop1 = values[index + 1]; if (pop0 == pop1) { return interval / pop0; } return interval * Math.log(pop1 / pop0) / (pop1 - pop0); } @Override public double getIntegral(double start, double finish) { double intensity = 0.0; switch (type) { case STEPWISE: { final int first = getIntervalIndexStep(start); final int last = getIntervalIndexStep(finish); final double popStart = values[first]; if (first == last) { intensity = (finish - start) / popStart; } else { intensity = (times[first + 1] - start) / popStart; for (int k = first + 1; k < last; ++k) { intensity += intervals[k] / values[k]; } intensity += (finish - times[last]) / values[last]; } break; } case LINEAR: { final int first = getIntervalIndexLin(start); final int last = getIntervalIndexLin(finish); if (first == last) { intensity += intensityLinInterval(start, finish, first); } else { // from first to end of interval intensity += intensityLinInterval(start, times[first + 1], first); // intervals until (not including) last for (int k = first + 1; k < last; ++k) { intensity += intensityLinInterval(k); } // last interval intensity += intensityLinInterval(times[last], finish, last); } break; } } return intensity; } /** * Get times of the (presumably changed) nt'th tree into the local array. * * @param nt */ private void setTreeTimes(int nt) { TreeIntervals nti = trees.get(nt); nti.setMultifurcationLimit(0); // code probably incorrect for serial samples final int lineages = nti.getIntervalCount(); assert lineages >= ttimes[nt].length : lineages + " " + ttimes[nt].length; int count = 0; for (int k = 0; k < ttimes[nt].length; ++k) { double timeToCoal = nti.getInterval(count); while (nti.getIntervalType(count) != IntervalType.COALESCENT) { ++count; timeToCoal += nti.getInterval(count); } int linAtStart = nti.getLineageCount(count); ++count; assert !(count == lineages && linAtStart != 2); int linAtEnd = (count == lineages) ? 1 : nti.getLineageCount(count); while (linAtStart <= linAtEnd) { ++count; timeToCoal += nti.getInterval(count); linAtStart = linAtEnd; ++count; linAtEnd = nti.getLineageCount(count); } ttimes[nt][k] = timeToCoal + (k == 0 ? 0 : ttimes[nt][k - 1]); } } /** * Merge sorted times in each ttimes[] array into one sorted array (alltimes) * */ private void mergeTreeTimes() { // now we want to merge times together int[] inds = new int[ttimes.length]; for (int k = 0; k < alltimes.length; ++k) { int j = 0; while (inds[j] == ttimes[j].length) { ++j; } for (int l = j + 1; l < inds.length; ++l) { if (inds[l] < ttimes[l].length) { if (ttimes[l][inds[l]] < ttimes[j][inds[j]]) { j = l; } } } alltimes[k] = ttimes[j][inds[j]]; inds[j]++; } } /** * Setup the internal times,values,intervals from the rest * */ private void setDemographicArrays() { // assumes lowest node has time 0. this is probably problematic when we come // to deal with multiple trees int tot = 1; final int nd = indicatorsParameter.getDimension(); assert nd == alltimes.length + (type == Type.STEPWISE ? -1 : 0) : " nd=" + nd + " alltimes.length=" + alltimes.length + " type=" + type; for (int k = 0; k < nd; ++k) { if (indicatorsParameter.getValue(k)) { ++tot; } } times = new double[tot + 1]; values = new double[tot]; intervals = new double[tot - 1]; times[0] = 0.0; times[tot] = Double.POSITIVE_INFINITY; values[0] = popSizeParameter.getValue(0); int n = 0; for (int k = 0; k < nd && n + 1 < tot; ++k) { if (indicatorsParameter.getValue(k)) { times[n + 1] = useMid ? ((alltimes[k] + (k > 0 ? alltimes[k - 1] : 0)) / 2) : alltimes[k]; values[n + 1] = popSizeParameter.getValue(k + 1); intervals[n] = times[n + 1] - times[n]; ++n; } } } @Override protected void store() { super.store(); } @Override protected boolean requiresRecalculation() { boolean anyTreesChanged = false; for (int nt = 0; nt < trees.size(); ++nt) { TreeIntervals ti = trees.get(nt); if (ti.isDirtyCalculation()) { shadow.protect_ttimes(nt); setTreeTimes(nt); anyTreesChanged = true; } } // we access parameters in any case getParams(); if (anyTreesChanged) { shadow.protect_alltimes(); shadow.protect_demo(); mergeTreeTimes(); setDemographicArrays(); } else { if (popSizeParameter.somethingIsDirty() && !indicatorsParameter.somethingIsDirty()) { } shadow.protect_demo(); setDemographicArrays(); } return true; } @Override protected void restore() { shadow.reject(); shadow.reset(); super.restore(); } @Override protected void accept() { shadow.accept(); shadow.reset(); super.accept(); } @Override public void init(PrintStream out) { // interval sizes out.print("popsSize0\t"); for (int i = 0; i < alltimes.length; i++) { out.print(getID() + ".times." + i + "\t"); } } @Override public void log(int sample, PrintStream out) { // interval sizes out.print("0:" + popSizeParameter.getArrayValue(0) + "\t"); for (int i = 0; i < alltimes.length - (type == Type.STEPWISE ? 1 : 0); i++) { out.print(alltimes[i]); if (indicatorsParameter.getArrayValue(i) > 0) { out.print(":" + popSizeParameter.getArrayValue(i + 1)); } out.print("\t"); } if( type == Type.STEPWISE ) { out.print(alltimes[alltimes.length-1]); } } @Override public void close(PrintStream out) { } }