/*
* CompoundGaussianProcess.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.math.distributions;
import dr.inference.distribution.DistributionLikelihood;
import dr.inference.model.CompoundLikelihood;
import dr.inference.model.Likelihood;
import dr.xml.Reportable;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.*;
/**
* @author Marc A. Suchard
*/
public class CompoundGaussianProcess implements GaussianProcessRandomGenerator, Reportable {
private final List<GaussianProcessRandomGenerator> gpList;
private final List<Integer> copyList;
private final List<Likelihood> likelihoodList;
private final CompoundLikelihood compoundLikelihood;
private final ExecutorService pool;
private final int threadCount;
private final List<Callable<DrawResult>> callers;
private static final boolean USE_POOL = false;
private int dimension = -1;
public CompoundGaussianProcess(List<GaussianProcessRandomGenerator> gpList, List<Likelihood> likelihoodList,
List<Integer> copyList) {
this.gpList = gpList;
this.copyList = copyList;
this.likelihoodList = likelihoodList;
compoundLikelihood = new CompoundLikelihood(likelihoodList);
if (USE_POOL) {
callers = createTasks();
threadCount = callers.size();
pool = Executors.newFixedThreadPool(threadCount);
} else {
callers = null;
threadCount = -1;
pool = null;
}
}
public boolean contains(Likelihood likelihood) {
return likelihoodList.contains(likelihood);
}
public int getDimension() {
if (dimension == -1) {
dimension = 0;
for (GaussianProcessRandomGenerator gp : gpList) {
dimension += gp.getDimension();
}
}
return dimension;
}
@Override
public double[][] getPrecisionMatrix() {
if (gpList.size() == 1) {
return gpList.get(0).getPrecisionMatrix();
} else {
final int dim = getDimension();
double[][] precision = new double[dim][dim];
int offset = 0;
for (GaussianProcessRandomGenerator gp : gpList) {
final int d = gp.getDimension();
double[][] p = gp.getPrecisionMatrix();
for (int i = 0; i < d; ++i) {
System.arraycopy(p[i], 0, precision[offset + i], offset, d);
}
offset += d;
}
return precision;
}
}
@Override
public Likelihood getLikelihood() { return compoundLikelihood; }
@Override
public String getReport() {
StringBuilder sb = new StringBuilder();
sb.append("compoundGP: " + getLikelihood().getLogLikelihood());
return sb.toString();
}
private class DrawResult {
final double[] result;
final int offset;
DrawResult(double[] result, int offset) {
this.result = result;
this.offset = offset;
}
}
private class DrawCaller implements Callable<DrawResult> {
public DrawCaller(GaussianProcessRandomGenerator gp, int copies, int offset, boolean isUnivariate) {
this.gp = gp;
this.copies = copies;
this.offset = offset;
this.isUnivariate = isUnivariate;
}
public DrawResult call() throws Exception {
final double[] vector;
if (isUnivariate) {
vector = new double[copies];
for (int i = 0; i < copies; ++i) {
vector[i] = (Double) gp.nextRandom();
}
} else {
vector = (double[]) gp.nextRandom();
}
return new DrawResult(vector, offset);
}
private final GaussianProcessRandomGenerator gp;
private final int copies;
private final int offset;
private final boolean isUnivariate;
}
private List<Callable<DrawResult>> createTasks() {
List<Callable<DrawResult>> callers = new ArrayList<Callable<DrawResult>>();
int offset = 0;
int index = 0;
for (GaussianProcessRandomGenerator gp : gpList) {
final int copies = copyList.get(index);
if (likelihoodList.get(index) instanceof DistributionLikelihood) { // Univariate
callers.add(new DrawCaller(gp, copies, offset, true));
offset += copies;
} else {
for (int i = 0; i < copies; ++i) {
callers.add(new DrawCaller(gp, 1, offset, false));
offset += gp.getDimension();
}
}
}
return callers;
}
@Override
public Object nextRandom() {
if (USE_POOL) {
return nextRandomParallel();
} else {
return nextRandomSerial();
}
}
private Object nextRandomParallel() {
double[] vector = new double[getDimension()];
try {
List<Future<DrawResult>> results = pool.invokeAll(callers);
for (Future<DrawResult> result : results) {
DrawResult dr = result.get();
System.arraycopy(dr.result, 0, vector, dr.offset, dr.result.length);
}
} catch (InterruptedException e) {
e.printStackTrace();
} catch (ExecutionException e) {
e.printStackTrace();
}
return vector;
}
private Object nextRandomSerial() {
int size = 0;
List<double[]> randomList = new ArrayList<double[]>();
int index = 0;
for (GaussianProcessRandomGenerator gp : gpList) {
final int copies = copyList.get(index);
if (likelihoodList.get(index) instanceof DistributionLikelihood) { // Univariate
double[] vector = new double[copies];
for (int i = 0; i < copies; ++i) {
vector[i] = (Double) gp.nextRandom();
}
randomList.add(vector);
size += vector.length;
} else {
for (int i = 0; i < copyList.get(index); ++i) {
double[] vector = (double[]) gp.nextRandom();
randomList.add(vector);
size += vector.length;
}
}
++index;
}
double[] result = new double[size];
int offset = 0;
for (double[] vector : randomList) {
System.arraycopy(vector, 0, result, offset, vector.length);
offset += vector.length;
}
return result;
}
@Override
public double logPdf(Object x) {
throw new RuntimeException("Not yet implemented");
}
}