/*******************************************************************************
* Copyright 2013 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.solvers.gibbs.samplers.conjugate;
import static com.analog.lyric.dimple.environment.DimpleEnvironment.*;
import static java.util.Objects.*;
import java.util.List;
import org.eclipse.jdt.annotation.Nullable;
import com.analog.lyric.dimple.data.IDatum;
import com.analog.lyric.dimple.exceptions.DimpleException;
import com.analog.lyric.dimple.factorfunctions.Dirichlet;
import com.analog.lyric.dimple.factorfunctions.ExchangeableDirichlet;
import com.analog.lyric.dimple.factorfunctions.core.IUnaryFactorFunction;
import com.analog.lyric.dimple.model.domains.RealDomain;
import com.analog.lyric.dimple.model.domains.RealJointDomain;
import com.analog.lyric.dimple.model.values.Value;
import com.analog.lyric.dimple.solvers.core.parameterizedMessages.DirichletParameters;
import com.analog.lyric.dimple.solvers.core.parameterizedMessages.IParameterizedMessage;
import com.analog.lyric.dimple.solvers.interfaces.ISolverEdgeState;
import com.analog.lyric.math.DimpleRandom;
public class DirichletSampler implements IRealJointConjugateSampler
{
private DirichletParameters _parameters = new DirichletParameters();
private int _dimension = -1;
@Override
public final double[] nextSample(ISolverEdgeState[] edges, List<? extends IDatum> inputs)
{
aggregateParameters(_parameters, edges, inputs);
return nextSample(_parameters);
}
@Override
public final void aggregateParameters(IParameterizedMessage aggregateParameters, ISolverEdgeState[] edges,
List<? extends IDatum> inputs)
{
if (_dimension < 0) // Just do this once
setDimension(edges, inputs);
final int dimension = _dimension;
DirichletParameters parameters = (DirichletParameters)aggregateParameters;
if (parameters.getSize() != dimension)
parameters.setSize(dimension);
parameters.setNull();
for (IDatum input : inputs)
{
DirichletParameters inputParameters = null;
if (input instanceof DirichletParameters)
inputParameters = (DirichletParameters)input;
else if (input instanceof Dirichlet)
inputParameters = ((Dirichlet)input).getParameterizedMessage();
else if (input instanceof ExchangeableDirichlet)
inputParameters = ((ExchangeableDirichlet)input).getParameterizedMessage();
else
continue; // This should not be possible
if (inputParameters.getSize() != dimension)
throw new DimpleException("All inputs to Dirichlet sampler must have the same number of dimensions");
parameters.addFrom(inputParameters);
}
final int numEdges = edges.length;
for (int i = 0; i < numEdges; i++)
{
// The message from each neighboring factor is an array with elements (alpha, beta)
DirichletParameters message = requireNonNull((DirichletParameters)edges[i].getFactorToVarMsg());
int messageSize = message.getSize();
if (messageSize == 0) // Uninitialized message, which implies uninformative
{
message.setSize(dimension);
message.setNull();
continue;
}
else if (messageSize != dimension)
throw new DimpleException("All inputs to Dirichlet sampler must have the same number of dimensions");
parameters.addFrom(message);
}
}
public final double[] nextSample(DirichletParameters parameters)
{
final DimpleRandom rand = activeRandom();
// Sample from a series of Gamma distributions, then normalize to sum to 1
int dimension = parameters.getSize();
double[] sample = new double[dimension];
double sum = 0;
int numZeros = 0;
for (int i = 0; i < dimension; i++)
{
double nextSample = rand.nextGamma(parameters.getAlphaMinusOne(i) + 1, 1);
sample[i] = nextSample;
sum += nextSample;
if (nextSample == 0)
numZeros++;
}
if (numZeros == 0)
{
for (int i = 0; i < dimension; i++)
sample[i] /= sum;
}
else if (numZeros < dimension)
{
// Corner case where some, but not all, of the samples are zero
// Add a little to the zero sample values and adjust the others accordingly
double zeroAdjustment = Double.MIN_VALUE * numZeros / (dimension - numZeros);
for (int i = 0; i < dimension; i++)
{
if (sample[i] == 0)
sample[i] = Double.MIN_VALUE;
else
sample[i] = (sample[i] / sum) - zeroAdjustment;
}
}
else
{
// Corner case where all samples were zero
// Choose one sample value at random, make that (nearly) one, and the others (nearly) zero
int randomChoice = rand.nextInt(dimension);
for (int i = 0; i < dimension; i++)
if (i != randomChoice)
sample[i] = Double.MIN_VALUE;
sample[randomChoice] = 1 - Double.MIN_VALUE * (dimension - 1);
}
return sample;
}
@Override
public IParameterizedMessage createParameterMessage()
{
return new DirichletParameters();
}
@SuppressWarnings("null")
private void setDimension(ISolverEdgeState[] sedges, List<? extends IDatum> inputs)
{
int numEdges = sedges.length;
int dimension = 0;
if (numEdges > 0)
dimension = ((DirichletParameters)sedges[0].getFactorToVarMsg()).getSize();
else if (inputs.size() > 0)
{
IDatum input = inputs.get(0);
if (input instanceof DirichletParameters)
dimension = ((DirichletParameters)input).getSize();
else if (input instanceof Value)
dimension = ((Value)input).getDoubleArray().length;
else if (input instanceof Dirichlet)
dimension = ((Dirichlet)input).getDimension();
else if (input instanceof ExchangeableDirichlet)
dimension = ((ExchangeableDirichlet)input).getDimension();
}
if (dimension == 0)
{
throw new DimpleException("Cannot determine Dirichlet dimension from edges or inputs.");
}
_parameters.setSize(dimension);
_dimension = dimension;
}
// A static factory that creates a sampler of this type
public static final IRealJointConjugateSamplerFactory factory = new IRealJointConjugateSamplerFactory()
{
@Override
public IRealJointConjugateSampler create() {return new DirichletSampler();}
@Override
public boolean isCompatible(@Nullable IUnaryFactorFunction factorFunction)
{
if (factorFunction == null)
return true;
else if (factorFunction instanceof Dirichlet || factorFunction instanceof DirichletParameters)
return true;
else if (factorFunction instanceof ExchangeableDirichlet)
return true;
else
return false;
}
@Override
public boolean isCompatible(RealJointDomain domain)
{
for (RealDomain d : domain.getRealDomains())
{
if (d.getLowerBound() > 0 || d.getUpperBound() < 1)
return false;
}
return true;
}
};
}