/*******************************************************************************
* Copyright 2014 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.solvers.gibbs;
import static java.util.Objects.*;
import java.util.ArrayList;
import java.util.Deque;
import java.util.LinkedList;
import java.util.Queue;
import java.util.concurrent.atomic.AtomicReference;
import org.eclipse.jdt.annotation.NonNull;
import org.eclipse.jdt.annotation.Nullable;
import com.analog.lyric.collect.ReleasableArrayIterator;
import com.analog.lyric.collect.ReleasableIterable;
import com.analog.lyric.collect.ReleasableIterator;
import com.analog.lyric.collect.UnmodifiableReleasableIterator;
import com.analog.lyric.dimple.factorfunctions.core.FactorFunction;
import com.analog.lyric.dimple.model.core.EdgeState;
import com.analog.lyric.dimple.model.factors.Factor;
import com.analog.lyric.dimple.model.values.Value;
import com.analog.lyric.dimple.model.variables.Variable;
import net.jcip.annotations.Immutable;
import net.jcip.annotations.NotThreadSafe;
/**
* Represents the neighbors of a {@link ISolverVariableGibbs} that need to be included in the variable's
* sample score. This class is only created when the contents differ from the immediate siblings of
* the variable. This can happen in two ways:
* <ul>
* <li>The same factor shows up more than once in the sibling list (and should not be double counted).
* <li>The variable is an input to one or more deterministic directed factors whose outputs should be
* included in the score (and in turn their deterministic dependents recursively). In this case, the
* adjacent deterministic factors are also stored in this instance.
* </ul>
*/
@Immutable
public final class GibbsNeighbors implements ReleasableIterable<ISolverNodeGibbs>
{
/*-------
* State
*/
private final ISolverNodeGibbs[] _neighbors;
private final GibbsSolverGraph _rootSolverGraph;
/**
* Contains list of directed deterministic factors that are directed from the
* starting variable. null if none.
*/
private final @Nullable FactorWork[] _adjacentDependentFactors;
/*--------------
* Construction
*/
private GibbsNeighbors(ISolverNodeGibbs[] neighbors, @Nullable FactorWork[] immediateDependentFactors, GibbsSolverGraph rootSolverGraph)
{
_neighbors = neighbors;
_adjacentDependentFactors = immediateDependentFactors;
_rootSolverGraph = rootSolverGraph;
}
/**
* Creates a neighbor list for scoring samples of {@code svar}.
*
* @return null if the neighbors are the same as the node's immediate siblings.
*/
public static @Nullable GibbsNeighbors create(ISolverVariableGibbs svar)
{
final Variable var = svar.getModelObject();
final int nSiblings = var.getSiblingCount();
// Neighbors at front of list, other visited nodes at end. The counter indicates
// where the boundary is.
final Deque<ISolverNodeGibbs> visited = new LinkedList<ISolverNodeGibbs>();
visited.addLast(svar);
svar.setVisited(true);
// Counter of neighbors.
int[] counter = new int[1];
// Nodes yet to visit.
// TODO: can we combine this with the visited list?
final Queue<Work> queue = new VarWork(svar, -1).handle(visited, counter, null);
final boolean createList = queue != null || counter[0] != nSiblings;
FactorWork[] adjacentDependentFactors = null;
if (queue != null)
{
ArrayList<FactorWork> adjacentFactors = new ArrayList<FactorWork>(nSiblings);
boolean processingAdjacentFactors = true;
for (Work work = null; (work = queue.poll()) != null;)
{
if (processingAdjacentFactors)
{
// The FactorWork objects at the head of the queue up to the first VarWork
// must be for adjacent factors.
FactorWork factorWork = work.asFactorWork();
if (factorWork == null)
{
processingAdjacentFactors = false;
}
else
{
adjacentFactors.add(factorWork);
}
}
work.handle(visited, counter, queue);
}
adjacentDependentFactors = adjacentFactors.toArray(new FactorWork[adjacentFactors.size()]);
}
if (createList)
{
final int size = counter[0];
ISolverNodeGibbs[] neighbors = new ISolverNodeGibbs[size];
int i = 0;
for (ISolverNodeGibbs node : visited)
{
node.setVisited(false);
if (i < size)
{
neighbors[i++] = node;
}
}
return new GibbsNeighbors(neighbors, adjacentDependentFactors, (GibbsSolverGraph)requireNonNull(svar.getRootSolverGraph()));
}
else
{
for (ISolverNodeGibbs node : visited)
{
node.setVisited(false);
}
return null;
}
}
private abstract static class Work
{
final int _incomingEdge;
private Work(int incomingEdge)
{
_incomingEdge = incomingEdge;
}
protected @Nullable FactorWork asFactorWork() { return null; }
protected abstract @Nullable Queue<Work> handle(Deque<ISolverNodeGibbs> visited, int[] counter, Queue<Work> queue);
}
private static final class VarWork extends Work
{
private final ISolverVariableGibbs _varNode;
private VarWork(ISolverVariableGibbs varNode, int incomingEdge)
{
super(incomingEdge);
_varNode = varNode;
}
@Override
protected @Nullable Queue<Work> handle(Deque<ISolverNodeGibbs> visited, int[] counterHolder,
@Nullable Queue<Work> queue)
{
final Variable variable = _varNode.getModelObject();
final int nSiblings = variable.getSiblingCount();
int counter = counterHolder[0];
for (int edge = 0; edge < nSiblings; ++edge)
{
if (edge == _incomingEdge)
continue;
final EdgeState edgeState = variable.getSiblingEdgeState(edge);
final ISolverFactorGibbs sfactor = _varNode.getSibling(edge);
final Factor factor = sfactor.getModelObject();
int reverseEdge;
if (factor.getFactorFunction().isDeterministicDirected() &&
!factor.isDirectedTo(reverseEdge = edgeState.getFactorToVariableEdgeNumber()))
{
// Do not mark deterministic directed factors as visited because we may
// need to visit them again from a different input variable and may get
// different outputs. See FactorWork.handle()
if (queue == null)
{
queue = new LinkedList<Work>();
}
queue.add(new FactorWork(sfactor, reverseEdge));
}
else if (sfactor.setVisited(true))
{
visited.addFirst(sfactor);
counter++;
}
}
counterHolder[0] = counter;
return queue;
}
}
private static final class FactorWork extends Work
{
private final ISolverFactorGibbs _factorNode;
private FactorWork(ISolverFactorGibbs factorNode, int incomingEdge)
{
super(incomingEdge);
_factorNode = factorNode;
}
@Override
protected FactorWork asFactorWork()
{
return this;
}
@Override
protected Queue<Work> handle(Deque<ISolverNodeGibbs> visited, int[] counterHolder, Queue<Work> queue)
{
final Factor factor = _factorNode.getModelObject();
final FactorFunction function = factor.getFactorFunction();
int[] outputEdges = function.getDirectedToIndicesForInput(factor, _incomingEdge);
if (outputEdges == null)
{
// In this case, all of the outputs will be visited the first time, so
// don't revisit this node if we come to it again from a different input.
if (_factorNode.setVisited(true))
{
visited.addLast(_factorNode);
}
else
{
return queue;
}
outputEdges = factor.getDirectedTo();
}
int counter = counterHolder[0];
if (outputEdges != null)
{
for (int edge : outputEdges)
{
ISolverVariableGibbs svariable = _factorNode.getSibling(edge);
if (svariable.setVisited(true))
{
if (svariable.hasPotential())
{
visited.addFirst(svariable);
counter++;
}
else
{
visited.addLast(svariable);
}
queue.add(new VarWork(svariable, factor.getReverseSiblingNumber(edge)));
}
}
}
counterHolder[0] = counter;
return queue;
}
}
/*------------------------------
* [Releasable]Iterable methods
*/
@Override
public @NonNull ReleasableIterator<ISolverNodeGibbs> iterator()
{
return ReleasableArrayIterator.create(_neighbors);
}
/**
* Returns an iterator that visits the contents of {@code list} if not null, and which otherwise iterates
* over the solver siblings of {@code var}.
*/
static ReleasableIterator<ISolverNodeGibbs> iteratorFor(@Nullable GibbsNeighbors list, ISolverVariableGibbs var)
{
return list != null ? list.iterator() : SimpleIterator.create(var);
}
/*---------------
* Local methods
*/
boolean hasDeterministicDependents()
{
return _adjacentDependentFactors != null;
}
/**
* Update the deterministic outputs that depend on the original variable.
*
* @param oldValue is the previous value of the variable. The new sample value should
* already have been set before this is invoked.
*/
void update(Value oldValue)
{
final FactorWork[] adjacentDependentFactors = _adjacentDependentFactors;
if (adjacentDependentFactors != null)
{
_rootSolverGraph.deferDeterministicUpdates();
ReleasableIterator<FactorWork> dependentFactors = ReleasableArrayIterator.create(adjacentDependentFactors);
while (dependentFactors.hasNext())
{
FactorWork factor = dependentFactors.next();
factor._factorNode.updateNeighborVariableValue(factor._incomingEdge, oldValue);
}
dependentFactors.release();
_rootSolverGraph.processDeferredDeterministicUpdates();
}
}
/*--------------------------
* Iterator implementations
*/
/**
* Iterator that visits immediate solver nodes of source model node.
*/
@NotThreadSafe
private static class SimpleIterator extends UnmodifiableReleasableIterator<ISolverNodeGibbs>
{
private @Nullable ISolverNodeGibbs _snode;
private int _size;
private int _index;
private static final AtomicReference<SimpleIterator> _reusableInstance = new AtomicReference<SimpleIterator>();
static SimpleIterator create(@Nullable ISolverNodeGibbs svar)
{
SimpleIterator iter = _reusableInstance.getAndSet(null);
if (iter == null)
{
iter = new SimpleIterator();
}
iter.reset(svar);
return iter;
}
@Override
public final boolean hasNext()
{
return _index < _size;
}
@Override
public @Nullable ISolverNodeGibbs next()
{
ISolverNodeGibbs snode = _snode;
return snode != null ? snode.getSibling(_index++) : null ;
}
@Override
public void release()
{
_snode = null;
_reusableInstance.lazySet(this);
}
void reset(@Nullable ISolverNodeGibbs svar)
{
_snode = svar;
_size = svar != null ? svar.getSiblingCount() : 0;
_index = 0;
}
}
}