package beast.evolution.speciation; // Arguably the most complex iterator I have ever written. /** * @author Joseph Heled */ public class CalibrationLineagesIterator { // taxaPartialOrder[i] contains clades immediately contained in the i'th clade as indices (so, strictly smaller than i) final int[][] taxaPartialOrder; // Per calibration point, the number of taxa which is not below any other point. final int[] cladesFreeLins; private final boolean rootCalibrated; // Use iterators 0 to curIters-1 (i.e. iters[0:curIters]) private int curIters; // per clade Iterator private final LinsIterator[] iters; // last returned value from iterators: vals[i] for iters[i] private int[][] vals; // Number of taxa not below any calibration point private int freeLineages; // indices of maximal clades private final int[] maximalClades; CalibrationLineagesIterator(final int[][] clades, final int[][] taxaPartialOrder, final boolean[] maximal, final int leafCount) { cladesFreeLins = new int[clades.length]; for(int k = 0; k < cladesFreeLins.length; ++k) { cladesFreeLins[k] = clades[k].length; for( final int l : taxaPartialOrder[k] ) { cladesFreeLins[k] -= clades[l].length; } assert cladesFreeLins[k] >= 0; } this.taxaPartialOrder = taxaPartialOrder; iters = new LinsIterator[clades.length+1]; vals = new int[iters.length][]; // number of maximal clades int max = 0; for(final boolean b : maximal) { max += b ? 1 : 0; } // indices of maximal clades in a list maximalClades = new int[max]; freeLineages = leafCount; max = 0; for(int m = 0; m < maximal.length; ++m) { if( maximal[m] ) { maximalClades[max] = m; ++max; freeLineages -= clades[m].length; } } rootCalibrated = ( max == 1 && clades[maximalClades[0]].length == leafCount ); assert ! (rootCalibrated && freeLineages > 0); assert freeLineages >= 0; } boolean isRootCalibrated() { return rootCalibrated; } // Prepare to iterate: ranks[i] gives the rank of the i'th clade. ranks is a permutation of (1,2,...,#points) int setup(final int[] ranks) { final int n = cladesFreeLins.length; // reset iterators used. each call to setOneIterator will increment it by one. curIters = 0; for(int k = 0; k < n; ++k) { setOneIterator(ranks, taxaPartialOrder[k], cladesFreeLins[k], ranks[k]); } if( ! rootCalibrated ) { setOneIterator(ranks, maximalClades, freeLineages, n+1); } for(int k = 0; k < curIters-1; ++k) { vals[k] = iters[k].next(); } return curIters; } private void setOneIterator(final int[] ranks, final int[] joinerClades, final int nl, final int rank) { final int subs = joinerClades.length; LinsIterator itr/* = null*/; if( subs == 0 ) { itr = new LinsIterator(nl, rank, null); } else /*if( nl > 0 || subs > 2 ) */ { final int[] s = new int[subs]; for(int i = 0; i < subs; ++i) { s[i] = ranks[joinerClades[i]]; } itr = new LinsIterator(nl, rank, s); } //assert itr != null; //if( itr != null ) { // sorted according to rank iters[itr.rank-1] = itr; itr.startIter(); ++curIters; //} } int[][] next() { final int[] l = iters[curIters-1].next(); if( l != null ) { vals[curIters-1] = l; return vals; } int i = curIters-2; for( ; i >= 0; --i) { if( (vals[i] = iters[i].next()) != null) { break; } } if( i < 0 ) { return null; } ++i; for( ; i < curIters; ++i) { iters[i].startIter(); vals[i] = iters[i].next(); } return vals; } public int[][] allJoiners() { final int[][] joiners = new int[curIters][]; for(int i = 0; i < curIters; ++i) { joiners[i] = iters[i].ljoins(); } return joiners; } public int start(final int i) { return iters[i].start; } class LinsIterator { private final int rank; private final int start; private final int[] joiners; private final int[] aStart; // Current count of lineages at all relevant time points, from 0 (start) to clade top. private final int[] lins; private int lastJoinger; private boolean stopIter; LinsIterator(final int ns, final int r, final int[] jnr) { rank = r; start = ns; joiners = new int [r]; lastJoinger = -1; // 2 for start+end, rank-1 intermediate levels aStart = new int [2 + rank-1]; lins = new int [2 + rank-1]; for(int k = 0; k < rank; ++k) { joiners[k] = 0; } if( jnr != null ) { for (final int j : jnr) { joiners[j] = 1; if (lastJoinger < j) { lastJoinger = j; } } } aStart[0] = ns; if( lastJoinger <= 0 ) { for(int i = 1; i < rank+1; ++i) { aStart[i] = 2; } if( rank > 1 ) { // first iteration increments this aStart[rank-1] -= 1; } } else { //assert(rank > 1); if( start > 0 ) { int i = 1; for(; i < lastJoinger+1; ++i) { aStart[i] = 1; } for(; i < rank+1; ++i) { aStart[i] = 2; } } else { assert jnr != null; int mj = jnr[0]; for (int aJnr : jnr) { mj = Math.min(mj, aJnr); } int i = 1; for(; i < mj+1; ++i) { aStart[i] = 0; } for(; i < lastJoinger+1; ++i) { aStart[i] = 1; } for(; i < rank+1; ++i) { aStart[i] = 2; } } // first iteration increments this aStart[rank-1] -= 1; } } void startIter() { for(int i = 0; i < rank+1; ++i) { lins[i] = aStart[i]; } stopIter = false; } final int[] next() { int i = rank - 1; if( lastJoinger <= 0 ) { while( i >= 1 && lins[i] == lins[i-1]) { --i; } if( i == 0 ) { if( rank == 1 ) { if( !stopIter ) { stopIter = true; return lins; } } return null; } lins[i] += 1; ++i; while( i < rank ) { lins[i] = 2; ++i; } } else { while( i >= 1 && lins[i] == lins[i-1] + joiners[i-1] ) { --i; } if( i == 0 ) { return null; } lins[i] += 1; i++; while( i < rank ) { lins[i] = (i <= lastJoinger) ? 1 : 2; i++; } } return lins; } final int[] ljoins() { return joiners; } } }