/*******************************************************************************
* Copyright 2014 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.model.transform;
import static java.util.Objects.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.Random;
import java.util.Set;
import org.eclipse.jdt.annotation.NonNullByDefault;
import org.eclipse.jdt.annotation.Nullable;
import com.analog.lyric.collect.ArrayUtil;
import com.analog.lyric.collect.BinaryHeap;
import com.analog.lyric.collect.IHeap;
import com.analog.lyric.collect.SkipSet;
import com.analog.lyric.collect.Tuple2;
import com.analog.lyric.dimple.exceptions.DimpleException;
import com.analog.lyric.dimple.factorfunctions.Uniform;
import com.analog.lyric.dimple.factorfunctions.core.FactorTable;
import com.analog.lyric.dimple.factorfunctions.core.IFactorTable;
import com.analog.lyric.dimple.factorfunctions.core.IFactorTableIterator;
import com.analog.lyric.dimple.model.core.FactorGraph;
import com.analog.lyric.dimple.model.domains.DiscreteDomain;
import com.analog.lyric.dimple.model.domains.JointDiscreteDomain;
import com.analog.lyric.dimple.model.factors.Factor;
import com.analog.lyric.dimple.model.transform.JunctionTreeTransformMap.AddedJointDiscreteVariable;
import com.analog.lyric.dimple.model.transform.JunctionTreeTransformMap.AddedJointVariable;
import com.analog.lyric.dimple.model.transform.VariableEliminator.CostFunction;
import com.analog.lyric.dimple.model.transform.VariableEliminator.Ordering;
import com.analog.lyric.dimple.model.transform.VariableEliminator.Stats;
import com.analog.lyric.dimple.model.transform.VariableEliminator.VariableCost;
import com.analog.lyric.dimple.model.variables.Discrete;
import com.analog.lyric.dimple.model.variables.Variable;
import com.analog.lyric.dimple.model.variables.VariableList;
import com.analog.lyric.dimple.options.BPOptions;
import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.Iterables;
import com.google.common.collect.LinkedHashMultimap;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.ObjectArrays;
import com.google.common.collect.SetMultimap;
/**
* This class implements the junction tree transformation on an input model to create
* a semantically equivalent version of the model that is singly connected and thus
* can be used for exact inference using belief propagation.
* <p>
* The algorithm as implemented in this class is as follows:
*
* <ol>
* <li>Determine a variable elimination order.
* <li>Makes a copy of the original model and creates a mapping from old to new nodes.
* Creates initial version of the {@link JunctionTreeTransformMap}.
* <li>If {@link #useConditioning()} is true, then any variables that have a fixed value
* will be disconnected in the new graph using {@link Factor#removeFixedVariables()} and
* will be recorded in the transform map.
*
* <li>Organize the new graph into "cliques" of mutually connected variables using the variable
* elimination order. For each variable in order, find all neighboring variables that have
* not already been "eliminated" create a new clique containing the variable and its neighbors
* and create a temporary factor that connects the neighbors (but not the variable itself) to each
* other. All non-temporary factors that are connected to the variable and that have not been assigned to a
* previous clique are assigned to the new clique. After all variables have been "eliminated", remove any
* temporary factors from the new graph.
*
* <li>Use a modified version of Prim's algorithm to build a max spanning tree over the cliques where the edge weight
* is the number of variables in common between two cliques. Ties are broken in favor of edges with lower
* joint variable cardinality in order to favor smaller messages. When a new edge contains all of the variables
* of one of the cliques, then instead of adding the edge, the two cliques are merged into one.
*
* <li>For each clique, create a new factor that connects all of its variables by combining all of the
* factors that have been assigned to it using {@link FactorGraph#join(Variable[], Factor...)}.
*
* <li>Create half-edges for variables that are in only one clique and therefore will not be in any edge
* created during spanning tree construction.
*
* <li>For each clique that has any multi-variable edge, create a new variable for each such edge and
* rewrite the factor to connect to the new edge variables. This will not increase the number of entries
* in the underlying factor table but will require it to be converted to a sparse representation and possibly
* reordered. It is possible for multiple edge variables for the same combinations of original variables to
* exist in the same graph. At the end of inference, they should all of the same beliefs but may differ before
* that.
*
* <li>The previous step may orphan some variables from the graph because they are subsumed by one or more new
* joint variables. Find any such variables and reconnect to the graph by adding a new deterministic factor that
* marginalizes out the variable value from the smallest joint variable that contains it.
* </ol>
*
* <h2>References</h2>
* <ul>
* <li>David Barber.
* <a href="http://www.cs.ucl.ac.uk/staff/d.barber/brml/">
* Bayesian Reasoning and Machine Learning.</a>
* Chapter 6.
*
* <li>Daphne Koller & Nir Friedman.
* <a href="http://mitpress.mit.edu/books/probabilistic-graphical-models">
* Probabilistic Graphical Models: <em>Principals and Techniques</em></a>
* Chapter 10.
* </ul>
* <p>
*
* @since 0.05
* @author Christopher Barber
*/
public class JunctionTreeTransform
{
/*-------
* State
*/
/**
* Default value of {@link #maxTransformationAttempts()}
*/
public static final int DEFAULT_MAX_TRANSFORMATION_ATTEMPTS = 10;
private int _nEliminationAttempts = DEFAULT_MAX_TRANSFORMATION_ATTEMPTS;
private boolean _useConditioning = false;
private CostFunction[] _costFunctions = {};
private Random _rand = new Random();
/**
* Orders variables by id.
*/
private static final Comparator<Variable> _variableComparator = Variable.orderById;
/*--------------
* Construction
*/
public JunctionTreeTransform()
{
}
/*---------
* Options
*/
/**
* The random number generator used by this transformer. This is only used when determining
* the variable elimination ordering and is passed to the underlying {@link VariableEliminator}.
* @see #random(Random)
*/
public Random random()
{
return _rand;
}
/**
* Sets {@link #random()} to specified generator.
* <p>
* Only intended for use in testing to allow for reproduction of test results from a known seed.
*
* @return this
*/
public JunctionTreeTransform random(Random rand)
{
_rand = rand;
return this;
}
/**
* If true, then the transformation will condition out any variables that have a fixed value.
* This will produce a more efficient graph but will prevent it from being reused if the fixed
* value changes.
* <p>
* False by default.
* @see #useConditioning(boolean)
*/
public boolean useConditioning()
{
return _useConditioning;
}
/**
* Sets {@link #useConditioning()} to specified value.
* @return this
*/
public JunctionTreeTransform useConditioning(boolean value)
{
_useConditioning = value;
return this;
}
/**
* The cost functions used by {@link VariableEliminator} to determine the variable
* elimination ordering. If empty (the default), then all of the standard {@link VariableCost}
* functions will be tried.
*
* @see #variableEliminatorCostFunctions(VariableEliminator.CostFunction...)
* @see #variableEliminatorCostFunctions(VariableEliminator.VariableCost...)
*/
public CostFunction[] variableEliminatorCostFunctions()
{
return _costFunctions.clone();
}
/**
* Sets {@link #variableEliminatorCostFunctions()} to specified value.
* @return this
* @see #variableEliminatorCostFunctions(VariableEliminator.VariableCost...)
*/
public JunctionTreeTransform variableEliminatorCostFunctions(CostFunction ... costFunctions)
{
_costFunctions = costFunctions.clone();
return this;
}
/**
* Sets {@link #variableEliminatorCostFunctions()} to specified value.
* @return this
* @see #variableEliminatorCostFunctions(VariableEliminator.CostFunction...)
*/
public JunctionTreeTransform variableEliminatorCostFunctions(VariableCost ... costFunctions)
{
_costFunctions = VariableCost.toFunctions(costFunctions);
return this;
}
/**
* Specifies the maximum number of times to attempt to determine an optimal junction tree
* transformation.
* <p>
* This is the number of iterations of the {@link VariableEliminator} algorithm when attempting
* to determine the variable elimination ordering that determines the junction tree
* transofmration Each iteration will pick a cost function from
* {@link #variableEliminatorCostFunctions()} at random and will randomize the order of
* variables that have equivalent costs. A higher number of iterations may produce a better
* ordering.
* <p>
* Default value is specified by {@link #DEFAULT_MAX_TRANSFORMATION_ATTEMPTS}.
* <p>
*
* @see #maxTransformationAttempts(int)
*/
public int maxTransformationAttempts()
{
return _nEliminationAttempts;
}
/**
* Sets {@link #maxTransformationAttempts()} to the specified value.
* @return this
*/
public JunctionTreeTransform maxTransformationAttempts(int attempts)
{
_nEliminationAttempts = attempts;
return this;
}
/*------------------------------
* Inner implementation classes
*/
private static class Clique
{
/**
* Factors that make up the clique.
*/
private Factor[] _factors;
/**
* Variables in clique
*/
private Discrete[] _variables;
private final SkipSet<CliqueEdge> _edges;
private boolean _hasMultiVariableEdge = false;
/**
* Merged factor for the clique. May be null if all of the original factors
* were already incorporated into another clique.
*/
private @Nullable Factor _mergedFactor = null;
//
// Temporary spanning tree state
//
private boolean _inSpanningTree = false;
private @Nullable IHeap.IEntry<Clique> _heapEntry = null;
private @Nullable CliqueEdge _bestEdge = null;
/*--------------
* Construction
*/
private Clique(List<Discrete> variables, List<Factor> factors)
{
_factors = factors.toArray(new Factor[factors.size()]);
_variables = variables.toArray(new Discrete[variables.size()]);
_edges = SkipSet.create();
Arrays.sort(_variables, _variableComparator);
}
/*----------------
* Object methods
*/
@Override
public String toString()
{
return Arrays.toString(_variables);
}
/*---------
* Methods
*/
/**
* Merges variables and factors from {@code absorbee} into this clique and removes
* {@code absorbee} from {@code varToCliques}.
*
* @param absorbee must not have any edges assigned to it.
* @param varToCliques
* @return true if variables were added to this clique
*/
private boolean absorbClique(Clique absorbee, SetMultimap<Discrete, Clique> varToCliques)
{
assert(absorbee._edges.isEmpty());
boolean variablesAdded = false;
for (Discrete var : absorbee._variables)
{
varToCliques.remove(var, absorbee);
varToCliques.put(var, this);
}
if (absorbee._variables.length > _variables.length)
{
_variables = absorbee._variables;
variablesAdded = true;
}
absorbee._variables = new Discrete[0];
_factors = ObjectArrays.concat(_factors, absorbee._factors, Factor.class);
absorbee._factors = new Factor[0];
return variablesAdded;
}
private void addEdge(CliqueEdge edge)
{
_hasMultiVariableEdge |= edge._variables.length > 1;
_edges.add(edge);
}
private void addToMap(SetMultimap<Discrete, Clique> map)
{
for (Discrete variable : _variables)
{
map.put(variable, this);
}
}
private int indexOfVariable(Discrete variable)
{
return Arrays.binarySearch(_variables, variable, _variableComparator);
}
private boolean joinMultivariateEdges()
{
if (!_hasMultiVariableEdge)
return false;
final int nEdges = _edges.size();
final CliqueEdge[] edges = _edges.toArray(new CliqueEdge[nEdges]);
// Build mapping from new to old indices.
final Discrete[] newVariables = new Discrete[nEdges];
final DiscreteDomain[] newDomains = new DiscreteDomain[nEdges];
final int[][] newFromOld = new int[nEdges][];
final int[][] scratchIndices = new int[nEdges][];
for (int edgei = 0; edgei < nEdges; ++edgei)
{
final CliqueEdge edge = edges[edgei];
final int nEdgeVars = edge._variables.length;
final int[] a = new int[nEdgeVars];
final Discrete jointVar = requireNonNull(edge._jointVariable);
newVariables[edgei] = jointVar;
newDomains[edgei] = jointVar.getDiscreteDomain();
for (int vari = 0; vari < nEdgeVars; ++vari)
{
final Discrete edgeVar = edge._variables[vari];
a[vari] = indexOfVariable(edgeVar);
}
newFromOld[edgei] = a;
if (nEdgeVars > 1)
{
scratchIndices[edgei] = new int[nEdgeVars];
}
}
// Compute new factor table energies and indices.
final Factor mergedFactor = requireNonNull(_mergedFactor);
final IFactorTable oldFactorTable = requireNonNull(mergedFactor.getFactorTable());
final int nEntries = oldFactorTable.countNonZeroWeights();
final int[][] indices = new int[nEntries][];
final double[] energies = new double[nEntries];
final IFactorTableIterator oldIter = oldFactorTable.iterator();
int si = 0;
while (oldIter.advance())
{
final int[] oldIndices = oldIter.indicesUnsafe();
final int[] newIndices = new int[nEdges];
for (int edgei = 0; edgei < nEdges; ++edgei)
{
final CliqueEdge edge = edges[edgei];
final int nEdgeVars = edge._variables.length;
final int[] map = newFromOld[edgei];
if (nEdgeVars == 1)
{
newIndices[edgei] = oldIndices[map[0]];
}
else
{
final int[] scratch = scratchIndices[edgei];
for (int vari = 0; vari < nEdgeVars; ++vari)
{
scratch[vari] = oldIndices[map[vari]];
}
newIndices[edgei] = ((JointDiscreteDomain<?>)newDomains[edgei]).getIndexFromIndices(scratch);
}
}
energies[si] = oldIter.energy();
indices[si] = newIndices;
++si;
}
// Create the new table
final IFactorTable newFactorTable = FactorTable.create(newDomains);
newFactorTable.setEnergiesSparse(indices, energies);
// Remove the old factor
FactorGraph graph = requireNonNull(mergedFactor.getParentGraph());
graph.remove(mergedFactor);
// Create new factor attached to edge variables
_mergedFactor = graph.addFactor(newFactorTable, newVariables);
return true;
}
private boolean updateBestEdge(CliqueEdge incomingEdge)
{
CliqueEdge bestEdge = _bestEdge;
if (bestEdge == null || bestEdge._weight < incomingEdge._weight)
{
_bestEdge = incomingEdge;
return true;
}
return false;
}
/**
* The number of variables in the clique including the eliminated variable.
*/
private int size()
{
return _variables.length;
}
}
/**
* Represents an edge between two cliques/factors. Will be represented as a single
* variable in the new graph.
*/
private static class CliqueEdge implements Comparable<CliqueEdge>
{
private final @Nullable Clique _from;
private final Clique _to;
/**
* Variables that are transmitted across this edge.
*/
private final Discrete[] _variables;
/**
* The variable representing this edge in the new graph.
*/
private @Nullable Discrete _jointVariable;
/**
* Weight is # of variables on edge plus the reciprocal of the joint cardinality.
* This will favor the edge with the most variables and smallest cardinality.
*/
private final double _weight;
private CliqueEdge(@Nullable Clique from, Clique to, Discrete ... variables)
{
_from = from;
_to = to;
_variables = variables;
long cardinality = 1;
for (Discrete variable : variables)
{
cardinality *= variable.getDomain().size();
}
_weight = variables.length + 1 / (double)cardinality;
if (variables.length == 1)
{
_jointVariable = variables[0].asDiscreteVariable();
}
assert(ArrayUtil.isSorted(variables, _variableComparator));
}
@Override
public String toString()
{
return String.format("%s =%s=> %s", _from, Arrays.toString(_variables), _to);
}
@Override
@NonNullByDefault(false)
public int compareTo(CliqueEdge that)
{
return _variableComparator.compare(this._jointVariable, that._jointVariable);
}
/**
* @return true if cliques connected by this edge could be merged because one is a subset of the other.
*/
private boolean isMergeable()
{
// Since the edge contains all variables that are in common between the two cliques,
// we only need to compare the lengths to see if there is a subset relationship.
final Clique from = _from;
final int size = _variables.length;
return from != null && (_to._variables.length == size || from._variables.length == size);
}
private @Nullable AddedJointVariable<?> makeJointVariable(FactorGraph targetModel)
{
final Discrete[] edgeVars = _variables;
final int nEdgeVars = edgeVars.length;
AddedJointVariable<?> addedVar = null;
if (nEdgeVars > 1)
{
// Create joint variable for edge
final DiscreteDomain[] edgeDomains = new DiscreteDomain[edgeVars.length];
for (int i = 0; i < nEdgeVars; ++i)
{
edgeDomains[i] = edgeVars[i].getDomain();
}
final Discrete jointVar = new Discrete(DiscreteDomain.joint(edgeDomains));
StringBuilder jointName = new StringBuilder();
for (int i = 0; i < nEdgeVars; ++i)
{
if (i > 0)
{
jointName.append("+");
}
jointName.append(edgeVars[i].getName());
}
jointVar.setName(jointName.toString());
targetModel.addVariables(jointVar);
_jointVariable = jointVar;
addedVar = new AddedJointDiscreteVariable(jointVar, edgeVars);
}
return addedVar;
}
}
/*---------
* Methods
*/
/**
* Build junction tree transformation.
* <p>
* @see #transform(FactorGraph, ArrayList)
* @see #transform(FactorGraph, VariableEliminator.Ordering)
*/
public JunctionTreeTransformMap transform(FactorGraph model)
{
return transform(model, buildEliminationOrder(model));
}
/**
* Build junction tree transformation using a specified variable elimination ordering.
* @param eliminationOrder is an ordering of the variables in the {@code model}. It must include
* every variable exactly once.
* <p>
* @see #transform(FactorGraph)
* @see #transform(FactorGraph, VariableEliminator.Ordering)
*/
public JunctionTreeTransformMap transform(FactorGraph model, ArrayList<Variable> eliminationOrder)
{
// Validate variables
final VariableList variables = model.getVariables();
if (eliminationOrder.size() != model.getVariableCount() || !variables.containsAll(eliminationOrder))
{
throw new IllegalArgumentException("Elimination order does not specify same variables as the model");
}
Stats stats = new Stats();
if (_useConditioning)
{
// Make sure conditioned variables are at front of the ordering.
Collections.sort(eliminationOrder, new Comparator<Variable>() {
@Override
@NonNullByDefault(false)
public int compare(Variable var1, Variable var2)
{
return var1.hasFixedValue() ? (var2.hasFixedValue() ? 0 : -1) : (var2.hasFixedValue() ? 1 : 0);
}
});
int nConditioned = 0;
for (Variable var : eliminationOrder)
{
if (var.hasFixedValue())
{
++nConditioned;
}
else
{
break;
}
}
stats.conditionedVariables(nConditioned);
}
if (model.isForest())
{
// If not a forest, then at least two factors would have to be merged.
stats.mergedFactors(2);
}
return transform(model, new Ordering(eliminationOrder, stats));
}
/**
* Build junction tree transformation using a specified variable elimination ordering.
* <p>
* @param eliminationOrder is a valid variable ordering for this graph created by {@link VariableEliminator}
* on this {@code model}.
*
* @see #transform(FactorGraph)
* @see #transform(FactorGraph, ArrayList)
*/
public JunctionTreeTransformMap transform(FactorGraph model, Ordering eliminationOrder)
{
// 1) Determine an elimination order
final Stats orderStats = eliminationOrder.stats;
if (orderStats.alreadyGoodForFastExactInference())
{
// If elimination order introduces no edges, graph is already a tree. Done.
return JunctionTreeTransformMap.identity(model);
}
if (orderStats.factorsWithDuplicateVariables() > 0)
{
// FIXME - support duplicate variables in JunctionTreeTransform
throw DimpleException.unsupported("factors with duplicate variables");
}
// 2) Make copy of the factor graph
final ArrayList<Variable> variables = eliminationOrder.variables;
final int nVariables = variables.size();
final int nFactors = model.getFactorCount();
final BiMap<Object,Object> old2new = HashBiMap.create(nVariables * 2);
final FactorGraph targetModel = model.copyRoot(old2new);
targetModel.unsetOption(BPOptions.scheduler); // Don't use copied scheduler
// Make copied factors undirected.
for (Factor factor : targetModel.getFactors())
{
factor.setUndirected();
}
final JunctionTreeTransformMap transformMap = JunctionTreeTransformMap.create(model, targetModel);
for (Entry<Object,Object> entry : old2new.entrySet())
{
final Object source = entry.getKey();
if (source instanceof Variable)
{
transformMap.addVariableMapping((Variable)source, Objects.requireNonNull((Variable)entry.getValue()));
}
}
// 3) Disconnect conditioned variables from other variables in new graph
disconnectConditionedVariables(eliminationOrder, transformMap);
// 4) Create cliques using variable elimination order
final List<Clique> cliques = createCliques(eliminationOrder, transformMap);
final SetMultimap<Discrete, Clique> varToCliques =
HashMultimap.create(nVariables, nVariables/Math.max(1,nFactors));
for (Clique clique : cliques)
{
clique.addToMap(varToCliques);
}
// 5) Use Prim's algorithm to build max spanning tree over clique graph where the edge weight
// is the number of variables in common between the two cliques along each edge.
final List<CliqueEdge> multiVariateEdges = formSpanningTree(transformMap, cliques, varToCliques);
// 6) Merge factors in cliques
for (Clique clique : cliques)
{
if (clique._variables.length > 0)
{
clique._mergedFactor = targetModel.join(clique._variables, clique._factors);
}
}
// 7) Rewrite factors with multivariate edges
for (Clique clique : cliques)
{
clique.joinMultivariateEdges();
for (Factor cliqueFactor : clique._factors)
{
Factor sourceFactor = (Factor) old2new.inverse().get(cliqueFactor);
transformMap.addFactorMapping(sourceFactor, Objects.requireNonNull(clique._mergedFactor));
}
}
// 8) Find and reconnect orphaned variables.
reconnectOrphanVariables(targetModel, multiVariateEdges);
return transformMap;
}
//-----------------
// Private methods
//
private Ordering buildEliminationOrder(FactorGraph model)
{
// Find max cardinality of existing factors - we can't do better than that.
int maxCardinality = 0;
for (Factor factor : model.getFactors())
{
if (factor.hasFactorTable())
{
maxCardinality = Math.max(maxCardinality, factor.getFactorTable().getDomainIndexer().getCardinality());
}
}
VariableEliminator.Stats threshold = new VariableEliminator.Stats().maxCliqueCardinality(maxCardinality);
VariableEliminator eliminator = new VariableEliminator(model, _useConditioning, _rand);
return VariableEliminator.generate(eliminator, _nEliminationAttempts, threshold, _costFunctions);
}
private int disconnectConditionedVariables(Ordering eliminationOrder, JunctionTreeTransformMap transformMap)
{
final int nConditioned = _useConditioning ? eliminationOrder.stats.conditionedVariables() : 0;
if (nConditioned > 0)
{
// Build list of factors that need to be modified
final Map<Factor, Factor> factors = new LinkedHashMap<Factor, Factor>();
// Conditioned variables are guaranteed to be at the head of the list.
// Conditioned variables aren't necessarily discrete but the remaining ones
// should be discrete.
for (int i = 0; i < nConditioned; ++i)
{
Variable sourceVariable = eliminationOrder.variables.get(i);
Variable variable = transformMap.sourceToTargetVariable(sourceVariable);
transformMap.addConditionedVariable(sourceVariable);
for (int j = 0, endj = variable.getSiblingCount(); j < endj; ++j)
{
final Factor factor = variable.getSibling(j);
factors.put(factor, factor);
}
}
for (Factor factor : factors.values())
{
factor.removeFixedVariables();
}
}
return nConditioned;
}
private List<Clique> createCliques(Ordering eliminationOrder, JunctionTreeTransformMap transformMap)
{
final List<Clique> cliques = new LinkedList<Clique>();
final ArrayList<Variable> variables = eliminationOrder.variables;
final int nVariables = variables.size();
final int nConditioned = eliminationOrder.stats.conditionedVariables();
final FactorGraph targetModel = transformMap.target();
final List<Factor> temporaryFactors = Lists.newLinkedList();
for (int vari = nConditioned; vari < nVariables; ++vari)
{
final Discrete var = (Discrete) transformMap.sourceToTargetVariable(variables.get(vari));
// Mark variable as "eliminated". Note that there is no need to clear the mark
// at the start because all of the variables are newly created.
var.setMarked();
final List<Factor> cliqueFactors = new LinkedList<Factor>();
final List<Discrete> cliqueVars = new LinkedList<Discrete>();
final int nVarFactors = var.getSiblingCount();
for (int fi = nVarFactors; --fi>=0;)
{
final Factor neighborFactor = var.getSibling(fi);
if (!neighborFactor.isMarked())
{
cliqueFactors.add(neighborFactor);
// Mark factor to indicate it has been assigned to a clique.
neighborFactor.setMarked();
}
for (int vi = neighborFactor.getSiblingCount(); --vi>=0;)
{
// Non-Discrete variables should already have been removed during conditioning step.
final Discrete neighborVar = neighborFactor.getSibling(vi).asDiscreteVariable();
if (!neighborVar.isMarked() && !neighborVar.wasVisited())
{
cliqueVars.add(neighborVar);
neighborVar.setVisited();
}
}
}
for (Discrete cliqueVar : cliqueVars)
{
cliqueVar.clearVisited();
}
// Add temporary factor to connect remaining variables in clique to each other.
if (cliqueVars.size() > 1 && nVarFactors > 1)
{
final Factor temporaryFactor = targetModel.addFactor(Uniform.INSTANCE, cliqueVars.toArray());
temporaryFactor.setMarked();
temporaryFactors.add(temporaryFactor);
}
cliqueVars.add(var);
Clique clique = new Clique(cliqueVars, cliqueFactors);
cliques.add(clique);
}
for (Factor temporaryFactor : temporaryFactors)
{
targetModel.remove(temporaryFactor);
}
return cliques;
}
private List<CliqueEdge> formSpanningTree(
JunctionTreeTransformMap transformMap,
List<Clique> cliques,
SetMultimap<Discrete, Clique> varToCliques)
{
final int nCliques = cliques.size();
assert(nCliques > 0);
final IHeap<Clique> heap = BinaryHeap.create(nCliques);
heap.deferOrderingForBulkAdd(nCliques);
Clique maxClique = null;
for (Clique clique : cliques)
{
if (maxClique == null || clique.size() > maxClique.size())
{
maxClique = clique;
}
clique._heapEntry = heap.offer(clique, Double.POSITIVE_INFINITY);
}
requireNonNull(maxClique);
heap.changePriority(requireNonNull(maxClique._heapEntry), Double.NEGATIVE_INFINITY);
// Edges with more than one variable
List<CliqueEdge> multiVariateEdges = Lists.newArrayListWithCapacity(nCliques - 1);
final FactorGraph targetModel = transformMap.target();
for (Clique clique; (clique = heap.poll()) != null;)
{
clique._inSpanningTree = true;
clique._heapEntry = null;
final CliqueEdge addedEdge = clique._bestEdge;
if (addedEdge != null)
{
final Clique prevClique = addedEdge._from;
if (addedEdge.isMergeable())
{
if (requireNonNull(prevClique).absorbClique(clique, varToCliques))
{
clique = prevClique;
}
else
{
clique = null;
}
}
else
{
final AddedJointVariable<?> addedVar = addedEdge.makeJointVariable(targetModel);
if (addedVar != null)
{
multiVariateEdges.add(addedEdge);
transformMap.addDeterministicVariable(addedVar);
}
requireNonNull(prevClique).addEdge(addedEdge);
clique.addEdge(addedEdge);
}
}
if (clique != null)
{
for (CliqueEdge edge : edgesNotInTree(clique, varToCliques))
{
final Clique targetClique = edge._to;
if (targetClique.updateBestEdge(edge))
{
// Use negative weight because IHeap implements a min heap.
heap.changePriority(Objects.requireNonNull(targetClique._heapEntry), -edge._weight);
}
}
}
}
// Add half-edges for variables that are in only one clique and therefore won't be in any
// edge created in the previous step.
for (Discrete var : varToCliques.keySet())
{
final Set<Clique> cliquesForVar = varToCliques.get(var);
if (cliquesForVar.size() == 1)
{
final Clique clique = Iterables.getOnlyElement(cliquesForVar);
clique.addEdge(new CliqueEdge(null, clique, var));
}
}
return multiVariateEdges;
}
private void reconnectOrphanVariables(FactorGraph targetModel, List<CliqueEdge> multiVariateEdges)
{
// Find orphaned variables and map each to the smallest joint variable that subsumes it
final Map<Discrete, Tuple2<Discrete,Integer>> orphanVarToJointVar = Maps.newLinkedHashMap();
for (CliqueEdge edge : multiVariateEdges)
{
final Discrete jointVar = Objects.requireNonNull(edge._jointVariable);
for (int vari = 0, nVars = edge._variables.length; vari < nVars; ++vari)
{
final Discrete var = edge._variables[vari];
if (!var.wasVisited())
{
var.setVisited();
if (var.getSiblingCount() == 0)
{
final Tuple2<Discrete,Integer> tuple = orphanVarToJointVar.get(var);
if (tuple == null || tuple.first.getDomain().size() > jointVar.getDomain().size())
{
orphanVarToJointVar.put(var, Tuple2.create(jointVar, vari));
}
}
}
}
}
// Create marginal factors that connects each orphan variable to a joint variable.
//
// TODO: if two or more orphaned variables are attached to the same joint variable, it
// can be expressed using a single factor instead of one per variable
for (Entry<Discrete,Tuple2<Discrete,Integer>> entry : orphanVarToJointVar.entrySet())
{
final Discrete orphan = entry.getKey();
final Discrete joint = entry.getValue().first;
final int subindex= entry.getValue().second;
final JointDiscreteDomain<?> jointd = (JointDiscreteDomain<?>)joint.getDomain();
Factor factor = targetModel.addFactor(FactorTable.createMarginal(subindex, jointd), orphan, joint);
factor.setDirectedTo(new int[] { 0 });
}
}
/**
* Computes list of edges from {@code clique} to other cliques that are not yet in the
* spanning tree.
*
* @param clique
* @param varToCliques is a mapping from variable to the cliques that contain it.
*/
private List<CliqueEdge> edgesNotInTree(Clique clique, SetMultimap<Discrete, Clique> varToCliques)
{
final SetMultimap<Clique, Discrete> cliqueToCommonVars = LinkedHashMultimap.create();
// NOTE: because clique._variables is sorted the variables in the edge list will also be sorted.
for (Discrete variable : clique._variables)
{
for (Clique neighbor : varToCliques.get(variable))
{
if (!neighbor._inSpanningTree)
{
cliqueToCommonVars.put(neighbor, variable);
}
}
}
List<CliqueEdge> edges = new ArrayList<CliqueEdge>(cliqueToCommonVars.size());
for (Clique neighbor : cliqueToCommonVars.keySet())
{
final Set<Discrete> commonVars = cliqueToCommonVars.get(neighbor);
edges.add(new CliqueEdge(clique, neighbor, ArrayUtil.copy(Discrete.class, commonVars)));
}
return edges;
}
}