/*******************************************************************************
* Copyright 2014 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.test.solvers.gibbs;
import static com.analog.lyric.dimple.environment.DimpleEnvironment.*;
import java.util.Objects;
import java.util.Random;
import org.eclipse.jdt.annotation.Nullable;
import com.analog.lyric.dimple.exceptions.DimpleException;
import com.analog.lyric.dimple.model.domains.DiscreteDomain;
import com.analog.lyric.dimple.model.domains.Domain;
import com.analog.lyric.dimple.model.values.Value;
import com.analog.lyric.dimple.solvers.core.proposalKernels.BlockProposal;
import com.analog.lyric.dimple.solvers.core.proposalKernels.IBlockProposalKernel;
import com.analog.lyric.math.Utilities;
/**
*
* @since 0.06
* @author jbernst2
*/
public class TrivialNonuniformBlockProposer implements IBlockProposalKernel
{
private double[] _weights;
private int[] _domainSizes;
private int[] _domainProducts;
private int _numDomains;
private Random _random;
public TrivialNonuniformBlockProposer(double[] weights, int[] domainSizes)
{
_domainSizes = domainSizes;
_weights = weights;
_random = activeRandom();
_numDomains = _domainSizes.length;
_domainProducts = new int[_numDomains];
_domainProducts[0] = 1;
for (int i = 1; i < _numDomains; i++)
_domainProducts[i] = _domainProducts[i-1] * _domainSizes[i-1];
}
@Override
public BlockProposal next(Value[] currentValue, Domain[] variableDomain)
{
// Sample value randomly give the weights
int newIndex = Utilities.sampleFromMultinomial(_weights, _random);
int[] newArray = indexToArray(newIndex);
int currentIndex = arrayToIndex(currentValue);
Value[] newValue = new Value[_numDomains];
for (int i = 0; i < _numDomains; i++)
{
Domain domain = variableDomain[i];
if (domain.isDiscrete())
{
DiscreteDomain discreteDomain = Objects.requireNonNull(domain.asDiscrete());
Value v = Value.create(discreteDomain);
v.setIndex(newArray[i]);
newValue[i] = v;
}
else
{
throw new DimpleException("Not supported");
}
}
double proposalForwardEnergy = -Math.log(_weights[newIndex]);
double proposalReverseEnergy = -Math.log(_weights[currentIndex]);
return new BlockProposal(newValue, proposalForwardEnergy, proposalReverseEnergy);
}
@Deprecated
@Override
public void setParameters(Object... parameters)
{
}
@Deprecated
@Override
public @Nullable Object[] getParameters()
{
return null;
}
protected int[] indexToArray(int index)
{
int[] newArray = new int[_numDomains];
for (int i = 0; i < _numDomains; i++)
{
int divisor = _domainProducts[_numDomains-i-1];
newArray[i] = index / divisor;
index = index % divisor;
}
return newArray;
}
protected int arrayToIndex(int[] array)
{
int index = 0;
for (int i = 0; i < _numDomains; i++)
index += array[i] * _domainProducts[_numDomains-i-1];
return index;
}
protected int arrayToIndex(Value[] array)
{
int index = 0;
for (int i = 0; i < _numDomains; i++)
index += array[i].getIndex() * _domainProducts[_numDomains-i-1];
return index;
}
// Just a double check
public static void main(String[] args)
{
int[] domainSizes = new int[]{2, 5, 3, 6, 4};
int product = 2 * 5 * 3 * 6 * 4;
double[] weights = new double[product];
for (int i = 0; i < product; i++)
weights[i] = activeRandom().nextDouble();
TrivialNonuniformBlockProposer t = new TrivialNonuniformBlockProposer(weights, domainSizes);
for (int i = 0; i < product; i++)
{
int[] array = t.indexToArray(i);
int j = t.arrayToIndex(array);
if (i != j)
throw new DimpleException("!");
}
}
}