/* * SelfControlledCaseSeries.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.inference.regression; import dr.inference.model.AbstractModelLikelihood; import dr.inference.model.Model; import dr.inference.model.Parameter; import dr.inference.model.Variable; import dr.xml.*; import java.util.*; /** * @author Marc Suchard * @author Trevor Shaddox */ public class SelfControlledCaseSeries extends AbstractModelLikelihood { public static final String SCCS_NAME = "selfControlledCaseSeries"; public static final String FILE_NAME = "fileName"; public static final String BETA = "beta"; public static final String PRECISION = "precision"; public SelfControlledCaseSeries(String name, String fileName, Parameter beta, Parameter precision) { super(name); regressionInterface = RegressionJNIWrapper.loadLibrary(); // Load data and find mode instance = regressionInterface.loadData(fileName); regressionInterface.setPriorType(instance, RegressionJNIWrapper.NORMAL_PRIOR); this.precision = precision; setPrecision(); precisionChanged = true; // Set beta to mode final int dim = regressionInterface.getBetaSize(instance); if (dim != beta.getDimension()) { beta.setDimension(dim); } // Start beta at mode (given precision) this.beta = beta; double[] mode = getMode(); for (int i = 0; i < beta.getDimension(); ++i) { beta.setParameterValue(i, mode[i]); } logSCCSLikelihood = regressionInterface.getLogLikelihood(instance); logSCCSPrior = regressionInterface.getLogPrior(instance); betaChanged = false; // Internal state is at mode addVariable(beta); addVariable(precision); } private void setPrecision() { regressionInterface.setHyperprior(instance, 1.0 / precision.getParameterValue(0)); } public double[] getMode() { if (precisionChanged) { setPrecision(); mode = null; } if (mode == null) { regressionInterface.findMode(instance); mode = new double[beta.getDimension()]; for (int i = 0; i < beta.getDimension(); ++i) { mode[i] = regressionInterface.getBeta(instance, i); } betaChanged = true; // Internal beta-state is at mode, not betaParameter // betaFlag.clear(); newMode = true; // System.err.println("A"); if (DEBUG_MODE) { System.err.println("Recomputed mode!"); } } double[] rtn = new double[mode.length]; System.arraycopy(mode, 0, rtn, 0, mode.length); return rtn; } @Override protected void handleModelChangedEvent(Model model, Object object, int index) { // Do nothing } /** * This method is called whenever a parameter is changed. * <p/> * It is strongly recommended that the model component sets a "dirty" flag and does no * further calculations. Recalculation is typically done when the model component is asked for * some information that requires them. This mechanism is 'lazy' so that this method * can be safely called multiple times with minimal computational cost. */ @Override protected void handleVariableChangedEvent(Variable variable, int index, Variable.ChangeType type) { if (variable == beta) { betaChanged = true; if (type == Variable.ChangeType.ALL_VALUES_CHANGED) { betaFlag.clear(); } else { betaFlag.add(index); } } else if (variable == precision) { precisionChanged = true; } else { throw new IllegalArgumentException("Unknown variable in SCCS"); } } /** * Additional state information, outside of the sub-model is stored by this call. */ @Override protected void storeState() { storedLogSCCSLikelihood = logSCCSLikelihood; storedLogSCCSPrior = logSCCSPrior; storedBetaChanged = betaChanged; storedPrecisionChanged = precisionChanged; storedPrecision = precision.getParameterValue(0); } /** * After this call the model is guaranteed to have returned its extra state information to * the values coinciding with the last storeState call. * Sub-models are handled automatically and do not need to be considered in this method. */ @Override protected void restoreState() { logSCCSLikelihood = storedLogSCCSLikelihood; logSCCSPrior = storedLogSCCSPrior; betaChanged = storedBetaChanged; precisionChanged = storedPrecisionChanged; } /** * This call specifies that the current state is accept. Most models will not need to do anything. * Sub-models are handled automatically and do not need to be considered in this method. */ @Override protected void acceptState() { if (storedPrecision != precision.getParameterValue(0)) { mode = null; // Accepted new precision state; mode has moved } } /** * Get the model. * * @return the model. */ public Model getModel() { return this; } /** * Get the log likelihood. * * @return the log likelihood. */ public double getLogLikelihood() { return calculateLogLikelihood(); } private double calculateLogLikelihood() { if (betaChanged) { if (betaFlag.isEmpty() || newMode) { regressionInterface.setBeta(instance, beta.getParameterValues()); newMode = false; } else { while (!betaFlag.isEmpty()) { final int index = betaFlag.remove(); regressionInterface.setBeta(instance, index, beta.getParameterValue(index)); } } } if (precisionChanged) { setPrecision(); } if (betaChanged) { logSCCSLikelihood = regressionInterface.getLogLikelihood(instance); } if (betaChanged || precisionChanged) { logSCCSPrior = regressionInterface.getLogPrior(instance); } betaChanged = false; precisionChanged = false; double logLike = logSCCSLikelihood + logSCCSPrior; if (DEBUG_LAZY) { double checkLike = regressionInterface.getLogLikelihood(instance); double checkPrior = regressionInterface.getLogPrior(instance); double check = checkLike + checkPrior; if (check != logLike) { System.err.println("Error in internal state in calculateLogLikelihood()"); System.err.println(checkLike + " " + logSCCSLikelihood + " d: " + (checkLike - logSCCSLikelihood)); System.err.println(checkPrior + " " + logSCCSPrior + " d: " + (checkPrior - logSCCSPrior)); System.err.println(betaChanged + " " + precisionChanged); System.exit(-1); } } return logLike; } /** * Forces a complete recalculation of the likelihood next time getLikelihood is called */ public void makeDirty() { betaChanged = true; newMode = true; precisionChanged = true; regressionInterface.makeDirty(instance); } public static XMLObjectParser PARSER = new AbstractXMLObjectParser() { public String getParserName() { return SCCS_NAME; } public Object parseXMLObject(XMLObject xo) throws XMLParseException { String fileName = xo.getStringAttribute(FILE_NAME); Parameter beta = (Parameter) xo.getElementFirstChild(BETA); Parameter precision = (Parameter) xo.getElementFirstChild(PRECISION); return new SelfControlledCaseSeries(xo.getId(), fileName, beta, precision); } public String getParserDescription() { return "Self-controlled case series design."; } public Class getReturnType() { return SelfControlledCaseSeries.class; } public XMLSyntaxRule[] getSyntaxRules() { return rules; } private final XMLSyntaxRule[] rules = { AttributeRule.newStringRule(FILE_NAME), new ElementRule(BETA, Parameter.class), new ElementRule(PRECISION, Parameter.class), }; }; private final RegressionJNIWrapper regressionInterface; private final int instance; private final Parameter beta; private final Parameter precision; private double logSCCSLikelihood; private double logSCCSPrior; private double storedLogSCCSLikelihood; private double storedLogSCCSPrior; private boolean betaChanged; private boolean precisionChanged; private boolean storedBetaChanged; private boolean storedPrecisionChanged; // private boolean[] betaFlag; // private boolean betaFlagAll; private Queue<Integer> betaFlag = new LinkedList<Integer>(); private boolean newMode = false; private double[] mode = null; private double storedPrecision; private static final boolean DEBUG_MODE = false; private static final boolean DEBUG_LAZY = false; }