package dr.evomodel.antigenic;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.logging.Logger;
import com.sun.corba.se.impl.orbutil.graph.Node;
import dr.evolution.tree.Tree;
import dr.evolution.tree.NodeRef;
import dr.evomodel.tree.TreeModel;
import dr.evomodelxml.treelikelihood.TreeTraitParserUtilities;
import dr.inference.model.AbstractModelLikelihood;
import dr.inference.model.CompoundParameter;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.inference.operators.MCMCOperator;
import dr.math.distributions.MultivariateNormalDistribution;
import dr.math.matrixAlgebra.SymmetricMatrix;
import dr.xml.AbstractXMLObjectParser;
import dr.xml.AttributeRule;
import dr.xml.ElementRule;
import dr.xml.StringAttributeRule;
import dr.xml.XMLObject;
import dr.xml.XMLObjectParser;
import dr.xml.XMLParseException;
import dr.xml.XMLSyntaxRule;
public class NPAntigenicLikelihood extends AbstractModelLikelihood {
public static final String NP_ANTIGENIC_LIKELIHOOD = "NPAntigenicLikelihood";
public NPAntigenicLikelihood (TreeModel treeModel,CompoundParameter traitParameter , Parameter assignments, Parameter clusterVar, Parameter priorMean, Parameter priorVar, double transformFactor ){
super(NP_ANTIGENIC_LIKELIHOOD);
this.assignments = assignments;
this.clusterVar = clusterVar;
this.priorVar = priorVar;
this.priorMean = priorMean;
this.treeModel= treeModel;
this.traitParameter= traitParameter;
this.transformFactor=transformFactor;
this.alpha= 1.0;
int dim = traitParameter.getParameter(0).getSize();
numdata = traitParameter.getParameterCount();
double Data[][] =new double[numdata][dim];
for (int i=0; i<numdata; i++){
for (int j=0; j<dim; j++){
Data[i][j]= traitParameter.getParameter(i).getParameterValue(j);
}
}
this.data=Data;
depMatrix=new double[numdata][numdata];
List<NodeRef> childList = new ArrayList<NodeRef>();
this.allTips=Tree.Utils.getExternalNodes(treeModel,treeModel.getRoot());
recursion(treeModel.getRoot(),childList);
logCorrectMatrix(transformFactor);
printInformtion(depMatrix);
logDepMatrix = new double[numdata][numdata];
for(int i=0;i<numdata;i++){
for(int j=0;j<i;j++){
logDepMatrix[i][j]=Math.log(depMatrix[i][j]);
logDepMatrix[j][i]=logDepMatrix[j][i];
}
}
//double[][] depMatrix1 = getMatrixFromTree(transformFactor);
//printInformtion(depMatrix1);
for (int i=0; i<numdata; i++){
assignments.setParameterValue(i, i);
}
this.logLikelihoodsVector = new double[assignments.getDimension()+1];
double[][] var = new double[2][2];
var[0][0]= clusterVar.getParameterValue(0)+priorVar.getParameterValue(0);
var[1][1]= clusterVar.getParameterValue(0)+priorVar.getParameterValue(0);
var[1][0]=0.0;
var[0][1] = 0.0;
double[][] precision = new SymmetricMatrix(var).inverse().toComponents();
double[] mean = new double[2];
mean[0]= priorMean.getParameterValue(0);
mean[1]= priorMean.getParameterValue(1);
MultivariateNormalDistribution MVN = new MultivariateNormalDistribution(mean, precision);
for(int i=0;i<logLikelihoodsVector.length-1;i++){
double[] d = new double[2];
d[0] = data[i][0];
d[1] = data[i][1];
logLikelihoodsVector[i]= MVN.logPdf(d);
}
}
public Model getModel() {
return this;
}
public double[] getLogLikelihoodsVector(){
return logLikelihoodsVector;
}
public double[][] getData(){
return data;
}
public double[][] getDepMatrix(){
return depMatrix;
}
public double[][] getLogDepMatrix(){
return logDepMatrix;
}
public Parameter getPriorMean(){
return priorMean;
}
public Parameter getPriorVar(){
return priorVar;
}
public Parameter getClusterVar(){
return clusterVar;
}
public void setLogLikelihoodsVector(int pos, double value){
logLikelihoodsVector[pos]=value;
}
public double getLogLikelihood() {
double logL = 0.0;
for (int j=0 ; j<assignments.getDimension();j++){
if(logLikelihoodsVector[j]!=0){
logL +=logLikelihoodsVector[j];
}
}
for (int j=0 ; j<assignments.getDimension();j++){
if(assignments.getParameterValue(j)==j){
logL += Math.log(alpha);
}
else{logL += Math.log(depMatrix[j][(int) assignments.getParameterValue(j)]);
}
double sumDist=0.0;
for (int i=0;i<numdata;i++){
if(i!=j){sumDist += depMatrix[i][j];
}
}
logL-= Math.log(alpha+sumDist);
}
return logL;
}
/* Marc's suggestion on recursion for getting matrix from tree*/
void recursion( NodeRef node, List childList){
List<NodeRef> leftChildTipList = new ArrayList<NodeRef>();
List<NodeRef> rightChildTipList = new ArrayList<NodeRef>();
if(!treeModel.isExternal(node)){
recursion(treeModel.getChild(node, 0),leftChildTipList);
recursion(treeModel.getChild(node, 1),rightChildTipList);
double lBranch = treeModel.getBranchLength(treeModel.getChild(node, 0));
double rBranch = treeModel.getBranchLength(treeModel.getChild(node, 1));
Set<NodeRef> notLeftChildList = new HashSet<NodeRef>();
notLeftChildList.addAll(allTips);
for (NodeRef i :leftChildTipList){
notLeftChildList.remove(i);
}
Set<NodeRef> notRightChildList = new HashSet<NodeRef>();
notRightChildList.addAll(allTips);
for (NodeRef i :rightChildTipList){
notRightChildList.remove(i);
}
for (NodeRef lChild : leftChildTipList){
for (NodeRef Child : notLeftChildList){
depMatrix[Child.getNumber()][lChild.getNumber()] += lBranch;
depMatrix[lChild.getNumber()][Child.getNumber()] += lBranch;
}
}
for (NodeRef rChild : rightChildTipList){
for (NodeRef Child : notRightChildList){
depMatrix[Child.getNumber()][rChild.getNumber()] += rBranch;
depMatrix[rChild.getNumber()][Child.getNumber()] += rBranch;
}
}
childList.addAll(leftChildTipList);
childList.addAll(rightChildTipList);
}
else{
childList.add(node);
}
}
void logCorrectMatrix(double p){
for (int i=0; i<numdata; i++){
for (int j=0; j<i; j++){
depMatrix[i][j]=1/Math.pow(depMatrix[i][j],p);
depMatrix[j][i]=depMatrix[i][j];
}}
}
// Slow method for computing matrix from tree
public double[][] getMatrixFromTree(double p){
double[][] Mat = new double[numdata][numdata];
for (int i = 0 ; i<numdata; i++){
for (int j =0 ; j<i; j++){
Mat[i][j] = -p*Math.log(getTreeDist(i,j));
Mat[j][i] = Mat[i][j];
}
}
return Mat;
}
public double getTreeDist(int i, int j){
double dist=0;
NodeRef MRCA = findMRCA(i,j);
NodeRef Parent = treeModel.getExternalNode(i);
while (Parent!=MRCA){
dist+=treeModel.getBranchLength(Parent);
Parent = treeModel.getParent(Parent);
}
Parent = treeModel.getExternalNode(j);
while (Parent!=MRCA){
dist+=treeModel.getBranchLength(Parent);
Parent = treeModel.getParent(Parent);
}
return dist;
}
private NodeRef findMRCA(int iTip, int jTip) {
Set<String> leafNames = new HashSet<String>();
leafNames.add(treeModel.getTaxonId(iTip));
leafNames.add(treeModel.getTaxonId(jTip));
return Tree.Utils.getCommonAncestorNode(treeModel, leafNames);
}
public void printInformtion(double[][] Mat) {
StringBuffer sb = new StringBuffer("matrix \n");
for(int i=0;i <numdata; i++){
sb.append(" \n");
for(int j=0; j<numdata; j++){
sb.append(Mat[i][j]+" \t");
}
}
Logger.getLogger("dr.evomodel").info(sb.toString()); };
public void printOrder() {
StringBuffer sb = new StringBuffer("taxa \n");
for(int i=0;i <numdata; i++){
sb.append(" \n");
sb.append(treeModel.getTaxonId(i));
}
Logger.getLogger("dr.evomodel").info(sb.toString()); };
public void printInformtion(double x) {
StringBuffer sb = new StringBuffer("Info \n");
sb.append(x);
Logger.getLogger("dr.evomodel").info(sb.toString()); };
public void makeDirty() {
}
public void acceptState() {
// DO NOTHING
}
public void restoreState() {
// DO NOTHING
}
public void storeState() {
// DO NOTHING
}
protected void handleModelChangedEvent(Model model, Object object, int index) {
// DO NOTHING
}
protected final void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) {
// DO NOTHING
}
Set<NodeRef> allTips;
CompoundParameter traitParameter;
double alpha;
Parameter clusterVar ;
Parameter priorVar ;
Parameter priorMean ;
Parameter assignments;
TreeModel treeModel;
String traitName;
double[][] data;
double[][] depMatrix;
double[][] logDepMatrix;
double[] logLikelihoodsVector;
int numdata;
double transformFactor;
public static XMLObjectParser PARSER = new AbstractXMLObjectParser() {
public final static String CLUSTER_VAR = "clusterVar";
public final static String PRIOR_VAR = "priorVar";
public final static String PRIOR_MEAN = "priorMean";
public final static String ASSIGNMENTS = "assignments";
public final static String TRAIT_NAME = "traitName";
public final static String TRANSFORM_FACTOR = "transformFactor";
boolean integrate = false;
public String getParserName() {
return NP_ANTIGENIC_LIKELIHOOD;
}
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
TreeModel treeModel = (TreeModel) xo.getChild(TreeModel.class);
//String traitName = (String) xo.getAttribute(TRAIT_NAME);
XMLObject cxo = xo.getChild(CLUSTER_VAR);
Parameter clusterVar = (Parameter) cxo.getChild(Parameter.class);
cxo = xo.getChild(PRIOR_VAR);
Parameter priorVar = (Parameter) cxo.getChild(Parameter.class);
cxo = xo.getChild(PRIOR_MEAN);
Parameter priorMean = (Parameter) cxo.getChild(Parameter.class);
cxo = xo.getChild(ASSIGNMENTS);
Parameter assignments = (Parameter) cxo.getChild(Parameter.class);
double transformFactor=1.0;
if(xo.hasAttribute(TRANSFORM_FACTOR)){
transformFactor = xo.getDoubleAttribute(TRANSFORM_FACTOR);
}
TreeTraitParserUtilities utilities = new TreeTraitParserUtilities();
String traitName = TreeTraitParserUtilities.DEFAULT_TRAIT_NAME;
TreeTraitParserUtilities.TraitsAndMissingIndices returnValue =
utilities.parseTraitsFromTaxonAttributes(xo, traitName, treeModel, integrate);
// traitName = returnValue.traitName;
CompoundParameter traitParameter = returnValue.traitParameter;
return new NPAntigenicLikelihood(treeModel,traitParameter, assignments, clusterVar, priorMean,priorVar,transformFactor);
}
//************************************************************************
// AbstractXMLObjectParser implementation
//************************************************************************
public String getParserDescription() {
return "conditional likelihood ddCRP";
}
public Class getReturnType() {
return NPAntigenicLikelihood.class;
}
public XMLSyntaxRule[] getSyntaxRules() {
return rules;
}
private final XMLSyntaxRule[] rules = {
new StringAttributeRule(TreeTraitParserUtilities.TRAIT_NAME, "The name of the trait for which a likelihood should be calculated"),
AttributeRule.newDoubleRule(TRANSFORM_FACTOR,true,"p in transformation of distances -p*log(dist)"),
new ElementRule(TreeTraitParserUtilities.TRAIT_PARAMETER, new XMLSyntaxRule[]{
new ElementRule(Parameter.class)
}),
new ElementRule(PRIOR_VAR,
new XMLSyntaxRule[]{new ElementRule(Parameter.class)}),
new ElementRule(CLUSTER_VAR,
new XMLSyntaxRule[]{new ElementRule(Parameter.class)}),
new ElementRule(PRIOR_MEAN,
new XMLSyntaxRule[]{new ElementRule(Parameter.class)}),
new ElementRule(ASSIGNMENTS,
new XMLSyntaxRule[]{new ElementRule(Parameter.class)}),
new ElementRule(TreeModel.class),
};
};
String Atribute = null;
}