/* * This file is part of ADDIS (Aggregate Data Drug Information System). * ADDIS is distributed from http://drugis.org/. * Copyright © 2009 Gert van Valkenhoef, Tommi Tervonen. * Copyright © 2010 Gert van Valkenhoef, Tommi Tervonen, Tijs Zwinkels, * Maarten Jacobs, Hanno Koeslag, Florin Schimbinschi, Ahmad Kamal, Daniel * Reid. * Copyright © 2011 Gert van Valkenhoef, Ahmad Kamal, Daniel Reid, Florin * Schimbinschi. * Copyright © 2012 Gert van Valkenhoef, Daniel Reid, Joël Kuiper, Wouter * Reckman. * Copyright © 2013 Gert van Valkenhoef, Joël Kuiper. * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program. If not, see <http://www.gnu.org/licenses/>. */ package org.drugis.addis.mcmcmodel; import gov.lanl.yadas.ArgumentMaker; import gov.lanl.yadas.BasicMCMCBond; import gov.lanl.yadas.ConstantArgument; import gov.lanl.yadas.Gaussian; import gov.lanl.yadas.GroupArgument; import gov.lanl.yadas.IdentityArgument; import gov.lanl.yadas.MCMCParameter; import gov.lanl.yadas.Uniform; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; import org.apache.commons.math3.random.JDKRandomGenerator; import org.apache.commons.math3.random.RandomGenerator; import org.drugis.addis.entities.Measurement; import org.drugis.common.stat.EstimateWithPrecision; import org.drugis.mtc.Parameter; import org.drugis.mtc.summary.NormalSummary; import org.drugis.mtc.util.DerSimonianLairdPooling; import org.drugis.mtc.yadas.AbstractYadasModel; import org.drugis.mtc.yadas.YadasSettings; abstract public class AbstractBaselineModel<T extends Measurement> extends AbstractYadasModel { protected List<T> d_measurements; protected final RandomGenerator d_rng = new JDKRandomGenerator(); private Parameter d_muParam = new Parameter() { public String getName() { return("mu"); } public String toString() { return getName(); }; }; private Parameter d_sigmaParam = new Parameter() { public String getName() { return("sd"); } public String toString() { return getName(); }; }; private NormalSummary d_summary; public AbstractBaselineModel(List<T> measurements) { super(new YadasSettings(5000, 15000, 10, 4, 2.5)); setTuningIterations(5000); setSimulationIterations(15000); d_results.setDirectParameters(Collections.singletonList(d_muParam)); d_summary = new NormalSummary(d_results, d_muParam); d_measurements = measurements; } public NormalSummary getSummary() { return d_summary; } @Override protected List<Parameter> getParameters() { return Arrays.asList(d_muParam, d_sigmaParam); } protected double getStandardDeviationPrior() { // FIXME: the factor 2 below is rather arbitrary. However, it is required to make // the tests pass for network-br. Until baselines can be specified explicitly, it // should remain there. double maxDev = 0.0; for (int i = 0; i < d_measurements.size() - 1; ++i) { EstimateWithPrecision e1 = estimateTreatmentEffect(i); for (int j = i + 1; j < d_measurements.size(); ++j) { EstimateWithPrecision e2 = estimateTreatmentEffect(j); maxDev = Math.max(maxDev, Math.abs(e2.getPointEstimate() - e1.getPointEstimate())); } } return 2 * maxDev; } protected abstract EstimateWithPrecision estimateTreatmentEffect(int i); protected abstract void createDataBond(MCMCParameter studyMu); protected double[] initializeStandardDeviation() { return new double[] {d_rng.nextDouble() * getStandardDeviationPrior()}; } private double[] initializeMean() { List<EstimateWithPrecision> estimates = new ArrayList<EstimateWithPrecision>(); for (int i = 0; i < d_measurements.size(); ++i) { estimates.add(estimateTreatmentEffect(i)); } EstimateWithPrecision pooled = new DerSimonianLairdPooling(estimates).getPooled(); return new double[] {generate(pooled)}; } private double[] initializeStudyMeans() { double[] means = new double[d_measurements.size()]; for (int i = 0; i < means.length; ++i) { final EstimateWithPrecision e = estimateTreatmentEffect(i); means[i] = generate(e); } return means; } private double generate(final EstimateWithPrecision e) { return e.getPointEstimate() + d_rng.nextGaussian() * getSettings().getVarianceScalingFactor() * e.getStandardError(); } @Override protected void prepareModel() { } @Override protected void createChain(int chain) { MCMCParameter studyMu = new MCMCParameter(initializeStudyMeans(), doubleArray(0.1, d_measurements.size()), null); MCMCParameter mu = new MCMCParameter(initializeMean(), new double[] {0.1}, null); MCMCParameter sd = new MCMCParameter(initializeStandardDeviation(), new double[] {0.1}, null); // data bond createDataBond(studyMu); // studyMu bond new BasicMCMCBond(new MCMCParameter[] {studyMu, mu, sd}, new ArgumentMaker[] { new IdentityArgument(0), new GroupArgument(1, new int[d_measurements.size()]), new GroupArgument(2, new int[d_measurements.size()]) }, new Gaussian()); // priors new BasicMCMCBond(new MCMCParameter[] {mu}, new ArgumentMaker[] { new IdentityArgument(0), new ConstantArgument(0.0), new ConstantArgument(15 * getStandardDeviationPrior()) }, new Gaussian()); new BasicMCMCBond(new MCMCParameter[] {sd}, new ArgumentMaker[] { new IdentityArgument(0), new ConstantArgument(0.0), new ConstantArgument(getStandardDeviationPrior()) }, new Uniform()); List<MCMCParameter> parameters = new ArrayList<MCMCParameter>(); parameters.add(studyMu); parameters.add(mu); parameters.add(sd); addTuners(parameters); addWriters(Arrays.asList( d_results.getParameterWriter(d_muParam, chain, mu, 0), d_results.getParameterWriter(d_sigmaParam, chain, sd, 0))); } protected double[] doubleArray(double val, int size) { double[] arr = new double[size]; Arrays.fill(arr, val); return arr; } protected abstract double getError(int i); }