/******************************************************************************* * Copyright 2013 Analog Devices, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. ********************************************************************************/ package com.analog.lyric.dimple.solvers.gibbs.samplers.generic; import static com.analog.lyric.dimple.environment.DimpleEnvironment.*; import org.apache.commons.math3.random.RandomGenerator; import com.analog.lyric.collect.ArrayUtil; import com.analog.lyric.dimple.model.domains.DiscreteDomain; import com.analog.lyric.dimple.model.domains.Domain; import com.analog.lyric.dimple.model.values.DiscreteValue; import com.analog.lyric.math.Utilities; public class SuwaTodoSampler extends AbstractGenericSampler implements IDiscreteDirectSampler { protected double[] _samplerScratch = ArrayUtil.EMPTY_DOUBLE_ARRAY; protected int _lengthRoundedUp = 0; protected int _length = 0; @Override public void initialize(Domain variableDomain) { int length = ((DiscreteDomain)variableDomain).size(); _length = length; _lengthRoundedUp = Utilities.nextPow2(length); _samplerScratch = new double[_lengthRoundedUp]; } @Override public void nextSample(DiscreteValue sampleValue, double[] energy, double minEnergy, IDiscreteSamplerClient samplerClient) { RandomGenerator rand = activeRandom(); final int length = sampleValue.getDomain().size(); // energy may be longer than domain size int sampleIndex; // Special-case length 2 for speed // This case is equivalent to MH if (length == 2) { final int previousIndex = sampleValue.getIndex(); final double pdf0 = Math.exp(minEnergy - energy[0]); final double pdf1 = Math.exp(minEnergy - energy[1]); if (previousIndex == 0) { double rejectProb = pdf0 - pdf1; if (rejectProb < 0) sampleIndex = 1; // Flip else if (rand.nextDouble() < rejectProb) sampleIndex = 0; else sampleIndex = 1; // Flip } else { double rejectProb = pdf1 - pdf0; if (rejectProb < 0) sampleIndex = 0; // Flip if (rand.nextDouble() < rejectProb) sampleIndex = 1; else sampleIndex = 0; // Flip } } else // For all other lengths { // Calculate cumulative conditional probability (unnormalized) double sum = 0; final double[] samplerScratch = _samplerScratch; final int previousIndex = sampleValue.getIndex(); double previousIntervalValue = 0; samplerScratch[0] = 0; for (int m = 1; m < length; m++) { final int mm1 = m - 1; final double unnormalizedValue = Math.exp(minEnergy-energy[mm1]); if (mm1 == previousIndex) previousIntervalValue = unnormalizedValue; sum += unnormalizedValue; samplerScratch[m] = sum; } final int lm1 = length - 1; final double unnormalizedValue = Math.exp(minEnergy-energy[lm1]); if (previousIndex == lm1) previousIntervalValue = unnormalizedValue; sum += unnormalizedValue; for (int m = length; m < _lengthRoundedUp; m++) samplerScratch[m] = Double.POSITIVE_INFINITY; // Sample from a range circularly shifted by the largest interval with size of the previous value interval // In this scale, the largest interval is always 1 double randomValue = samplerScratch[previousIndex] + 1 + previousIntervalValue * rand.nextDouble(); randomValue = randomValue % sum; // Circularly wrap // Sample from the CDF using a binary search final int half = _lengthRoundedUp >> 1; sampleIndex = 0; for (int bitValue = half; bitValue > 0; bitValue >>= 1) { final int testIndex = sampleIndex | bitValue; if (randomValue > samplerScratch[testIndex]) sampleIndex = testIndex; } } samplerClient.setNextSampleIndex(sampleIndex); } }