/* * WeightedMixtureModel.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.inference.model; import dr.inference.loggers.LogColumn; import dr.math.LogTricks; import dr.util.Citable; import dr.util.Citation; import dr.util.CommonCitations; import dr.xml.*; import java.util.*; import java.util.logging.Logger; /** * @author Marc A. Suchard * @author Andrew Rambaut */ public class WeightedMixtureModel extends AbstractModelLikelihood implements Citable { public static final String MIXTURE_MODEL = "mixtureModel"; // public static final String MIXTURE_WEIGHTS = "weights"; public static final String NORMALIZE = "normalize"; public WeightedMixtureModel(List<AbstractModelLikelihood> likelihoodList, Parameter mixtureWeights) { super(MIXTURE_MODEL); this.likelihoodList = likelihoodList; this.mixtureWeights = mixtureWeights; for (AbstractModelLikelihood model : likelihoodList) { addModel(model); } addVariable(mixtureWeights); StringBuilder sb = new StringBuilder(); sb.append("Constructing a finite mixture model\n"); sb.append("\tComponents:\n"); for (AbstractModelLikelihood model : likelihoodList) { sb.append("\t\t\t").append(model.getId()).append("\n"); } sb.append("\tMixing parameter: ").append(mixtureWeights.getId()).append("\n"); sb.append("\tPlease cite:\n"); sb.append(Citable.Utils.getCitationString((this))); Logger.getLogger("dr.inference.model").info(sb.toString()); } protected void handleModelChangedEvent(Model model, Object object, int index) { } protected final void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) { } protected void storeState() { } protected void restoreState() { } protected void acceptState() { } public Model getModel() { return this; } public double getLogLikelihood() { double logSum = Double.NEGATIVE_INFINITY; for (int i = 0; i < likelihoodList.size(); ++i) { double pi = mixtureWeights.getParameterValue(i); if (pi > 0.0) { logSum = LogTricks.logSum(logSum, Math.log(pi) + likelihoodList.get(i).getLogLikelihood()); } } return logSum; } public void makeDirty() { } public LogColumn[] getColumns() { return new LogColumn[0]; } public static XMLObjectParser PARSER = new AbstractXMLObjectParser() { public String getParserName() { return MIXTURE_MODEL; } public Object parseXMLObject(XMLObject xo) throws XMLParseException { Parameter weights = (Parameter) xo.getChild(Parameter.class); List<AbstractModelLikelihood> likelihoodList = new ArrayList<AbstractModelLikelihood>(); for (int i = 0; i < xo.getChildCount(); i++) { if (xo.getChild(i) instanceof Likelihood) likelihoodList.add((AbstractModelLikelihood) xo.getChild(i)); } if (weights.getDimension() != likelihoodList.size()) { throw new XMLParseException("Dim of " + weights.getId() + " does not match the number of likelihoods"); } if (xo.hasAttribute(NORMALIZE)) { if (xo.getBooleanAttribute(NORMALIZE)) { double sum = 0; for (int i = 0; i < weights.getDimension(); i++) sum += weights.getParameterValue(i); for (int i = 0; i < weights.getDimension(); i++) weights.setParameterValue(i, weights.getParameterValue(i) / sum); } } if (!normalized(weights)) throw new XMLParseException("Parameter +" + weights.getId() + " must lie on the simplex"); return new WeightedMixtureModel(likelihoodList, weights); } private boolean normalized(Parameter p) { double sum = 0; for (int i = 0; i < p.getDimension(); i++) sum += p.getParameterValue(i); return (sum == 1.0); } //************************************************************************ // AbstractXMLObjectParser implementation //************************************************************************ public String getParserDescription() { return "This element represents a finite mixture of likelihood models."; } public Class getReturnType() { return CompoundModel.class; } public XMLSyntaxRule[] getSyntaxRules() { return rules; } private final XMLSyntaxRule[] rules = { AttributeRule.newBooleanRule(NORMALIZE, true), new ElementRule(Likelihood.class,2,Integer.MAX_VALUE), new ElementRule(Parameter.class) }; }; private final Parameter mixtureWeights; List<AbstractModelLikelihood> likelihoodList; public static void main(String[] args) { final double l1 = -10; final double l2 = -2; AbstractModelLikelihood like1 = new AbstractModelLikelihood("dummy") { public Model getModel() { return null; } public double getLogLikelihood() { return l1; } public void makeDirty() { } public String prettyName() { return null; } public boolean isUsed() { return false; } @Override protected void handleModelChangedEvent(Model model, Object object, int index) { } @Override protected void handleVariableChangedEvent(Variable variable, int index, Variable.ChangeType type) { } @Override protected void storeState() { } @Override protected void restoreState() { } @Override protected void acceptState() { } public void setUsed() { } public LogColumn[] getColumns() { return new LogColumn[0]; } public String getId() { return null; } public void setId(String id) { } }; AbstractModelLikelihood like2 = new AbstractModelLikelihood("dummy") { public Model getModel() { return null; } public double getLogLikelihood() { return l2; } public void makeDirty() { } public String prettyName() { return null; } public boolean isUsed() { return false; } @Override protected void handleModelChangedEvent(Model model, Object object, int index) { } @Override protected void handleVariableChangedEvent(Variable variable, int index, Variable.ChangeType type) { } @Override protected void storeState() { } @Override protected void restoreState() { } @Override protected void acceptState() { } public void setUsed() { } public LogColumn[] getColumns() { return new LogColumn[0]; } public String getId() { return null; } public void setId(String id) { } }; List<AbstractModelLikelihood> likelihoodList = new ArrayList<AbstractModelLikelihood>(); likelihoodList.add(like1); likelihoodList.add(like2); Parameter weights = new Parameter.Default(2); double p1 = 0.05; weights.setParameterValue(0, p1); weights.setParameterValue(1, 1.0 - p1); WeightedMixtureModel mixture = new WeightedMixtureModel(likelihoodList, weights); System.err.println("getLogLikelihood() = " + mixture.getLogLikelihood()); double test = Math.log(p1 * Math.exp(l1) + (1.0 - p1) * Math.exp(l2)); System.err.println("correct = " + test); } @Override public Citation.Category getCategory() { return Citation.Category.MISC; } @Override public String getDescription() { return "Weighted mixture model"; } @Override public List<Citation> getCitations() { return Collections.singletonList(CommonCitations.LEMEY_MIXTURE_2012); } }