/*******************************************************************************
* Copyright 2012-2013 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.solvers.gibbs;
import static java.util.Objects.*;
import java.util.Collection;
import java.util.concurrent.atomic.AtomicReference;
import org.eclipse.jdt.annotation.Nullable;
import com.analog.lyric.dimple.factorfunctions.core.FactorFunction;
import com.analog.lyric.dimple.model.core.EdgeState;
import com.analog.lyric.dimple.model.factors.Factor;
import com.analog.lyric.dimple.model.values.IndexedValue;
import com.analog.lyric.dimple.model.values.Value;
import com.analog.lyric.dimple.model.variables.Variable;
import com.analog.lyric.dimple.solvers.core.parameterizedMessages.DiscreteMessage;
import com.analog.lyric.dimple.solvers.interfaces.SolverNodeMapping;
import com.analog.lyric.util.misc.Matlab;
/**
* Real solver factor under Gibbs solver.
*
* @since 0.07
* @author Christopher Barber
*/
@SuppressWarnings("deprecation") // TODO: remove when SRealFactor removed
public class GibbsRealFactor extends SRealFactor implements ISolverFactorGibbs
{
protected Value [] _currentSamples = new Value[0];
// private Object[] _scratchValues;
protected boolean _isDeterministicDirected;
private int _topologicalOrder = 0;
/**
* True if output samples in {@link #_currentSamples} have been computed.
*/
private boolean _outputsValid = false;
private boolean _visited = false;
public GibbsRealFactor(Factor factor, GibbsSolverGraph parent)
{
super(factor, parent);
_isDeterministicDirected = _model.getFactorFunction().isDeterministicDirected();
}
@Override
public void initialize()
{
super.initialize();
_outputsValid = false;
_currentSamples = _model.fillInArgumentValues(_parent, _currentSamples);
}
@Override
public @Nullable GibbsSolverEdge<?> createEdge(EdgeState edge)
{
return null;
}
@Override
public void doUpdateEdge(int outPortNum)
{
// The Gibbs solver doesn't directly update factors, but the equivalent is instead done calls from variables
// This is ignored and doesn't throw an error so that a custom schedule that updates factors won't cause a problem
}
@Override
protected void doUpdate()
{
// The Gibbs solver doesn't directly update factors, but the equivalent is instead done calls from variables
// This is ignored and doesn't throw an error so that a custom schedule that updates factors won't cause a problem
}
@Override
public ISolverVariableGibbs getSibling(int edge)
{
return (ISolverVariableGibbs)super.getSibling(edge);
}
@Override
public GibbsSolverGraph getParentGraph()
{
return (GibbsSolverGraph)_parent;
}
@Override
public void updateEdgeMessage(EdgeState modelEdge, GibbsSolverEdge<?> solverEdge)
{
final Factor factor = _model;
final int outPortNum = modelEdge.getFactorToVariableEdgeNumber();
final int outIndex = factor.siblingNumberToArgIndex(outPortNum);
Value outValue = _currentSamples[outIndex];
if (outValue.getDomain().isDiscrete())
{
// This edge connects to a discrete variable, so send an output message
// This method only considers the current conditional values, and does not propagate
// to any other variables (unlike get ConditionalPotential)
// This should only be called if this factor is not a deterministic directed factor
final FactorFunction factorFunction = factor.getFactorFunction();
final Value[] values = _currentSamples.clone();
outValue = outValue.clone();
values[outIndex] = outValue;
double[] outputMsgs = ((DiscreteMessage)solverEdge.factorToVarMsg).representation();
for (int i = outputMsgs.length; --i>=0;)
{
outValue.setIndex(i);
outputMsgs[i] = factorFunction.evalEnergy(values); // Messages to discrete variables are energy values
}
}
}
@Override
public double getPotential()
{
if (_isDeterministicDirected)
return 0;
final Value[] inputMsgs = _currentSamples;
if (inputMsgs.length > 0)
{
final double energy = _model.getFactorFunction().evalEnergy(inputMsgs);
if (energy != energy) // Faster isNaN
return Double.POSITIVE_INFINITY;
return energy;
}
else
return Double.POSITIVE_INFINITY;
}
@Matlab
@Override
public double getPotential(Object[] inputs)
{
return _model.evalEnergy(inputs);
}
@Override
public final int getTopologicalOrder()
{
return _topologicalOrder;
}
@Override
public final void setTopologicalOrder(int order)
{
_topologicalOrder = order;
}
@SuppressWarnings("null")
@Override
public void updateNeighborVariableValue(int variableIndex, Value oldValue)
{
final int argIndex = _model.siblingNumberToArgIndex(variableIndex);
((GibbsSolverGraph)getRootSolverGraph()).scheduleDeterministicDirectedUpdate(this, argIndex, oldValue);
}
@Override
public void updateNeighborVariableValuesNow(@Nullable Collection<IndexedValue> oldValues)
{
// Compute the output values of the deterministic factor function from the input values
final Factor factor = _model;
final FactorFunction function = factor.getFactorFunction();
int[] directedTo = factor.getDirectedTo();
final Value[] inputMsgs = requireNonNull(_currentSamples);
final SolverNodeMapping solvers = requireNonNull(getParentGraph()).getSolverMapping();
if (oldValues != null && _outputsValid)
{
AtomicReference<int[]> changedOutputsHolder = new AtomicReference<int[]>();
Value[] values = function.updateDeterministicToCopy(inputMsgs, oldValues, changedOutputsHolder);
int[] changedOutputs = changedOutputsHolder.get();
if (changedOutputs != null)
{
if (factor.hasConstants())
{
for (int i = changedOutputs.length; --i>=0;)
{
// Translate from factor arg index back to edge index, even though
// we will just switch back again below.
changedOutputs[i] = factor.argIndexToSiblingNumber(changedOutputs[i]);
}
}
directedTo = changedOutputs;
}
// Update the directed-to variables with the computed values
if (directedTo != null)
{
for (int to : directedTo)
{
final int outputIndex = factor.siblingNumberToArgIndex(to);
Variable variable = requireNonNull(factor.getSibling(to));
Value newValue = values[outputIndex];
((ISolverVariableGibbs)solvers.getSolverVariable(variable)).setCurrentSample(newValue);
}
}
}
else
{
Value[] values = function.evalDeterministicToCopy(inputMsgs);
_outputsValid = true;
// Update the directed-to variables with the computed values
if (directedTo != null)
{
// Full update
for (int to : directedTo)
{
final int outputIndex = factor.siblingNumberToArgIndex(to);
Variable variable = requireNonNull(factor.getSibling(to));
Value newValue = values[outputIndex];
((ISolverVariableGibbs)solvers.getSolverVariable(variable)).setCurrentSample(newValue);
}
}
}
}
@Deprecated
@Override
public Value getInputMsg(int portIndex)
{
return _currentSamples[portIndex];
}
@Override
public boolean setVisited(boolean visited)
{
boolean changed = _visited ^ visited;
_visited = visited;
return changed;
}
@SuppressWarnings("null")
@Override
public GibbsSolverEdge<?> getSiblingEdgeState(int siblingIndex)
{
return (GibbsSolverEdge<?>)getSiblingEdgeState_(siblingIndex);
}
}