/*******************************************************************************
* Copyright 2013 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.solvers.gibbs.customFactors;
import static java.util.Objects.*;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import org.eclipse.jdt.annotation.Nullable;
import com.analog.lyric.dimple.factorfunctions.Multinomial;
import com.analog.lyric.dimple.factorfunctions.core.FactorFunction;
import com.analog.lyric.dimple.model.core.EdgeState;
import com.analog.lyric.dimple.model.factors.Factor;
import com.analog.lyric.dimple.model.variables.Variable;
import com.analog.lyric.dimple.model.variables.VariableBlock;
import com.analog.lyric.dimple.schedulers.schedule.IGibbsSchedule;
import com.analog.lyric.dimple.schedulers.scheduleEntry.BlockScheduleEntry;
import com.analog.lyric.dimple.solvers.core.parameterizedMessages.DirichletParameters;
import com.analog.lyric.dimple.solvers.gibbs.GibbsDirichletEdge;
import com.analog.lyric.dimple.solvers.gibbs.GibbsDiscrete;
import com.analog.lyric.dimple.solvers.gibbs.GibbsRealFactor;
import com.analog.lyric.dimple.solvers.gibbs.GibbsRealJoint;
import com.analog.lyric.dimple.solvers.gibbs.GibbsSolverEdge;
import com.analog.lyric.dimple.solvers.gibbs.GibbsSolverGraph;
import com.analog.lyric.dimple.solvers.gibbs.GibbsVariableBlock;
import com.analog.lyric.dimple.solvers.gibbs.samplers.block.BlockMHInitializer;
import com.analog.lyric.dimple.solvers.gibbs.samplers.conjugate.DirichletSampler;
import com.analog.lyric.dimple.solvers.gibbs.samplers.conjugate.IRealJointConjugateSamplerFactory;
import com.analog.lyric.dimple.solvers.interfaces.SolverNodeMapping;
public class CustomMultinomial extends GibbsRealFactor implements IRealJointConjugateFactor, MultinomialBlockProposal.ICustomMultinomial
{
private @Nullable GibbsDiscrete[] _outputVariables;
private @Nullable GibbsDiscrete _NVariable;
private @Nullable GibbsRealJoint _alphaVariable;
private int _dimension;
private int _alphaParameterEdge;
private int _constantN;
private @Nullable double[] _constantAlpha;
private @Nullable int[] _constantOutputCounts;
private boolean _hasConstantN;
private boolean _hasConstantAlpha;
private boolean _hasConstantOutputs;
private @Nullable boolean[] _hasConstantOutput;
private static final int NO_PORT = -1;
private static final int ALPHA_PARAMETER_INDEX_FIXED_N = 0; // If N is in constructor then alpha is first index (0)
private static final int OUTPUT_MIN_INDEX_FIXED_N = 1; // If N is in constructor then output starts at second index (1)
private static final int N_PARAMETER_INDEX = 0; // If N is not in constructor then N is the first index (0)
private static final int ALPHA_PARAMETER_INDEX = 1; // If N is not in constructor then alpha is second index (1)
private static final int OUTPUT_MIN_INDEX = 2; // If N is not in constructor then output starts at third index (2)
public CustomMultinomial(Factor factor, GibbsSolverGraph parent)
{
super(factor, parent);
}
@Override
public @Nullable GibbsSolverEdge<?> createEdge(EdgeState edge)
{
if (edge.getFactorToVariableEdgeNumber() == _alphaParameterEdge)
{
return new GibbsDirichletEdge(_dimension);
}
return null;
}
@SuppressWarnings("null")
@Override
public void updateEdgeMessage(EdgeState modelEdge, GibbsSolverEdge<?> solverEdge)
{
final int portNum = modelEdge.getFactorToVariableEdgeNumber();
if (portNum == _alphaParameterEdge)
{
// Output port is the joint alpha parameter input
// Determine sample alpha vector of the conjugate Dirichlet distribution
DirichletParameters outputMsg = (DirichletParameters)solverEdge.factorToVarMsg;
// Clear the output counts
outputMsg.setNull(_dimension);
// Get the current output counts
if (!_hasConstantOutputs)
{
for (int i = 0; i < _dimension; i++)
outputMsg.add(i, _outputVariables[i].getCurrentSampleIndex());
}
else // Some or all outputs are constant
{
for (int i = 0, iVar = 0, iConst = 0; i < _dimension; i++)
outputMsg.add(i, _hasConstantOutput[i] ? _constantOutputCounts[iConst++] : _outputVariables[iVar++].getCurrentSampleIndex());
}
}
else
super.updateEdgeMessage(modelEdge, solverEdge);
}
@Override
public Set<IRealJointConjugateSamplerFactory> getAvailableRealJointConjugateSamplers(int portNumber)
{
Set<IRealJointConjugateSamplerFactory> availableSamplers = new HashSet<IRealJointConjugateSamplerFactory>();
if (isPortAlphaParameter(portNumber)) // Conjugate sampler if edge is alpha parameter input
availableSamplers.add(DirichletSampler.factory); // Parameter inputs have conjugate Dirichlet distribution
return availableSamplers;
}
public boolean isPortAlphaParameter(int portNumber)
{
determineConstantsAndEdges(); // Call this here since initialize may not have been called yet
return (portNumber == _alphaParameterEdge);
}
// For MultinomialBlockProposal.ICustomMultinomial interface
@SuppressWarnings("null")
@Override
public final double[] getCurrentAlpha()
{
return (_hasConstantAlpha ? _constantAlpha : _alphaVariable.getCurrentSample()).clone();
}
@Override
public final boolean isAlphaEnergyRepresentation()
{
return false;
}
@Override
public final boolean hasConstantN()
{
return _hasConstantN;
}
@Override
public final int getN()
{
return _constantN;
}
@Override
public void initialize()
{
super.initialize();
// Determine what parameters are constants or edges, and save the state
determineConstantsAndEdges();
// Create a block schedule entry with a BlockMHSampler and a MultinomialBlockProposal kernel
final GibbsDiscrete[] outputVariables = Objects.requireNonNull(_outputVariables);
Variable[] nodeList = new Variable[outputVariables.length + (_hasConstantN ? 0 : 1)];
int nodeIndex = 0;
if (!_hasConstantN)
nodeList[nodeIndex++] = Objects.requireNonNull(_NVariable).getModelObject();
for (int i = 0; i < outputVariables.length; i++, nodeIndex++)
nodeList[nodeIndex] = outputVariables[i].getModelObject();
GibbsSolverGraph parent = getParentGraph();
VariableBlock block = parent.getModel().addVariableBlock(nodeList);
GibbsVariableBlock sblock = requireNonNull(parent.getSolverVariableBlock(block, true));
BlockMHInitializer blockSampler = new BlockMHInitializer(sblock, new MultinomialBlockProposal(this));
BlockScheduleEntry blockScheduleEntry = new BlockScheduleEntry(blockSampler, block);
// Add the block updater to the schedule
GibbsSolverGraph rootGraph = (GibbsSolverGraph)parent.getRootSolverGraph(); // FIXME don't assume root
IGibbsSchedule schedule = rootGraph.getSchedule(); // Assumes scheduler for Gibbs solver is flattened to root graph
schedule.addBlockScheduleEntry(blockScheduleEntry);
// Use the block sampler to initialize the neighboring variables
rootGraph.addBlockInitializer(blockSampler);
}
private void determineConstantsAndEdges()
{
final Factor factor = _model;
FactorFunction factorFunction = factor.getFactorFunction();
Multinomial specificFactorFunction = (Multinomial)factorFunction;
final int prevAlphaParameterEdge = _alphaParameterEdge;
// Pre-determine whether or not the parameters are constant
List<? extends Variable> siblings = factor.getSiblings();
int alphaParameterIndex;
int outputMinIndex;
_constantN = -1;
_NVariable = null;
if (specificFactorFunction.hasConstantNParameter()) // N parameter is constructor constant
{
_hasConstantN = true;
_constantN = specificFactorFunction.getN();
alphaParameterIndex = ALPHA_PARAMETER_INDEX_FIXED_N;
outputMinIndex = OUTPUT_MIN_INDEX_FIXED_N;
}
else // Variable or constant N parameter
{
_hasConstantN = factor.hasConstantAtIndex(N_PARAMETER_INDEX);
if (_hasConstantN)
_constantN = requireNonNull(factor.getConstantValueByIndex(N_PARAMETER_INDEX)).getInt();
else
_NVariable = (GibbsDiscrete)getSibling(factor.argIndexToSiblingNumber(N_PARAMETER_INDEX));
alphaParameterIndex = ALPHA_PARAMETER_INDEX;
outputMinIndex = OUTPUT_MIN_INDEX;
}
// Save the alpha parameter constant or variables
_hasConstantAlpha = false;
_constantAlpha = null;
_alphaVariable = null;
_alphaParameterEdge = NO_PORT;
if (factor.hasConstantAtIndex(alphaParameterIndex))
{
_hasConstantAlpha = true;
_constantAlpha =
requireNonNull(factor.getConstantValueByIndex(alphaParameterIndex)).getDoubleArray();
}
else
{
_alphaParameterEdge = factor.argIndexToSiblingNumber(alphaParameterIndex);
_alphaVariable = (GibbsRealJoint)getSibling(_alphaParameterEdge);
}
final SolverNodeMapping solvers = getSolverMapping();
// Save the output constant or variables as well
final int nEdges = getSiblingCount();
int numOutputEdges = nEdges - factor.argIndexToSiblingNumber(outputMinIndex);
final GibbsDiscrete[] outputVariables = _outputVariables = new GibbsDiscrete[numOutputEdges];
_hasConstantOutputs = factor.hasConstantAtOrAboveIndex(outputMinIndex);
_constantOutputCounts = null;
_hasConstantOutput = null;
_dimension = -1;
if (_hasConstantOutputs)
{
int numConstantOutputs = factor.numConstantsAtOrAboveIndex(outputMinIndex);
_dimension = numOutputEdges + numConstantOutputs;
final boolean[] hasConstantOutput = _hasConstantOutput = new boolean[_dimension];
final int[] constantOutputCounts = _constantOutputCounts = new int[numConstantOutputs];
for (int i = 0, index = outputMinIndex; i < _dimension; i++, index++)
{
if (factor.hasConstantAtIndex(index))
{
hasConstantOutput[i] = true;
constantOutputCounts[i] = requireNonNull(factor.getConstantValueByIndex(index)).getInt();
}
else
{
hasConstantOutput[i] = false;
int outputEdge = factor.argIndexToSiblingNumber(index);
outputVariables[i] = (GibbsDiscrete)solvers.getSolverVariable(siblings.get(outputEdge));
}
}
}
else // No constant outputs
{
_dimension = numOutputEdges;
for (int i = 0, index = outputMinIndex; i < _dimension; i++, index++)
{
int outputEdge = factor.argIndexToSiblingNumber(index);
outputVariables[i] = (GibbsDiscrete)solvers.getSolverVariable(siblings.get(outputEdge));
}
}
if (_alphaParameterEdge != prevAlphaParameterEdge)
{
removeSiblingEdgeState();
}
}
}