package dr.evomodel.antigenic.phyloClustering.misc.obsolete;
import dr.inference.model.MatrixParameter;
import dr.inference.model.Parameter;
import dr.inference.operators.MCMCOperator;
import dr.inference.operators.SimpleMCMCOperator;
import dr.math.distributions.MultivariateNormalDistribution;
import dr.xml.*;
/**
* A Gibbs operator for allocation of items to clusters under a distance dependent Chinese restaurant process.
*
* @author Charles Cheung
* @author Trevor Bedford
*/
public class ClusterAlgorithmOperator extends SimpleMCMCOperator {
//Parameter locationDrift; // no longer need to know
Parameter virusOffsetsParameter;
private double sigmaSq =1;
private int numdata = 0; //NEED TO UPDATE
//private double[] groupSize;
private MatrixParameter mu = null;
private Parameter clusterLabels = null;
private Parameter K = null;
private MatrixParameter virusLocations = null;
private int maxLabel = 0;
private int[] muLabels = null;
private int[] groupSize;
// public ClusterViruses clusterLikelihood = null;
private double numAcceptMoveMu = 0;
private double numProposeMoveMu = 0;
private double numAcceptMoveC = 0;
private double numProposeMoveC = 0;
private int isMoveMu = -1;
private double[] old_vLoc0 ;
private double[] old_vLoc1 ;
private Parameter clusterOffsetsParameter;
private int groupSelectedChange = -1;
private int virusIndexChange = -1;
private double originalValueChange = -1;
private int dimSelectChange = -1;
private double[] mu0_offset;
//public ClusterAlgorithmOperator(MatrixParameter virusLocations, MatrixParameter mu, Parameter clusterLabels, Parameter K, double weight, Parameter virusOffsetsParameter, Parameter locationDrift_in, Parameter clusterOffsetsParameter) {
public ClusterAlgorithmOperator(MatrixParameter virusLocations, MatrixParameter mu, Parameter clusterLabels, Parameter K, double weight, Parameter virusOffsetsParameter, Parameter clusterOffsetsParameter) {
System.out.println("Loading the constructor for ClusterAlgorithmOperator");
this.mu = mu;
this.K = K;
this.clusterLabels = clusterLabels;
// this.clusterLikelihood = clusterLikelihood;
this.virusLocations = virusLocations;
this.virusOffsetsParameter = virusOffsetsParameter;
// this.locationDrift = locationDrift_in; //no longer need
this.clusterOffsetsParameter = clusterOffsetsParameter;
numdata = virusOffsetsParameter.getSize();
System.out.println("numdata="+ numdata);
int K_int = (int) K.getParameterValue(0);
System.out.println("K_int=" + K_int);
groupSize = new int[K_int];
for(int i=0; i < K_int; i++){
groupSize[i] = 0;
}
for(int i=0; i < numdata; i++){
//System.out.println("i="+ i);
int index = (int) clusterLabels.getParameterValue(i);
groupSize[ index]++;
}
for(int i=0; i < numdata;i++){
if(maxLabel < (int) clusterLabels.getParameterValue(i)){
maxLabel = (int) clusterLabels.getParameterValue(i);
}
}
//NEED maxGROUP
//for(int i=0; i < K_int; i++){
//System.out.println("groupSize=" + groupSize[i]);
//}
muLabels = new int[K_int];
for(int i=0; i < maxLabel; i++){
int j=0;
if(groupSize[i] >0){
muLabels[j] = i;
j++;
}
}
//muLabels ...
setWeight(weight);
System.out.println("Finished loading the constructor for ClusterAlgorithmOperator");
}
/**
* change the parameter and return the log hastings ratio.
*/
public final double doOperation() {
//System.out.println("Do operation!");
double[] zeroVector2D = {0,0};
double[][] identityMatrix2D = new double[][]{
{ 1, 0 },
{ 0, 1 }};
double[][] sigmaSqMatrix2D = new double[][]{
{ sigmaSq, 0 },
{ 0, sigmaSq }};
double logHastingRatio = 0;
double chooseOperator = Math.random();
int K_int = (int) K.getParameterValue(0);
double[] original_groupSize = new double[groupSize.length];
//recalculate groupSize
for(int i=0; i < groupSize.length; i++){
original_groupSize[i] = 0;
}
for(int i=0; i < numdata; i++){
int label = (int) clusterLabels.getParameterValue(i);
original_groupSize[label ]++;
}
// for(int i=0; i < K_int; i++){
// System.out.println("group " + i + " has size=" + original_groupSize[i]);
// }
for(int i=0; i < K_int; i++){
double muk_0 = mu.getParameter(i).getParameterValue(0);
double muk_1 = mu.getParameter(i).getParameterValue(1);
//System.out.println("size=" + groupSize[i] + " mu_k_0=" + muk_0+ " , muk_1=" + muk_1);
}
// System.out.println("propose a change in mu only");
if(chooseOperator < 0.5){
// if(chooseOperator < 1){
//change nothing
isMoveMu = 1;
int groupSelect = (int) Math.floor( Math.random()* K_int );
groupSelectedChange = groupSelect;
int dimSelect = (int) Math.floor( Math.random()* 2 );
dimSelectChange = dimSelect;
// System.out.println("Group selected = " + groupSelectedChange + " mu=" + mu.getParameter(groupSelectedChange).getParameterValue(0)+ "\t" + mu.getParameter(groupSelectedChange).getParameterValue(1) + " (before change...)" );
double change = Math.random()*2-1 ;
//System.out.println(change);
double originalValue = mu.getParameter(groupSelect).getParameterValue(dimSelect);
originalValueChange = originalValue;
mu.getParameter(groupSelect).setParameterValue(dimSelect, originalValue + change);
// System.out.println("Group selected = " + groupSelectedChange + " mu=" + mu.getParameter(groupSelectedChange).getParameterValue(0)+ "\t" + mu.getParameter(groupSelectedChange).getParameterValue(1) + " (propsed to...)");
logHastingRatio = 0;
}
// System.out.println("propose a change in both C and mu");
else{
isMoveMu = 0;
int virusIndex = (int) Math.floor( Math.random()*numdata );
virusIndexChange = virusIndex;
int toBin = (int) Math.floor(Math.random()*K_int);
// System.out.println("toBin=" + toBin);
int fromBin = (int) clusterLabels.getParameterValue(virusIndex);
// System.out.println("fromBin=" + fromBin);
// if(virusIndex < 5){
// System.out.println("virus " + virusIndex + " from bin=" + fromBin + " to bin " + toBin);
// }
clusterLabels.setParameterValue( virusIndex, toBin); //the proposal
//recalculate groupSize
for(int i=0; i < groupSize.length; i++){
groupSize[i] = 0;
}
for(int i=0; i < numdata; i++){
int label = (int) clusterLabels.getParameterValue(i);
groupSize[label ]++;
}
//special case that needs attention on the virus label
if( (original_groupSize[fromBin] > 0) && ( groupSize[fromBin] == 0)){
K.setParameterValue(0, K_int - 1);
System.out.println("propose the fromBin " + fromBin + "becomes 0 in size - death of a bin");
//actually that label is no longer used..
double[] ranNormal = MultivariateNormalDistribution.nextMultivariateNormalVariance( zeroVector2D, sigmaSqMatrix2D);
mu.getParameter(fromBin).setParameterValue(0, ranNormal[0]);
mu.getParameter(fromBin).setParameterValue(1, ranNormal[1]);
//logHastingRatio += 0; //this move doesn't change
}
//birth of a new bin.. assign an offset to it
if( (original_groupSize[toBin] == 0) && (groupSize[toBin] == 1)){
K.setParameterValue(0, K_int + 1);
System.out.println("propose the birth of bin" + toBin);
double offset = 0;
// double drift = locationDrift.getParameterValue(0); // no longer need to do this here
// System.out.println("drift=" + drift);
if (virusOffsetsParameter != null) {
// System.out.print("virus Offeset Parameter present"+ ": ");
// System.out.print( virusOffsetsParameter.getParameterValue(i) + " ");
// System.out.print(" drift= " + drift + " ");
// offset = drift * virusOffsetsParameter.getParameterValue(virusIndex);
//make sure that it is equivalent to double offset = year[virusIndex] - firstYear;
}
else{
System.out.println("virus Offeset Parameter NOT present. We expect one though. Something is wrong.");
}
double[] ranNormal = MultivariateNormalDistribution.nextMultivariateNormalVariance( zeroVector2D, sigmaSqMatrix2D);
mu.getParameter(toBin).setParameterValue(0, ranNormal[0] ); // no need to assign offset anymore.. it's getting taken care of in the ClusterViruses by default
// mu.getParameter(toBin).setParameterValue(0, ranNormal[0] + offset);
mu.getParameter(toBin).setParameterValue(1, ranNormal[1]);
//this move should change the Hasting Ratio!
//CODE HERE
}
} //else
/*
for(int i=0; i < K_int; i++){
double muValue = mu.getParameter(i).getParameterValue(0);
double muValue2 = mu.getParameter(i).getParameterValue(1);
System.out.println("Group " + i + "\t" + muValue + "\t" + muValue2);
}
System.out.println("=============================");
*/
//change the mu in the toBin and fromBIn
//borrow from getLogLikelihood:
double[] meanYear = new double[K_int];
double[] groupCount = new double[K_int];
for(int i=0; i < numdata; i++){
int label = (int) clusterLabels.getParameterValue(i);
double year = 0;
if (virusOffsetsParameter != null) {
// System.out.print("virus Offeset Parameter present"+ ": ");
// System.out.print( virusOffsetsParameter.getParameterValue(i) + " ");
// System.out.print(" drift= " + drift + " ");
year = virusOffsetsParameter.getParameterValue(i); //just want year[i]
//make sure that it is equivalent to double offset = year[virusIndex] - firstYear;
}
else{
System.out.println("virus Offeset Parameter NOT present. We expect one though. Something is wrong.");
}
meanYear[ label] = meanYear[ label] + year;
groupCount[ label ] = groupCount[ label ] +1;
}
int maxLabel=0;
for(int i=0;i< numdata; i++){
if(maxLabel < (int) clusterLabels.getParameterValue(i)){
maxLabel = (int) clusterLabels.getParameterValue(i);
}
}
for(int i=0; i <= maxLabel; i++){
meanYear[i] = meanYear[i]/groupCount[i];
//System.out.println(meanYear[i]);
}
//System.out.println("beta=" + beta);
//beta = 1;
mu0_offset = new double[maxLabel+1];
//double[] mu1 = new double[maxLabel];
//System.out.println("maxLabel=" + maxLabel);
//now, change the mu..
for(int i=0; i <= maxLabel; i++){
//System.out.println(meanYear[i]*beta);
mu0_offset[i] = meanYear[i];
// System.out.println("group " + i + "\t" + mu0_offset[i]);
}
// System.out.println("=====================");
//Set the vLoc to be the corresponding mu values , and clusterOffsetsParameter to be the corresponding offsets
//virus in the same cluster has the same position
for(int i=0; i < numdata; i++){
int label = (int) clusterLabels.getParameterValue(i);
Parameter vLoc = virusLocations.getParameter(i);
//setting the virus locs to be equal to the corresponding mu
double muValue = mu.getParameter(label).getParameterValue(0);
vLoc.setParameterValue(0, muValue);
double muValue2 = mu.getParameter(label).getParameterValue(1);
vLoc.setParameterValue(1, muValue2);
//if we want to apply the mean year virus cluster offset to the cluster
if(clusterOffsetsParameter != null){
//setting the clusterOffsets to be equal to the mean year of the virus cluster
// by doing this, the virus changes cluster AND updates the offset simultaneously
clusterOffsetsParameter.setParameterValue( i , mu0_offset[label]);
}
// System.out.println("mu0_offset[label]=" + mu0_offset[label]);
// System.out.println("clusterOffsets now becomes =" + clusterOffsetsParameter.getParameterValue(i) );
}
// System.out.println("");
//Hasting's Ratio is p(old |new)/ p(new|old)
//System.out.println("Done doing operation!");
//return(logHastingRatio); //log hasting ratio
return(logHastingRatio);
}
public void accept(double deviation) {
super.accept(deviation);
/*
if(isMoveMu==1){
numAcceptMoveMu++;
numProposeMoveMu++;
System.out.println("% accept move Mu = " + numAcceptMoveMu/(double)numProposeMoveMu);
}
else{
numAcceptMoveC++;
numProposeMoveC++;
System.out.println("% accept move C = " + numAcceptMoveC/(double)numProposeMoveC);
}
*/
// if(virusIndexChange <5){
// System.out.println(" - Accepted!");
// }
}
public void reject(){
super.reject();
/*
//manually change mu back..
if(isMoveMu==1){
mu.getParameter(groupSelectedChange).setParameterValue(dimSelectChange, originalValueChange);
}
//manually change all the affected vLoc back...
for(int i=0; i < numdata; i++){
int label = (int) clusterLabels.getParameterValue(i);
Parameter vLoc = virusLocations.getParameter(i);
// double muValue = mu.getParameter(label).getParameterValue(0);
// vLoc.setParameterValue(0, muValue);
// double muValue2 = mu.getParameter(label).getParameterValue(1);
// vLoc.setParameterValue(1, muValue2);
clusterOffsetsParameter.setParameterValue( i , mu0_offset[label]);
}
*/
/*
if(isMoveMu==1){
numProposeMoveMu++;
System.out.println("% accept move Mu = " + numAcceptMoveMu/(double)numProposeMoveMu);
}
else{
numProposeMoveC++;
System.out.println("% accept move C = " + numAcceptMoveC/(double)numProposeMoveC);
}
*/
//if(virusIndexChange < 5){
// System.out.println(" - Rejected!");
//}
/*
for(int i=0; i < numdata; i++){
Parameter vLoc = virusLocations.getParameter(i);
if( vLoc.getParameterValue(0) != old_vLoc0[i]){
System.out.println("virus " + i + " is different: " + vLoc.getParameterValue(0) + " and " + old_vLoc0[i]);
}
//System.out.println(old_vLoc0[i] + ", " + old_vLoc1[i]);
vLoc.setParameterValue(0, old_vLoc0[i]);
vLoc.setParameterValue(1, old_vLoc1[i]);
}
*/
//System.exit(0);
}
public final static String CLUSTERALGORITHM_OPERATOR = "ClusterAlgorithmOperator";
//MCMCOperator INTERFACE
public final String getOperatorName() {
return CLUSTERALGORITHM_OPERATOR;
}
public final void optimize(double targetProb) {
throw new RuntimeException("This operator cannot be optimized!");
}
public boolean isOptimizing() {
return false;
}
public void setOptimizing(boolean opt) {
throw new RuntimeException("This operator cannot be optimized!");
}
public double getMinimumAcceptanceLevel() {
return 0.1;
}
public double getMaximumAcceptanceLevel() {
return 0.4;
}
public double getMinimumGoodAcceptanceLevel() {
return 0.20;
}
public double getMaximumGoodAcceptanceLevel() {
return 0.30;
}
public String getPerformanceSuggestion() {
if (Utils.getAcceptanceProbability(this) < getMinimumAcceptanceLevel()) {
return "";
} else if (Utils.getAcceptanceProbability(this) > getMaximumAcceptanceLevel()) {
return "";
} else {
return "";
}
}
public static XMLObjectParser PARSER = new AbstractXMLObjectParser() {
public final static String VIRUSLOCATIONS = "virusLocations";
public final static String MU = "mu";
public final static String CLUSTERLABELS = "clusterLabels";
public final static String K = "k";
public final static String OFFSETS = "offsets";
// public final static String LOCATION_DRIFT = "locationDrift"; //no longer need
public final static String CLUSTER_OFFSETS = "clusterOffsetsParameter";
public String getParserName() {
return CLUSTERALGORITHM_OPERATOR;
}
/* (non-Javadoc)
* @see dr.xml.AbstractXMLObjectParser#parseXMLObject(dr.xml.XMLObject)
*/
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
//System.out.println("Parser run. Exit now");
//System.exit(0);
double weight = xo.getDoubleAttribute(MCMCOperator.WEIGHT);
XMLObject cxo = xo.getChild(VIRUSLOCATIONS);
MatrixParameter virusLocations = (MatrixParameter) cxo.getChild(MatrixParameter.class);
cxo = xo.getChild(MU);
MatrixParameter mu = (MatrixParameter) cxo.getChild(MatrixParameter.class);
cxo = xo.getChild(CLUSTERLABELS);
Parameter clusterLabels = (Parameter) cxo.getChild(Parameter.class);
cxo = xo.getChild(K);
Parameter k = (Parameter) cxo.getChild(Parameter.class);
cxo = xo.getChild(OFFSETS);
Parameter offsets = (Parameter) cxo.getChild(Parameter.class);
// cxo = xo.getChild(LOCATION_DRIFT);
// Parameter locationDrift = (Parameter) cxo.getChild(Parameter.class);
Parameter clusterOffsetsParameter = null;
if (xo.hasChildNamed(CLUSTER_OFFSETS)) {
clusterOffsetsParameter = (Parameter) xo.getElementFirstChild(CLUSTER_OFFSETS);
}
//return new ClusterAlgorithmOperator(virusLocations, mu, clusterLabels, k, weight, offsets, locationDrift, clusterOffsetsParameter);
return new ClusterAlgorithmOperator(virusLocations, mu, clusterLabels, k, weight, offsets, clusterOffsetsParameter);
}
//************************************************************************
// AbstractXMLObjectParser implementation
//************************************************************************
public String getParserDescription() {
return "An operator that picks a new allocation of an item to a cluster under the Dirichlet process.";
}
public Class getReturnType() {
return ClusterAlgorithmOperator.class;
}
public XMLSyntaxRule[] getSyntaxRules() {
return rules;
}
private final XMLSyntaxRule[] rules = {
AttributeRule.newDoubleRule(MCMCOperator.WEIGHT),
new ElementRule(VIRUSLOCATIONS, Parameter.class),
new ElementRule(MU, Parameter.class),
new ElementRule(CLUSTERLABELS, Parameter.class),
new ElementRule(K, Parameter.class),
new ElementRule(OFFSETS, Parameter.class),
// new ElementRule(LOCATION_DRIFT, Parameter.class), //no longer needed
//
new ElementRule(CLUSTER_OFFSETS, Parameter.class, "Parameter of cluster offsets of all virus"), // no longer REQUIRED
};
};
public int getStepCount() {
return 1;
}
}