/*
* DirichletProcessPrior.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.evomodel.branchmodel.lineagespecific;
import java.util.ArrayList;
import java.util.List;
import dr.app.bss.Utils;
import dr.inference.distribution.ParametricMultivariateDistributionModel;
import dr.inference.model.AbstractModelLikelihood;
import dr.inference.model.CompoundParameter;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.inference.model.Variable.ChangeType;
@SuppressWarnings("serial")
public class DirichletProcessPrior extends AbstractModelLikelihood {
private static boolean VERBOSE = false;
private Parameter categoriesParameter;
private CompoundParameter uniquelyRealizedParameters;
public ParametricMultivariateDistributionModel baseModel;
private Parameter gamma;
private int categoryCount;
private int N;
private boolean likelihoodKnown = false;
private double logLikelihood;
private final List<Double> cachedLogFactorials;
public DirichletProcessPrior(Parameter categoriesParameter, //
CompoundParameter uniquelyRealizedParameters, //
ParametricMultivariateDistributionModel baseModel, //
Parameter gamma //
) {
super("");
// vector z of cluster assignments
this.categoriesParameter = categoriesParameter;
this.baseModel = baseModel;
this.uniquelyRealizedParameters = uniquelyRealizedParameters;
this.gamma = gamma;
// K clusters
this.categoryCount = uniquelyRealizedParameters.getDimension();
// this.categoryCount=Utils.findMaximum(categoriesParameter.getParameterValues()) + 1;
this.N = categoriesParameter.getDimension();
cachedLogFactorials = new ArrayList<Double>();
cachedLogFactorials.add(0, 0.0);
// add all
this.addVariable(this.categoriesParameter);
this.addVariable(this.gamma);
this.addVariable(this.uniquelyRealizedParameters);
if(baseModel != null) {
this.addModel(baseModel);
}
this.likelihoodKnown = false;
}// END: Constructor
private double getLogFactorial(int i) {
if ( cachedLogFactorials.size() <= i) {
for (int j = cachedLogFactorials.size() - 1; j <= i; j++) {
double logfactorial = cachedLogFactorials.get(j) + Math.log(j + 1);
cachedLogFactorials.add(logfactorial);
}
}
return cachedLogFactorials.get(i);
}
/**
* Assumes mappings start from index 0
* */
private int[] getCounts() {
// eta_k parameters (number of assignments to each category)
int[] counts = new int[categoryCount];
for (int i = 0; i < N; i++) {
int category = getMapping(i);
counts[category]++;
}// END: i loop
return counts;
}// END: getCounts
public double getGamma() {
return gamma.getParameterValue(0);
}
private int getMapping(int i) {
return (int) categoriesParameter.getParameterValue(i);
}
public double getLogDensity(Parameter parameter) {
double value[] = parameter.getAttributeValue();
return baseModel.logPdf(value);
}
public double getRealizedValuesLogDensity() {
double total = 0.0;
for (int i = 0; i < categoryCount; i++) {
Parameter param = uniquelyRealizedParameters.getParameter(i);
total += getLogDensity(param);
}
return total;
}//END: getRealizedValuesLogDensity
public double getCategoriesLogDensity() {
int[] counts = getCounts();
if (VERBOSE) {
Utils.printArray(counts);
}
double loglike = categoryCount * Math.log(getGamma());
for (int k = 0; k < categoryCount; k++) {
int eta = counts[k];
if (eta > 0) {
loglike += getLogFactorial(eta - 1);
}
}// END: k loop
for (int i = 1; i <= N; i++) {
loglike -= Math.log(getGamma() + i - 1);
}// END: i loop
return loglike;
}// END: getPriorLoglike
@Override
public Model getModel() {
return this;
}
@Override
public double getLogLikelihood() {
this.fireModelChanged();
likelihoodKnown = false;
if (!likelihoodKnown) {
logLikelihood = calculateLogLikelihood();
likelihoodKnown = true;
}
return logLikelihood;
}
private double calculateLogLikelihood() {
// getCounts();
double loglike = getCategoriesLogDensity() + getRealizedValuesLogDensity();
//TODOs
// System.out.println(loglike);
return loglike;
}//END: calculateLogLikelihood
@Override
public void makeDirty() {
// likelihoodKnown = false;
}
@Override
protected void handleModelChangedEvent(Model model, Object object, int index) {
likelihoodKnown = false;
}
public int getCategoryCount() {
return categoryCount;
}
public Parameter getUniqueParameters() {
return uniquelyRealizedParameters;
}
public Parameter getUniqueParameter(int index) {
return uniquelyRealizedParameters.getParameter(index);
}
@Override
protected void handleVariableChangedEvent(Variable variable, int index,
ChangeType type) {
if (variable == categoriesParameter) {
this.fireModelChanged();
} else if (variable == gamma) {
// do nothing
this.fireModelChanged();
} else if (variable == uniquelyRealizedParameters) {
likelihoodKnown = false;
this.fireModelChanged();
} else {
throw new IllegalArgumentException("Unknown parameter");
}
}// END: handleVariableChangedEvent
public void setVerbose() {
VERBOSE = true;
}
@Override
protected void storeState() {
}
@Override
protected void restoreState() {
likelihoodKnown = false;
}
@Override
protected void acceptState() {
}
public static void main(String args[]) {
testDirichletProcess(new double[] { 0, 1, 2 }, 3, 1.0, -Math.log(6.0));
testDirichletProcess(new double[] { 0, 0, 1 }, 3, 1.0, -Math.log(6.0));
testDirichletProcess(new double[] { 0, 1, 2, 3, 4 }, 5, 0.5, -6.851184927493743);
}// END: main
private static void testDirichletProcess(double[] mapping, int categoryCount,double gamma,
double expectedLogL) {
Parameter categoriesParameter = new Parameter.Default(mapping);
Parameter gammaParameter = new Parameter.Default(gamma);
CompoundParameter dummy = new CompoundParameter("dummy");
for (int i = 0; i < categoryCount; i++) {
dummy.addParameter(new Parameter.Default(1.0));
}
DirichletProcessPrior dpp = new DirichletProcessPrior(
categoriesParameter, dummy, null, gammaParameter);
dpp.setVerbose();
// int[] counts = dpp.getCounts();
System.out.println("lnL: " + dpp.getCategoriesLogDensity());
System.out.println("expected lnL: " + expectedLogL);
}// END: testDirichletProcess
}// END: class