/* * TeamOperator.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.inference.operators; import dr.math.MathUtils; /** * Combines N oprators - each time a random subset of k out of the N are run. If operators have weight then the * probability of any subset is proportional to the sum of it's members weights. * * Currently very basic as no Coercible operators are handled correctly. Getting coercion to work is non trivial. * * @author Joseph Heled */ public class TeamOperator extends SimpleMCMCOperator /*implements CoercableMCMCOperator*/ { private final MCMCOperator[] operators; //private final ArrayList<Integer> operatorToOptimizeList; //private int currentOptimizedOperator; // private final double targetProbability; private final int nPick; private final boolean unequalWeights; private final MCMCOperator[] currentRound; private int nToReject; private int[][] binomial; public TeamOperator(MCMCOperator[] operators, int nPick, double weight) { setWeight(weight); this.operators = operators; //targetProbability = targetProb; final int N = operators.length; assert 0 < nPick && nPick <= N; this.nPick = nPick; currentRound = new MCMCOperator[N]; { boolean b = false; double w = operators[0].getWeight(); for( MCMCOperator o : operators ) { if( o.getWeight() != w ) { b = true; break; } } unequalWeights = b; } if( unequalWeights ) { final int M = N+1; binomial = new int[M][M]; for(int n = 0; n < M; ++n) { // binomial[n] = new int[operators.length]; binomial[n][0] = 1; } for(int n = 1; n < M; ++n) { for(int k = 1; k <= n; ++k) { binomial[n][k] = binomial[n-1][k] + binomial[n-1][k-1]; } } } else { for(int k = 0; k < N; ++k) { currentRound[k] = operators[k]; } } } // public void addOperator(SimpleMCMCOperator operation) { // // operatorList.add(operation); // if (operation instanceof CoercableMCMCOperator) { // // if (((CoercableMCMCOperator) operation).getMode() == CoercionMode.COERCION_ON) // // operatorToOptimizeList.add(operatorList.size() - 1); // // } // } private void choose() { final int n = operators.length; if( nPick < n ) { if( unequalWeights ) { chooseUsingWeights(); } else { // equal weights, just pick a subset of 'nPick' operators uniformly for(int k = 0; k < nPick; ++k) { final int which = k + MathUtils.nextInt(n - k); final MCMCOperator tmp = currentRound[k]; currentRound[k] = currentRound[which]; currentRound[which] = tmp; } } } } private void chooseUsingWeights() { // sum(o_w : o in operators already selected) double inSumWeights = 0.0; // sum(o_w : o in remaining operators) double sumWeightsRemaining = 0; for( MCMCOperator o : operators ) { sumWeightsRemaining += o.getWeight(); } // Number of operators still to pick int k = nPick; // Operator under consideration int j = 0; while( k > 0 ) { // remaining to choose from final int n = operators.length - j; if( k == n ) { // speedup for( ; k > 0; k--, j++ ) { currentRound[k-1] = operators[j]; } } else { final int cnk = binomial[n][k]; final int cnk1 = binomial[n-1][k-1]; final int cnk2 = k >= 2 ? binomial[n-2][k-2] : 0; final double tot = cnk1 * sumWeightsRemaining + cnk * inSumWeights; final double we0 = operators[j].getWeight(); final double has = cnk2 * (sumWeightsRemaining-we0) + cnk1 * (we0+inSumWeights); final double r = MathUtils.nextDouble(); if( r < has/tot ) { currentRound[k-1] = operators[j]; k -= 1; inSumWeights += we0; } j += 1; sumWeightsRemaining -= we0; } } } public final double doOperation() { choose(); double logP = 0; for(int k = 0; k < nPick; ++k) { MCMCOperator operation = currentRound[k]; logP += operation.operate(); } nToReject = nPick; return logP; } public void accept(double deviation) { super.accept(deviation); for(int k = 0; k < nPick; ++k) { currentRound[k].accept(deviation); } } public void reject() { super.reject(); for(int k = 0; k < nToReject; ++k) { currentRound[k].reject(); } } public void reset() { for( MCMCOperator op : operators ) { op.reset(); } } public String getOperatorName() { StringBuffer sb = new StringBuffer("Team " + nPick + " ("); for( MCMCOperator operation : operators ) { sb.append(operation.getOperatorName()+","); } return sb.substring(0, sb.length()-1) + ")"; } // public double getTargetAcceptanceProbability() { // return targetProbability; // } public String getPerformanceSuggestion() { return ""; } }