/*******************************************************************************
* Copyright 2013 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.solvers.gibbs.samplers.generic;
import static com.analog.lyric.dimple.environment.DimpleEnvironment.*;
import org.apache.commons.math3.random.RandomGenerator;
import com.analog.lyric.collect.ArrayUtil;
import com.analog.lyric.dimple.exceptions.DimpleException;
import com.analog.lyric.dimple.model.domains.DiscreteDomain;
import com.analog.lyric.dimple.model.domains.Domain;
import com.analog.lyric.dimple.model.values.DiscreteValue;
import com.analog.lyric.math.Utilities;
public class CDFSampler extends AbstractGenericSampler implements IDiscreteDirectSampler
{
protected double[] _samplerScratch = ArrayUtil.EMPTY_DOUBLE_ARRAY;
protected int _lengthRoundedUp = 0;
protected int _length = 0;
@Override
public void initialize(Domain variableDomain)
{
int length = ((DiscreteDomain)variableDomain).size();
_length = length;
_lengthRoundedUp = Utilities.nextPow2(length);
_samplerScratch = new double[_lengthRoundedUp];
}
@Override
public void nextSample(DiscreteValue sampleValue, double[] energy, double minEnergy, IDiscreteSamplerClient samplerClient)
{
final RandomGenerator rand = activeRandom();
final int length = sampleValue.getDomain().size(); //energy may be longer than domain size
int sampleIndex;
// Special-case lengths 2, 3, and 4 for speed
switch (length)
{
case 2:
{
sampleIndex = (rand.nextDouble() * (1 + Math.exp(energy[1]-energy[0])) > 1) ? 0 : 1;
break;
}
case 3:
{
final double cumulative1 = Math.exp(minEnergy-energy[0]);
final double cumulative2 = cumulative1 + Math.exp(minEnergy-energy[1]);
final double sum = cumulative2 + Math.exp(minEnergy-energy[2]);
final double randomValue = sum * rand.nextDouble();
sampleIndex = (randomValue > cumulative2) ? 2 : (randomValue > cumulative1) ? 1 : 0;
break;
}
case 4:
{
final double cumulative1 = Math.exp(minEnergy-energy[0]);
final double cumulative2 = cumulative1 + Math.exp(minEnergy-energy[1]);
final double cumulative3 = cumulative2 + Math.exp(minEnergy-energy[2]);
final double sum = cumulative3 + Math.exp(minEnergy-energy[3]);
final double randomValue = sum * rand.nextDouble();
sampleIndex = (randomValue > cumulative2) ? ((randomValue > cumulative3) ? 3 : 2) : ((randomValue > cumulative1) ? 1 : 0);
break;
}
default: // For all other lengths
{
// Calculate cumulative conditional probability (unnormalized)
double sum = 0;
final double[] samplerScratch = _samplerScratch;
samplerScratch[0] = 0;
for (int m = 1; m < length; m++)
{
sum += expApprox(minEnergy-energy[m-1]);
samplerScratch[m] = sum;
}
sum += expApprox(minEnergy-energy[length-1]);
for (int m = length; m < _lengthRoundedUp; m++)
samplerScratch[m] = Double.POSITIVE_INFINITY;
final int half = _lengthRoundedUp >> 1;
while (true)
{
// Sample from the distribution using a binary search.
final double randomValue = sum * rand.nextDouble();
sampleIndex = 0;
for (int bitValue = half; bitValue > 0; bitValue >>= 1)
{
final int testIndex = sampleIndex | bitValue;
if (randomValue > samplerScratch[testIndex]) sampleIndex = testIndex;
}
// Rejection sampling, since the approximation of the exponential function is so coarse
final double logp = minEnergy-energy[sampleIndex];
if (Double.isNaN(logp))
throw new DimpleException("The energy for all values of this variable is infinite. This may indicate a state inconsistent with the model.");
if (rand.nextDouble()*expApprox(logp) <= Math.exp(logp)) break;
}
}
}
samplerClient.setNextSampleIndex(sampleIndex);
}
// This is an approximation to the exponential function; inputs must be non-positive
// To facilitate subsequent rejection sampling, the error versus the correct exponential function needs to be always positive
// This is true except for very large negative inputs, for values just as the output approaches zero
// To ensure rejection is never in an infinite loop, this must reach 0 for large negative inputs before the Math.exp function does
public final static double expApprox(double value)
{
// Convert input to base2 log, then convert integer part into IEEE754 exponent
final long expValue = (long)((int)(1512775.395195186 * value) + 0x3FF00000) << 32; // 1512775.395195186 = 2^20/log(2)
return Double.longBitsToDouble(expValue & ~(expValue >> 63)); // Clip result if negative and convert to a double
}
}