package dr.evomodel.antigenic.phyloClustering.misc.obsolete; import java.util.ArrayList; import java.util.Collections; import java.util.LinkedList; import java.util.List; import java.util.Set; import dr.evolution.tree.NodeRef; //import dr.evomodel.antigenic.driver.OrderDouble; import dr.evomodel.tree.TreeModel; import dr.inference.model.AbstractModelLikelihood; import dr.inference.model.CompoundParameter; import dr.inference.model.MatrixParameter; import dr.inference.model.Model; import dr.inference.model.Parameter; import dr.inference.model.Variable; import dr.xml.AbstractXMLObjectParser; import dr.xml.ElementRule; import dr.xml.XMLObject; import dr.xml.XMLObjectParser; import dr.xml.XMLParseException; import dr.xml.XMLSyntaxRule; /** * @author Charles Cheung * @author Trevor Bedford */ //Some suggestion to speed up the code from Charles Cheung //(please scroll through to places marked by to see the places of changes //---------- suggestion from cykc---------------- public class ClusterViruses extends AbstractModelLikelihood { //---------- suggestion from cykc---------------- private double mostRecentTransformedValue = 0; //keep a copy of the most recent version of transformFactor, to keep track of whether the transformFactor has changed //private boolean ShouldUpdateDepMatrix = true; // this way of flagging for change is not used anymore private boolean treeChanged = false; //a flag that becomes true when treeModel changes //---------- End of suggestion from cykc---------------- public static final String CLUSTER_VIRUSES = "ClusterViruses"; //============================================================================================================================== //variables double lambda = 10; double sigmaSq = 9; // double sigmaSq = 100; //when offset is off //K - number of parameters Parameter K; // for now, there is no move to change K //E|K - excision points Parameter excisionPoints; //C Parameter clusterLabels; //mu - means MatrixParameter mu; double[] muLabels; MatrixParameter virusLocations = null; // Parameter virusOffsetsParameter; //need to read it from AntigenicLIkelihood (the year ) // use offsets instead //need to stop the virus from public ClusterViruses (TreeModel treeModel_in, Parameter K_in, Parameter excisionPoints_in, Parameter clusterLabels_in, MatrixParameter mu_in, Boolean hasDrift, // Parameter locationDrift_in, Parameter offsets_in, MatrixParameter virusLocations_in){ super(CLUSTER_VIRUSES); this.treeModel= treeModel_in; this.K = K_in; this.excisionPoints = excisionPoints_in; this.clusterLabels = clusterLabels_in; this.mu = mu_in; this.hasDrift=hasDrift; // this.locationDrift=locationDrift_in; this.offsets=offsets_in; //this.hasDrift=false; this.virusLocations = virusLocations_in; numdata = offsets.getSize(); System.out.println("numdata = " + numdata); //initialize clusterLabels clusterLabels.setDimension(numdata); for (int i = 0; i < numdata; i++) { clusterLabels.setParameterValue(i, 0); } addVariable(clusterLabels); //initialize mu mu.setColumnDimension(2); //mu.setRowDimension(numdata); //in reality, only K of them are used int K_int = (int) K.getParameterValue(0); mu.setRowDimension(K_int); //in reality, only K of them are used //System.out.println((int) K.getParameterValue(0)); //System.exit(0); for(int i=0; i < K_int; i++){ //can I set initial condition to be like this? and then let the algorithm set it properly later? double zero=0; mu.getParameter(i).setValue(0, zero); mu.getParameter(i).setValue(1, zero); } //adding the pre-clustering step. preClustering(); /* this.abc.setColumnDimension(2); //set dimension equal to 2 abc.setRowDimension(strains.size()); for (int i = 0; i < strains.size(); i++) { abc.getParameter(i).setId(strains.get(i)); } */ addVariable(virusLocations); addModel(treeModel); //addVariable(locationDrift); addVariable(offsets); addVariable(K); addVariable(excisionPoints); addVariable(mu); System.out.println("Finished loading the constructor for ClusterViruses"); } private void preClustering() { int numViruses = offsets.getSize(); System.out.println("# offsets = " + offsets.getSize()); //for(int i=0; i < offsets.getSize(); i++){ //System.out.println(offsets.getParameterValue(i)); //} List<OrderDouble> list = new ArrayList<OrderDouble>(); for(int i=0; i < numViruses; i++){ list.add(new OrderDouble(i, offsets.getParameterValue(i))); //offset of 1 } Collections.sort(list, new OrderDouble()); int initialEqualBinSize = numViruses/(int)(K.getParameterValue(0) -1); System.out.println("initial bin size = " + initialEqualBinSize); System.out.println("Initial cluster assignment:"); // System.out.println("virus index\tOffset\tCluster label"); for(int i=0; i < numViruses; i++){ // System.out.println(list.get(i).getIndex() + "\t" + list.get(i).getValue() +"\t"+ i/ initialEqualBinSize ); int label = i/initialEqualBinSize; clusterLabels.setParameterValue(list.get(i).getIndex() , label); } /* //borrow from getLogLikelihood: int K_int = (int) K.getParameterValue(0); double[] meanYear = new double[K_int]; double[] groupCount = new double[K_int]; for(int i=0; i < numdata; i++){ int label = (int) clusterLabels.getParameterValue(i); double year = 0; if (offsets != null) { // System.out.print("virus Offeset Parameter present"+ ": "); // System.out.print( virusOffsetsParameter.getParameterValue(i) + " "); // System.out.print(" drift= " + drift + " "); year = offsets.getParameterValue(i); //just want year[i] //make sure that it is equivalent to double offset = year[virusIndex] - firstYear; } else{ System.out.println("virus Offeset Parameter NOT present. We expect one though. Something is wrong."); } meanYear[ label] += year; groupCount[ label ] = groupCount[ label ] +1; } */ int maxLabel=0; for(int i=0;i< numdata; i++){ if(maxLabel < (int) clusterLabels.getParameterValue(i)){ maxLabel = (int) clusterLabels.getParameterValue(i); } } /* for(int i=0; i <= maxLabel; i++){ meanYear[i] = meanYear[i]/groupCount[i]; //System.out.println(meanYear[i]); } */ //double beta = locationDrift.getParameterValue(0); //now, change the mu.. for(int i=0; i <= maxLabel; i++){ //System.out.println(meanYear[i]*beta); //mu.getParameter(i).setParameterValue(0, meanYear[i]*beta);//now separate out mu from virusLocation mu.getParameter(i).setParameterValue(0, 0); mu.getParameter(i).setParameterValue(1, 0); } //now change the clusterOffsets //... is it necessary? //System.exit(0); } public double getLogLikelihood() { //System.out.println("getLogLikelihood of ClusterViruses"); double logL = 0; int maxLabel=0; for(int i=0;i< numdata; i++){ if(maxLabel < (int) clusterLabels.getParameterValue(i)){ maxLabel = (int) clusterLabels.getParameterValue(i); } } //P(K=k) int K_int = (int) K.getParameterValue(0); //logL += Math.log(K.getParameterValue(0)) - lambda*K.getParameterValue(0) - Math.log( (double) factorial(K_int)); logL += -lambda + K.getParameterValue(0)*Math.log(lambda) - Math.log( (double) factorial(K_int)); // p(C | K= k) logL -= numdata * Math.log(K.getParameterValue(0)); //p(mu_j | C, years) ~ N( theta , sigma^2) //logL -= Math.log(2*Math.PI); //logL -= 0.5*Math.log( sigmaSq*sigmaSq); logL -= K.getParameterValue(0) * ( Math.log(2) + Math.log(Math.PI)+ 0.5*Math.log(sigmaSq) + 0.5*Math.log(sigmaSq) ); //double[] meanYear = new double[K_int]; double[] groupCount = new double[K_int]; for(int i=0; i < numdata; i++){ int label = (int) clusterLabels.getParameterValue(i); groupCount[ label ] = groupCount[ label ] +1; } //for(int i=0; i <= maxLabel; i++){ //meanYear[i] = meanYear[i]/groupCount[i]; //System.out.println(meanYear[i]); //} for(int i=0; i <= maxLabel; i++){ double mu_i0 = mu.getParameter(i).getParameterValue(0); double mu_i1 = mu.getParameter(i).getParameterValue(1); //double beta = locationDrift.getParameterValue(0); if( groupCount[i] >0){ //System.out.println("meanYear = " + meanYear[i]); //logL -= 0.5*( (mu_i0 - beta*meanYear[i] )*(mu_i0 - beta*meanYear[i] ) + ( mu_i1 -0)*( mu_i1 - 0) )/sigmaSq; logL -= 0.5*( (mu_i0 )*(mu_i0 ) + ( mu_i1 )*( mu_i1 ) )/sigmaSq; } } //System.out.println(logL); // System.out.println("logL=" + logL); //System.out.println("done getLogLikelihood of ClusterViruses"); //System.out.println("logL=" + logL); //double logL = 0; // for testing purpose only return(logL); /* // if treeModel changes, compute the depMatrix from scratch // if only the transformFactor change, go back to the latest copy of the untransformed deptMatrix, //and transform it, so it doesn't have to go through the treeModel to get the distance. if(treeChanged==true){ // setDepMatrix(); //the super slow step if(treeChanged ==true){ treeChanged = false; } } double logL = 0.0; for (int j=0 ; j<logLikelihoodsVector.length;j++){ logLikelihoodsVector[j]=getLogLikGroup(j); logL +=logLikelihoodsVector[j]; } for (int j=0 ; j<links.getDimension();j++){ if(links.getParameterValue(j)==j){ logL += Math.log(alpha.getParameterValue(0)); } else{logL += Math.log(depMatrix[j][(int) links.getParameterValue(j)]); } double sumDist=0.0; for (int i=0;i<numdata;i++){ if(i!=j){sumDist += depMatrix[i][j]; } } logL-= Math.log(alpha.getParameterValue(0)+sumDist); } return logL; */ } //===================================================================================================================== public int factorial(int n) { int fact = 1; // this will be the result for (int i = 1; i <= n; i++) { fact *= i; } return fact; } public Model getModel() { return this; } public void makeDirty() { } public void acceptState() { // DO NOTHING } public void restoreState() { // DO NOTHING } public void storeState() { // DO NOTHING } protected void handleModelChangedEvent(Model model, Object object, int index) { //---------- suggestion from cykc---------------- // I am making an assumption that whenever treeModel changes, the changes get caught here. if(model == treeModel){ //System.out.println("==========Tree model changes!!!!!!!!====="); treeChanged = true; } else{ } //---------- End of suggestion from cykc---------------- } //---------- suggestion from cykc---------------- //I tried to catch the transformedFactor changes through this routine, but it seems that it doesn't catch all, //so I abandoned this routine and now use 'mostRecentTransformedValue' to directly test if transformedValue has changed. //This ShouldUpdateDepMatrix never gets used and is now an obsolete variable. //I am noticing that handleVariableChangedEvent doesn't always catch when transformFactor changes. //I am observing that sometimes transformFactor can change more than once within a single MCMC sample - I don't know why, //if getLogLikelihood() gets called after the transformFactor changes but ShouldUpdateDepMatrix flag doesn't catch it, //then the getLogLikelihood() will not be calculated correctly. //Hence, this way of catching when transformFactor changes now becomes obsolete // protected final void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) { /* if (variable == transformFactor) { //System.out.println("TransformFactor gets updated and is caught here!!"); ShouldUpdateDepMatrix = true; } else { //has to change at another sample instead of setting it to false right after //updating the matrix because the transformFactor value //can update twice within a sample, so setDepMatrix has to update twice if(ShouldUpdateDepMatrix == true){ //System.out.println("ShouldUpdateDepMatrix changes from true to false"); ShouldUpdateDepMatrix = false; } } */ } //---------- End of suggestion from cykc---------------- Set<NodeRef> allTips; CompoundParameter traitParameter; Parameter alpha; Parameter clusterPrec ; Parameter priorPrec ; Parameter priorMean ; Parameter assignments; Parameter links; Parameter means2; Parameter means1; Parameter locationDrift; Parameter offsets; boolean hasDrift; TreeModel treeModel; String traitName; double[][] data; double[][] depMatrix; double[][] logDepMatrix; double[][] cur_untransformedMatrix; //---------- suggestion from cykc---------------- double[] logLikelihoodsVector; int numdata; Parameter transformFactor; double k0; double v0; double[][] T0Inv; double[] m; double logDetT0; LinkedList<Integer>[] assignmentsLL; int seqLength; public int getSeqLength() { return seqLength; } char[][] seqData; public static XMLObjectParser PARSER = new AbstractXMLObjectParser() { public final static String TREEMODEL = "treeModel"; public final static String K = "k"; public final static String EXCISIONPOINTS = "excisionPoints"; public final static String CLUSTERLABELS = "clusterLabels"; public final static String MU = "mu"; //public final static String HASDRIFT = ?? // public final static String LOCATION_DRIFT = "locationDrift"; public final static String OFFSETS = "offsets"; public final static String VIRUS_LOCATIONS = "virusLocations"; boolean integrate = false; public String getParserName() { return CLUSTER_VIRUSES; } public Object parseXMLObject(XMLObject xo) throws XMLParseException { TreeModel treeModel = (TreeModel) xo.getChild(TreeModel.class); //String traitName = (String) xo.getAttribute(TRAIT_NAME); XMLObject cxo = xo.getChild(K); Parameter k = (Parameter) cxo.getChild(Parameter.class); cxo = xo.getChild(EXCISIONPOINTS); Parameter excisionPoints = (Parameter) cxo.getChild(Parameter.class); cxo = xo.getChild(CLUSTERLABELS); Parameter clusterLabels = (Parameter) cxo.getChild(Parameter.class); cxo = xo.getChild(MU); MatrixParameter mu = (MatrixParameter) cxo.getChild(MatrixParameter.class); //alternative way to load in the MatrixParameter? // MatrixParameter serumLocationsParameter = null; // if (xo.hasChildNamed(SERUM_LOCATIONS)) { // serumLocationsParameter = (MatrixParameter) xo.getElementFirstChild(SERUM_LOCATIONS); //} // cxo=xo.getChild(LOCATION_DRIFT) ; // Parameter locationDrift= (Parameter) cxo.getChild(Parameter.class); cxo=xo.getChild(OFFSETS); Parameter offsets =(Parameter) cxo.getChild(Parameter.class); cxo=xo.getChild(VIRUS_LOCATIONS); MatrixParameter virusLocations =(MatrixParameter) cxo.getChild(MatrixParameter.class); boolean hasDrift = false; if (offsets.getDimension()>1){ hasDrift=true; } // TreeTraitParserUtilities utilities = new TreeTraitParserUtilities(); // String traitName = TreeTraitParserUtilities.DEFAULT_TRAIT_NAME; // TreeTraitParserUtilities.TraitsAndMissingIndices returnValue = // utilities.parseTraitsFromTaxonAttributes(xo, traitName, treeModel, integrate); // // traitName = returnValue.traitName; // CompoundParameter traitParameter = returnValue.traitParameter; //return new ClusterViruses (treeModel,traitParameter , K, excisionPoints, clusterLabels, mu, hasDrift); // return new ClusterViruses(treeModel, k, excisionPoints, clusterLabels, mu, hasDrift, locationDrift, offsets, virusLocations); return new ClusterViruses(treeModel, k, excisionPoints, clusterLabels, mu, hasDrift, offsets, virusLocations); } //************************************************************************ // AbstractXMLObjectParser implementation //************************************************************************ public String getParserDescription() { return "clustering viruses"; } public Class getReturnType() { return ClusterViruses.class; } public XMLSyntaxRule[] getSyntaxRules() { return rules; } private final XMLSyntaxRule[] rules = { // new StringAttributeRule(TreeTraitParserUtilities.TRAIT_NAME, "The name of the trait for which a likelihood should be calculated"), // new ElementRule(TREEMODEL, Parameter.class), new ElementRule(K, Parameter.class), new ElementRule(EXCISIONPOINTS, Parameter.class), new ElementRule(CLUSTERLABELS, Parameter.class), new ElementRule(MU, MatrixParameter.class), // new ElementRule(LOCATION_DRIFT, Parameter.class), new ElementRule(OFFSETS, Parameter.class), new ElementRule(VIRUS_LOCATIONS, MatrixParameter.class) }; }; String Atribute = null; }