package dr.inference.operators;
import dr.math.MathUtils;
import org.w3c.dom.Document;
import org.w3c.dom.Element;
import java.util.ArrayList;
/**
* @author Marc A. Suchard
*/
public class JointOperator extends SimpleMCMCOperator implements CoercableMCMCOperator {
private final ArrayList<SimpleMCMCOperator> operatorList;
private final ArrayList<Integer> operatorToOptimizeList;
private int currentOptimizedOperator;
private final double targetProbability;
public JointOperator(double weight, double targetProb) {
operatorList = new ArrayList<SimpleMCMCOperator>();
operatorToOptimizeList = new ArrayList<Integer>();
targetProbability = targetProb;
setWeight(weight);
}
public void addOperator(SimpleMCMCOperator operation) {
operatorList.add(operation);
if (operation instanceof CoercableMCMCOperator) {
if (((CoercableMCMCOperator) operation).getMode() == CoercionMode.COERCION_ON)
operatorToOptimizeList.add(operatorList.size() - 1);
}
}
public final double doOperation() throws OperatorFailedException {
double logP = 0;
boolean failed = false;
OperatorFailedException failure = null;
for (SimpleMCMCOperator operation : operatorList) {
try {
logP += operation.doOperation();
} catch (OperatorFailedException ofe) {
failed = true;
failure = ofe;
}
// todo After a failure, should not have to complete remaining operations, need to fake their operate();
}
if (failed)
throw failure;
return logP;
}
// private double old;
public double getCoercableParameter() {
if (operatorToOptimizeList.size() > 0) {
currentOptimizedOperator = operatorToOptimizeList.get(MathUtils.nextInt(operatorToOptimizeList.size()));
return ((CoercableMCMCOperator) operatorList.get(currentOptimizedOperator)).getCoercableParameter();
}
throw new IllegalArgumentException();
}
public void setCoercableParameter(double value) {
if (operatorToOptimizeList.size() > 0) {
((CoercableMCMCOperator) operatorList.get(currentOptimizedOperator)).setCoercableParameter(value);
return;
}
throw new IllegalArgumentException();
}
public int getNumberOfSubOperators() {
return operatorList.size();
}
public double getRawParamter(int i) {
if (i < 0 || i >= operatorList.size())
throw new IllegalArgumentException();
return ((CoercableMCMCOperator) operatorList.get(i)).getRawParameter();
}
public double getRawParameter() {
throw new RuntimeException("More than one raw parameter for a joint operator");
}
public CoercionMode getMode() {
if (operatorToOptimizeList.size() > 0)
return CoercionMode.COERCION_ON;
return CoercionMode.COERCION_OFF;
}
public MCMCOperator getSubOperator(int i) {
return operatorList.get(i);
}
public CoercionMode getSubOperatorMode(int i) {
if (i < 0 || i >= operatorList.size())
throw new IllegalArgumentException();
if (operatorList.get(i) instanceof CoercableMCMCOperator)
return ((CoercableMCMCOperator) operatorList.get(i)).getMode();
return CoercionMode.COERCION_OFF;
}
public String getSubOperatorName(int i) {
if (i < 0 || i >= operatorList.size())
throw new IllegalArgumentException();
return "Joint." + operatorList.get(i).getOperatorName();
}
public String getOperatorName() {
// StringBuffer sb = new StringBuffer("Joint(\n");
// for(SimpleMCMCOperator operation : operatorList)
// sb.append("\t"+operation.getOperatorName()+"\n");
// sb.append(") opt = "+optimizedOperator.getOperatorName());
// return sb.toString();
return "JointOperator";
}
public Element createOperatorElement(Document d) {
throw new RuntimeException("not implemented");
}
public double getTargetAcceptanceProbability() {
return targetProbability;
}
public double getMinimumAcceptanceLevel() {
double min = targetProbability - 0.2;
if (min < 0)
min = 0.01;
return min;
}
public double getMaximumAcceptanceLevel() {
double max = targetProbability + 0.2;
if (max > 1)
max = 0.9;
return max;
}
public double getMinimumGoodAcceptanceLevel() {
double min = targetProbability - 0.1;
if (min < 0)
min = 0.01;
return min;
}
public double getMaximumGoodAcceptanceLevel() {
double max = targetProbability + 0.2;
if (max > 1)
max = 0.9;
return max;
}
public final String getPerformanceSuggestion() {
// double prob = MCMCOperator.Utils.getAcceptanceProbability(this);
// double targetProb = getTargetAcceptanceProbability();
// dr.util.NumberFormatter formatter = new dr.util.NumberFormatter(5);
// double sf = OperatorUtils.optimizeScaleFactor(scaleFactor, prob, targetProb);
// if (prob < getMinimumGoodAcceptanceLevel()) {
// return "Try setting scaleFactor to about " + formatter.format(sf);
// } else if (prob > getMaximumGoodAcceptanceLevel()) {
// return "Try setting scaleFactor to about " + formatter.format(sf);
// } else return "";
return "";
}
}