/******************************************************************************* * Copyright 2012 Analog Devices, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. ********************************************************************************/ package com.analog.lyric.dimple.solvers.gibbs; import java.util.Collection; import org.eclipse.jdt.annotation.Nullable; import com.analog.lyric.dimple.exceptions.DimpleException; import com.analog.lyric.dimple.factorfunctions.core.FactorFunction; import com.analog.lyric.dimple.factorfunctions.core.FactorTableRepresentation; import com.analog.lyric.dimple.factorfunctions.core.IFactorTable; import com.analog.lyric.dimple.model.core.EdgeState; import com.analog.lyric.dimple.model.factors.Factor; import com.analog.lyric.dimple.model.values.DiscreteValue; import com.analog.lyric.dimple.model.values.IndexedValue; import com.analog.lyric.dimple.model.values.Value; import com.analog.lyric.dimple.model.variables.Variable; import com.analog.lyric.dimple.solvers.core.STableFactorBase; import com.analog.lyric.dimple.solvers.core.parameterizedMessages.DiscreteMessage; import com.analog.lyric.dimple.solvers.interfaces.SolverNodeMapping; import com.analog.lyric.util.misc.Matlab; /** * Solver table factor for Gibbs solver. * * @since 0.07 */ public class GibbsTableFactor extends STableFactorBase implements ISolverFactorGibbs { /*------- * State */ protected Value[] _currentSamples = new DiscreteValue[0]; protected boolean _isDeterministicDirected; private boolean _visited = false; private int _topologicalOrder = 0; /*-------------- * Construction */ public GibbsTableFactor(Factor factor, GibbsSolverGraph parent) { super(factor, parent); _isDeterministicDirected = _model.getFactorFunction().isDeterministicDirected(); } /*--------------- * SNode methods */ @Override public void doUpdateEdge(int outPortNum) { // The Gibbs solver doesn't directly update factors, but the equivalent is instead done calls from variables // This is ignored and doesn't throw an error so that a custom schedule that updates factors won't cause a problem } @Override protected void doUpdate() { // The Gibbs solver doesn't directly update factors, but the equivalent is instead done calls from variables // This is ignored and doesn't throw an error so that a custom schedule that updates factors won't cause a problem } @Override public GibbsSolverGraph getParentGraph() { return (GibbsSolverGraph)_parent; } @Override public GibbsDiscrete getSibling(int edge) { return (GibbsDiscrete)super.getSibling(edge); } /*---------------------------- * ISolverFactorGibbs methods */ @Override public @Nullable GibbsSolverEdge<?> createEdge(EdgeState edge) { return null; } @Override public void updateEdgeMessage(EdgeState modelEdge, GibbsSolverEdge<?> solverEdge) { // Generate message representing conditional distribution of selected edge variable // This should be called only for a table factor that is not a deterministic directed factor if (_isDeterministicDirected) throw new DimpleException("Invalid call to updateEdge"); final Factor factor = _model; final int outPortNum = modelEdge.getFactorToVariableEdgeNumber(); final int outIndex = factor.siblingNumberToArgIndex(outPortNum); final double[] outMessage = ((DiscreteMessage)solverEdge.factorToVarMsg).representation(); final IFactorTable factorTable = getFactorTableIfComputed(); if (factorTable != null) { factorTable.getEnergySlice(outMessage, outPortNum, samplesForFactorTable()); } else { final Value changedValue = _currentSamples[outIndex]; final FactorFunction function = _model.getFactorFunction(); final int savedIndex = changedValue.getIndex(); final int sliceLength = outMessage.length; changedValue.setIndex(0); outMessage[0] = function.evalEnergy(_currentSamples); if (function.useUpdateEnergy(_currentSamples, 1)) { final Value prevValue = changedValue.clone(); final IndexedValue[] changedValues = new IndexedValue[] { new IndexedValue(outIndex, prevValue) }; double energy = outMessage[0]; for (int i = 1; i < sliceLength; ++i) { changedValue.setIndex(i); prevValue.setIndex(i - 1); outMessage[i] = energy = function.updateEnergy(_currentSamples, changedValues, energy); } } else { for (int i = 1; i < sliceLength; ++i) { changedValue.setIndex(i); outMessage[i] = function.evalEnergy(_currentSamples); } } changedValue.setIndex(savedIndex); } } @Override public double getPotential() { if (_isDeterministicDirected) return 0; final int size = _currentSamples.length; if (size == 0) { // Probably because initalize() not yet called. return Double.POSITIVE_INFINITY; } IFactorTable factorTable = getFactorTableIfComputed(); if (factorTable == null) { // Avoid creating table because it may be very large. // FIXME - think more about this. Should this be conditional on something? final double energy = getFactor().getFactorFunction().evalEnergy(_currentSamples); if (energy != energy) // Faster isNaN return Double.POSITIVE_INFINITY; return energy; } return factorTable.getEnergyForValues(samplesForFactorTable()); } @Matlab @Deprecated public double getPotential(int[] inputs) { return getFactorTable().getEnergyForIndices(inputs); } @Override public final int getTopologicalOrder() { return _topologicalOrder ; } @Override public final void setTopologicalOrder(int order) { _topologicalOrder = order; } @Override public void updateNeighborVariableValue(int variableIndex, Value oldValue) { final int argIndex = _model.siblingNumberToArgIndex(variableIndex); ((GibbsSolverGraph)getRootSolverGraph()).scheduleDeterministicDirectedUpdate(this, argIndex, oldValue); } @Override public void updateNeighborVariableValuesNow(@Nullable Collection<IndexedValue> oldValues) { // Compute the output values of the deterministic factor function from the input values final Factor factor = _model; Value[] values = factor.getFactorFunction().evalDeterministicToCopy(_currentSamples); // Update the directed-to variables with the computed values SolverNodeMapping solvers = getSolverMapping(); int[] directedTo = factor.getDirectedTo(); if (directedTo != null) { for (int to : directedTo) { final int outputIndex = factor.siblingNumberToArgIndex(to); Variable variable = factor.getSibling(to); // FIXME: is sample value already set? Just need to handle side effects? ISolverVariableGibbs svar = (ISolverVariableGibbs) solvers.getSolverVariable(variable); svar.setCurrentSample(values[outputIndex]); } } } @Override public void initialize() { super.initialize(); _isDeterministicDirected = _model.getFactorFunction().isDeterministicDirected(); _currentSamples = _model.fillInArgumentValues(_parent, _currentSamples); } @Deprecated @Override public DiscreteValue getInputMsg(int portIndex) { return (DiscreteValue)_currentSamples[_model.siblingNumberToArgIndex(portIndex)]; } @Deprecated @Override public Object getOutputMsg(int portIndex) { return getSiblingEdgeState(portIndex).factorToVarMsg.representation(); } /*-------------------------- * STableFactorBase methods */ @Override protected void setTableRepresentation(IFactorTable table) { if (_isDeterministicDirected) { table.setRepresentation(FactorTableRepresentation.DETERMINISTIC); } else { table.setRepresentation(FactorTableRepresentation.DENSE_ENERGY); } } @Override public boolean setVisited(boolean visited) { boolean changed = _visited ^ visited; _visited = visited; return changed; } @SuppressWarnings("null") @Override public GibbsDiscreteEdge getSiblingEdgeState(int siblingIndex) { return (GibbsDiscreteEdge)getSiblingEdgeState_(siblingIndex); } /*----------------- * Private methods */ /** * Same as {@link _currentSamples} with constant entries removed. For use with factor table. * @since 0.08 */ private Value[] samplesForFactorTable() { final Factor factor = _model; Value[] samples = _currentSamples; if (factor.hasConstants()) { // FIXME Constant - can we avoid doing this copy? samples = new Value[factor.getSiblingCount()]; for (int i = samples.length; --i>=0;) { samples[i] = _currentSamples[factor.siblingNumberToArgIndex(i)]; } } return samples; } }