/******************************************************************************* * Copyright 2013 Analog Devices, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. ********************************************************************************/ package com.analog.lyric.dimple.solvers.gibbs.customFactors; import static java.util.Objects.*; import java.util.HashSet; import java.util.List; import java.util.Set; import org.eclipse.jdt.annotation.Nullable; import com.analog.lyric.dimple.exceptions.DimpleException; import com.analog.lyric.dimple.factorfunctions.MultinomialEnergyParameters; import com.analog.lyric.dimple.factorfunctions.MultinomialUnnormalizedParameters; import com.analog.lyric.dimple.factorfunctions.core.FactorFunction; import com.analog.lyric.dimple.model.core.EdgeState; import com.analog.lyric.dimple.model.factors.Factor; import com.analog.lyric.dimple.model.variables.Variable; import com.analog.lyric.dimple.model.variables.VariableBlock; import com.analog.lyric.dimple.schedulers.schedule.IGibbsSchedule; import com.analog.lyric.dimple.schedulers.scheduleEntry.BlockScheduleEntry; import com.analog.lyric.dimple.solvers.core.parameterizedMessages.GammaParameters; import com.analog.lyric.dimple.solvers.gibbs.GibbsDiscrete; import com.analog.lyric.dimple.solvers.gibbs.GibbsGammaEdge; import com.analog.lyric.dimple.solvers.gibbs.GibbsReal; import com.analog.lyric.dimple.solvers.gibbs.GibbsRealFactor; import com.analog.lyric.dimple.solvers.gibbs.GibbsSolverEdge; import com.analog.lyric.dimple.solvers.gibbs.GibbsSolverGraph; import com.analog.lyric.dimple.solvers.gibbs.GibbsVariableBlock; import com.analog.lyric.dimple.solvers.gibbs.samplers.block.BlockMHInitializer; import com.analog.lyric.dimple.solvers.gibbs.samplers.conjugate.GammaSampler; import com.analog.lyric.dimple.solvers.gibbs.samplers.conjugate.IRealConjugateSamplerFactory; import com.analog.lyric.dimple.solvers.gibbs.samplers.conjugate.NegativeExpGammaSampler; import com.analog.lyric.dimple.solvers.interfaces.SolverNodeMapping; public class CustomMultinomialUnnormalizedOrEnergyParameters extends GibbsRealFactor implements IRealConjugateFactor, MultinomialBlockProposal.ICustomMultinomial { private @Nullable GibbsDiscrete[] _outputVariables; private @Nullable GibbsDiscrete _NVariable; private @Nullable GibbsReal[] _alphaVariables; private int _dimension; private int _alphaParameterMinIndex; private int _alphaParameterMinEdge; private int _alphaParameterMaxEdge; private int _constantN; private @Nullable int[] _constantOutputCounts; private boolean _hasConstantN; private boolean _hasConstantOutputs; private boolean _hasConstantAlphas; private @Nullable boolean[] _hasConstantAlpha; private @Nullable double[] _constantAlpha; private boolean _useEnergyParameters; private static final int ALPHA_PARAMETER_MIN_INDEX_FIXED_N = 0; // If N is in constructor then alpha is first index (0) private static final int N_PARAMETER_INDEX = 0; // If N is not in constructor then N is the first index (0) private static final int ALPHA_PARAMETER_MIN_INDEX = 1; // If N is not in constructor then alpha is second index (1) public CustomMultinomialUnnormalizedOrEnergyParameters(Factor factor, GibbsSolverGraph parent) { super(factor, parent); } @Override public @Nullable GibbsSolverEdge<?> createEdge(EdgeState edge) { final int portNum = edge.getFactorToVariableEdgeNumber(); if (portNum >= _alphaParameterMinEdge && portNum <= _alphaParameterMaxEdge) { return new GibbsGammaEdge(); } return null; } @SuppressWarnings("null") @Override public void updateEdgeMessage(EdgeState modelEdge, GibbsSolverEdge<?> solverEdge) { final int portNum = modelEdge.getFactorToVariableEdgeNumber(); if (portNum >= _alphaParameterMinEdge && portNum <= _alphaParameterMaxEdge) { // Output port is a parameter input // Determine sample alpha and beta parameters // NOTE: This case works for either MultinomialUnnormalizedParameters or MultinomialEnergyParameters factor functions // since the actual parameter value doesn't come into play in determining the message in this direction GammaParameters outputMsg = (GammaParameters)solverEdge.factorToVarMsg; // The parameter being updated corresponds to this value int parameterOffset = _model.siblingNumberToArgIndex(portNum) - _alphaParameterMinIndex; // Get the count from the corresponding output int count = _hasConstantOutputs ? _constantOutputCounts[parameterOffset] : _outputVariables[parameterOffset].getCurrentSampleIndex(); outputMsg.setAlphaMinusOne(count); // Sample alpha outputMsg.setBeta(0); // Sample beta } else super.updateEdgeMessage(modelEdge, solverEdge); } @Override public Set<IRealConjugateSamplerFactory> getAvailableRealConjugateSamplers(int portNumber) { Set<IRealConjugateSamplerFactory> availableSamplers = new HashSet<IRealConjugateSamplerFactory>(); if (isPortAlphaParameter(portNumber)) // Conjugate sampler if edge is alpha parameter input if (_useEnergyParameters) availableSamplers.add(NegativeExpGammaSampler.factory); // Parameter inputs have conjugate negative exp-Gamma distribution else availableSamplers.add(GammaSampler.factory); // Parameter inputs have conjugate Gamma distribution return availableSamplers; } public boolean isPortAlphaParameter(int portNumber) { determineConstantsAndEdges(); // Call this here since initialize may not have been called yet return (portNumber >= _alphaParameterMinEdge && portNumber <= _alphaParameterMaxEdge); } // For MultinomialBlockProposal.ICustomMultinomial interface @SuppressWarnings("null") @Override public final double[] getCurrentAlpha() { double[] alphas = new double[_dimension]; if (_hasConstantAlphas) { for (int i = 0; i < _dimension; i++) alphas[i] = _hasConstantAlpha[i] ? _constantAlpha[i] : _alphaVariables[i].getCurrentSample(); } else // Only variable alphas { for (int i = 0; i < _dimension; i++) alphas[i] = _alphaVariables[i].getCurrentSample(); } return alphas; } @Override public final boolean isAlphaEnergyRepresentation() { return _useEnergyParameters; } @Override public final boolean hasConstantN() { return _hasConstantN; } @Override public final int getN() { return _constantN; } @SuppressWarnings("null") @Override public void initialize() { super.initialize(); // Determine what parameters are constants or edges, and save the state determineConstantsAndEdges(); // Create a block schedule entry with a BlockMHSampler and a MultinomialBlockProposal kernel Variable[] nodeList = new Variable[_outputVariables.length + (_hasConstantN ? 0 : 1)]; int nodeIndex = 0; if (!_hasConstantN) nodeList[nodeIndex++] = _NVariable.getModelObject(); for (int i = 0; i < _outputVariables.length; i++, nodeIndex++) nodeList[nodeIndex] = _outputVariables[i].getModelObject(); GibbsSolverGraph parent = getParentGraph(); VariableBlock block = getParentGraph().getModel().addVariableBlock(nodeList); GibbsVariableBlock sblock = requireNonNull(parent.getSolverVariableBlock(block, true)); BlockMHInitializer blockSampler = new BlockMHInitializer(sblock, new MultinomialBlockProposal(this)); BlockScheduleEntry blockScheduleEntry = new BlockScheduleEntry(blockSampler, block); // Add the block updater to the schedule GibbsSolverGraph rootGraph = (GibbsSolverGraph)parent.getRootSolverGraph(); // FIXME don't assume root IGibbsSchedule schedule = rootGraph.getSchedule(); // Assumes scheduler for Gibbs solver is flattened to root graph schedule.addBlockScheduleEntry(blockScheduleEntry); // Use the block sampler to initialize the neighboring variables rootGraph.addBlockInitializer(blockSampler); } private void determineConstantsAndEdges() { final int prevAlphaParameterMinEdge = _alphaParameterMinEdge; final int prevAlphaParameterMaxEdge = _alphaParameterMaxEdge; final Factor factor = _model; FactorFunction factorFunction = factor.getFactorFunction(); FactorFunction containedFactorFunction = factorFunction; boolean hasFactorFunctionConstructorConstantN; if (containedFactorFunction instanceof MultinomialUnnormalizedParameters) { MultinomialUnnormalizedParameters specificFactorFunction = (MultinomialUnnormalizedParameters)containedFactorFunction; hasFactorFunctionConstructorConstantN = specificFactorFunction.hasConstantNParameter(); _dimension = specificFactorFunction.getDimension(); _constantN = specificFactorFunction.getN(); _useEnergyParameters = false; } else if (containedFactorFunction instanceof MultinomialEnergyParameters) { MultinomialEnergyParameters specificFactorFunction = (MultinomialEnergyParameters)containedFactorFunction; hasFactorFunctionConstructorConstantN = specificFactorFunction.hasConstantNParameter(); _dimension = specificFactorFunction.getDimension(); _constantN = specificFactorFunction.getN(); _useEnergyParameters = true; } else throw new DimpleException("Invalid factor function"); // Pre-determine whether or not the parameters are constant List<? extends Variable> siblings = factor.getSiblings(); _NVariable = null; _hasConstantOutputs = false; _outputVariables = null; _alphaParameterMinIndex = hasFactorFunctionConstructorConstantN ? ALPHA_PARAMETER_MIN_INDEX_FIXED_N : ALPHA_PARAMETER_MIN_INDEX; int alphaParameterMaxIndex = _alphaParameterMinIndex + _dimension - 1; int outputMinIndex = alphaParameterMaxIndex + 1; if (hasFactorFunctionConstructorConstantN) _hasConstantN = true; else // Variable or constant N { _hasConstantN = factor.hasConstantAtIndex(N_PARAMETER_INDEX); if (_hasConstantN) _constantN = requireNonNull(factor.getConstantValueByIndex(N_PARAMETER_INDEX)).getInt(); else _NVariable = (GibbsDiscrete)getSibling(factor.argIndexToSiblingNumber(N_PARAMETER_INDEX)); } final SolverNodeMapping solvers = getSolverMapping(); // Save the alpha parameter constant or variables as well _hasConstantAlphas = false; _hasConstantAlpha = null; _constantAlpha = null; final GibbsReal[] alphaVariables = _alphaVariables = new GibbsReal[_dimension]; _alphaParameterMinEdge = factor.argIndexToSiblingNumber(_alphaParameterMinEdge); _alphaParameterMaxEdge = factor.argIndexToSiblingNumber(alphaParameterMaxIndex); if (factor.hasConstantsInIndexRange(_alphaParameterMinIndex, alphaParameterMaxIndex)) // Some constant alphas { _hasConstantAlphas = true; final boolean[] hasConstantAlpha = _hasConstantAlpha = new boolean[_dimension]; final double[] constantAlpha = _constantAlpha = new double[_dimension]; for (int i = 0, index = _alphaParameterMinIndex; i < _dimension; i++, index++) { if (factor.hasConstantAtIndex(index)) { hasConstantAlpha[i] = true; constantAlpha[i] = requireNonNull(factor.getConstantValueByIndex(index)).getDouble(); } else { hasConstantAlpha[i] = false; int alphaEdge = factor.argIndexToSiblingNumber(index); alphaVariables[i] = (GibbsReal)solvers.getSolverVariable(siblings.get(alphaEdge)); } } } else // No constant alphas { for (int i = 0, index = _alphaParameterMinIndex; i < _dimension; i++, index++) { int alphaEdge = factor.argIndexToSiblingNumber(index); alphaVariables[i] = (GibbsReal)solvers.getSolverVariable(siblings.get(alphaEdge)); } } // Save the output constant or variables as well final int nEdges = getSiblingCount(); int numOutputEdges = nEdges - factor.argIndexToSiblingNumber(outputMinIndex); _hasConstantOutputs = factor.hasConstantAtOrAboveIndex(outputMinIndex); final GibbsDiscrete[] outputVariables = _outputVariables = new GibbsDiscrete[numOutputEdges]; _hasConstantOutputs = factor.hasConstantAtOrAboveIndex(outputMinIndex); _constantOutputCounts = null; if (_hasConstantOutputs) { int numConstantOutputs = factor.numConstantsAtOrAboveIndex(outputMinIndex); final int[] constantOutputCounts = _constantOutputCounts = new int[numConstantOutputs]; for (int i = 0, index = outputMinIndex; i < _dimension; i++, index++) { if (factor.hasConstantAtIndex(index)) { constantOutputCounts[i] = requireNonNull(factor.getConstantValueByIndex(index)).getInt(); } else { int outputEdge = factor.argIndexToSiblingNumber(index); outputVariables[i] = (GibbsDiscrete)solvers.getSolverVariable(siblings.get(outputEdge)); } } } else // No constant outputs { for (int i = 0, index = outputMinIndex; i < _dimension; i++, index++) { int outputEdge = factor.argIndexToSiblingNumber(index); outputVariables[i] = (GibbsDiscrete)solvers.getSolverVariable(siblings.get(outputEdge)); } } if (_alphaParameterMaxEdge != prevAlphaParameterMaxEdge || _alphaParameterMinEdge != prevAlphaParameterMinEdge) { removeSiblingEdgeState(); } } }