/******************************************************************************* * Copyright 2014 Analog Devices, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. ********************************************************************************/ package com.analog.lyric.dimple.model.transform; import static java.util.Objects.*; import java.io.Serializable; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.LinkedHashMap; import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Random; import java.util.Set; import net.jcip.annotations.Immutable; import org.eclipse.jdt.annotation.Nullable; import com.analog.lyric.collect.BinaryHeap; import com.analog.lyric.collect.IHeap; import com.analog.lyric.collect.IHeap.IEntry; import com.analog.lyric.dimple.exceptions.DimpleException; import com.analog.lyric.dimple.model.core.FactorGraph; import com.analog.lyric.dimple.model.factors.Factor; import com.analog.lyric.dimple.model.factors.FactorBase; import com.analog.lyric.dimple.model.factors.FactorList; import com.analog.lyric.dimple.model.variables.Variable; import com.analog.lyric.dimple.model.variables.VariableList; import com.google.common.collect.Iterators; import com.google.common.collect.Sets; /** * Computes a variable elimination order for a factor graph using a greedy * algorithm over various cost functions. * <p> * A variable elimination order is an ordering of the variables in a graph * that is intended to minimize the cost of various exact inference algorithms * that are based on it. While the exact inference cost will depend on the * algorithm, in general, orderings that minimize the cost of the classic * variable elimination algorithm are also good for other algorithms as * well, including the construction of junction trees, finding loop cuts * for loop cut conditioning, and for constructing graph partition trees * for recursive conditioning schemes. * <p> * The problem of finding an optimal ordering is NP-hard, but the heuristic * greedy approach implemented by this class has been shown to produce good * results with reasonable time complexity. The algorithm implemented here * is as follows: * <p> * <ol> * <li>Build a variable to variable adjacency list representation for * the variables in the factor graph. Note that this representation will * be modified during the execution of the algorithm so the actual model * representation cannot be used directly. * * <li>Pick a cost function that will be used to order variables in the * graph. The specific cost functions will be described below. * * <li>Build a priority queue (heap) containing all of the variables ordered according * to cost, with the minimum cost at the front. * * <li>While the priority queue is not empty: * <ol> * <li>Remove the variable with the lowest cost from the queue. * * <li>Add the variable to the end of the variable elimination order * * <li>Connect all of the variable's neighbors with each other by adding * edges as necessary. * * <li>Remove the variable from the graph by removing it from its neighbors * adjacency sets. * * <li>Update the priority queue as appropriate to reflect the changes to * the value of the cost function resulting from the changes to the graph. * </ol> * </ol> * <p> * There are four standard cost functions that are described in the literature * and supported by this implementation: * <ul> * <li>{@link VariableCost#MIN_NEIGHBORS} * <li>{@link VariableCost#WEIGHTED_MIN_NEIGHBORS} * <li>{@link VariableCost#MIN_FILL} * <li>{@link VariableCost#WEIGHTED_MIN_FILL} * </ul> * Users can implement additional cost functions by subclassing {@link CostFunction}. * <p> * NOTE: this implementation currently does not handle models that contain non-Discrete variables * unless they have fixed values and {@link #usesConditioning()} is true. It also * and it does not take into account the contents of the factor tables. * * @author Christopher Barber * @since 0.05 */ public class VariableEliminator { /** * Describes the standard built-in variable cost functions supported by {@link VariableEliminator}. * See members for details. */ public static enum VariableCost { /** * Cost is the number of neighboring variables that have not yet * been eliminated. * @see MinNeighbors */ MIN_NEIGHBORS(new MinNeighbors()), /** * Cost is the product of the domain cardinalities of the neighboring variables * that have not yet been eliminated. * @see MinWeight */ WEIGHTED_MIN_NEIGHBORS(new MinWeight()), /** * Cost is the number of edges that would be introduced between neighboring variables * if this variable were to be eliminated at this step. * @see MinFill */ MIN_FILL(new MinFill()), /** * Cost is the sum of the weights of the edges that would be introduced between neighboring variables * if this variable were to be eliminated at this step, where the edge weights are the products of * the domain cardinalities of the variables connected by the edge. * @see WeightedMinFill */ WEIGHTED_MIN_FILL(new WeightedMinFill()); private final CostFunction _costFunction; private VariableCost(CostFunction costFunction) { _costFunction = costFunction; _costFunction._type = this; } public CostFunction function() { return _costFunction; } /** * Converts an array of {@link VariableCost} to a corresponding array of {@link CostFunction}. */ public static CostFunction[] toFunctions(VariableCost[] costFunctions) { final int nFunctions = costFunctions.length; final CostFunction[] functions = new CostFunction[nFunctions]; for (int i = 0; i < nFunctions; ++i) { functions[i] = costFunctions[i].function(); } return functions; } } /*------- * State */ private final FactorGraph _model; private final @Nullable Random _rand; /** * If true, then variables with fixed values will be eliminated first and will be * considered to be disjoint from the rest of the graph. */ private final boolean _useConditioning; /** * The number of variables in the model. Used to preallocate capacity for data structures. */ private int _nVariables; /*-------------- * Construction */ /** * Initialize for given model. * <p> * Invokes {@link #VariableEliminator(FactorGraph, boolean, Random)} with new {@link Random} instance. */ public VariableEliminator(FactorGraph model, boolean useConditioning) { this(model, useConditioning, new Random()); } /** * Initialize for given model and using given random number * generator. * @param useConditioning sets value of {@link #usesConditioning()} * @param rand is the random number generator used to randomly * break ties for variables with the same cost. If null, then * ties will be broken deterministically by favoring the variable * with the lower id ({@link Variable#getId()}), which is useful * for testing. */ public VariableEliminator(FactorGraph model, boolean useConditioning, @Nullable Random rand) { _model = model; _rand = rand; _useConditioning = useConditioning; _nVariables = model.getVariableCount(); } /*--------- * Methods */ /** * Computes a variable elimination order by iteratively retrying using one or more cost functions * and choosing the best fit according to the specified threshold statistics. * <p> * This function builds an eliminator for specified {@code mode} and {@code useConditioning} attribute. * It then iteratively up to {@code nAttempts} times picks a cost function at random from {@code costFunctions} * and uses it to build an {@link OrderIterator} from which it generates an ordering. After each iteration, * the global statistics (from the {@linkplain OrderIterator#getStats() getStats()} method on the iterator) * are compared against the best statistics so far using the statistic's * {@linkplain VariableEliminator.Stats#compareTo compareTo} method to determine whether to keep the ordering. * If the stats at any point satisfy the specified threshold values (as determined by the * {@linkplain VariableEliminator.Stats#meetsThreshold meetsThreshold method} then the function will return * immediately. * <p> * For example, the following call will generate an order by conditioning out any fixed value variables, * and will randomly try weighted min neighbors or weighted min fill cost functions up to ten iterations * and returning the first one to achieve a max clique cardinality of no more than 42, otherwise returns * the order with the best max clique cardinality: * * <pre>{@code * Ordering order = generateStochastically(fg, true, 10, * new Stats().maxCliqueCardinality(42), * VariableCost.WEIGHTED_MIN_NEIGHBORS, VariableCost.WEIGHTED_MIN_FILL) * }</pre> * <p> * @param model is the graph for which the eliminator order is being computed. * @param useConditioning specifies whether to use conditioning (see {@link #usesConditioning()} * @param nAttempts is the number of potential iteration orders to compute. If not a positive value, * then each cost function will be tried once deterministically. * @param threshold specifies which statistics should be used to evaluate the goodness of a given * ordering (see {@link VariableEliminator.Stats#compareTo}) and also threshold values for each * statistic that will terminate the function before all {@code nAttempts} have been tried * (see {@link VariableEliminator.Stats#meetsThreshold}). Only statistics with non-negative threshold * values will be considered. * @param costFunctions is a list of cost functions to be used. If empty, all will be tried. * * @return the variable elimination order that best satisfied the {@code threshold} statistics. */ public static Ordering generate( FactorGraph model, boolean useConditioning, int nAttempts, Stats threshold, VariableCost ... costFunctions) { return generate(model, useConditioning, nAttempts, threshold, VariableCost.toFunctions(costFunctions)); } /** * A more general version of {@link #generate(FactorGraph, boolean, int, Stats, VariableCost...)} * but accepting {@link CostFunction} objects, which allows for user-defined cost functions. */ public static Ordering generate( FactorGraph model, boolean useConditioning, int nAttempts, Stats threshold, CostFunction ... costFunctions) { final boolean deterministic = nAttempts <= 0; final VariableEliminator eliminator = deterministic? new VariableEliminator(model, useConditioning, null) : new VariableEliminator(model, useConditioning); return generate(eliminator, nAttempts, threshold, costFunctions); } /** * Invokes {@link #generate(FactorGraph, boolean, int, Stats, VariableCost...)} with * all standard {@link VariableCost} functions. * <p> * @since 0.05 */ public static Ordering generate( FactorGraph model, boolean useConditioning, int nAttempts, Stats threshold) { return generate(model, useConditioning, nAttempts, threshold, VariableCost.values()); } /** * Computes a variable elimination order by iteratively retrying using one or more cost functions * and choosing the best fit according to the specified threshold statistics. * <p> * Same as {@link #generate(FactorGraph, boolean, int, Stats, VariableCost...)} but uses provided * eliminator instead of building a new one. */ public static Ordering generate( VariableEliminator eliminator, int nAttempts, Stats threshold, VariableCost ... costFunctions) { return generate(eliminator, nAttempts, threshold, VariableCost.toFunctions(costFunctions)); } /** * Computes a variable elimination order by iteratively retrying using one or more cost functions * and choosing the best fit according to the specified threshold statistics. * <p> * Same as {@link #generate(FactorGraph, boolean, int, Stats)} but uses provided * eliminator instead of building a new one. */ public static Ordering generate( VariableEliminator eliminator, int nAttempts, Stats threshold) { return generate(eliminator, nAttempts, threshold, VariableCost.values()); } /** * Computes a variable elimination order by iteratively retrying using one or more cost functions * and choosing the best fit according to the specified threshold statistics. * <p> * Same as {@link #generate(FactorGraph, boolean, int, Stats, CostFunction...)} but uses provided * eliminator instead of building a new one. */ public static Ordering generate( VariableEliminator eliminator, int nAttempts, Stats threshold, CostFunction ... costFunctions) { final boolean deterministic = nAttempts <= 0; if (costFunctions.length == 0) { costFunctions = VariableCost.toFunctions(VariableCost.values()); } final int nFunctions = costFunctions.length; // Cumulative distribution function for choosing cost function. Initially // set to uniform weights. final double[] functionCDF = new double[nFunctions]; { final double increment = 1.0 / nFunctions; double cumProb = increment; for (int i = 0; i < nFunctions; ++i) { functionCDF[i] = cumProb; cumProb += increment; } } final long[] timePerFunction = new long[nFunctions]; long totalTime = 0; if (deterministic) { nAttempts = nFunctions; } ArrayList<Variable> curList = new ArrayList<Variable>(eliminator._nVariables); ArrayList<Variable> bestList = new ArrayList<Variable>(eliminator._nVariables); Stats bestStats = null; Random rand = eliminator.getRandomizer(); if (rand == null) { rand = new Random(); } for (int attempt = 0; attempt < nAttempts; ++attempt) { // Pick a cost function int costIndex = 0; if (nFunctions > 1) { if (deterministic) { costIndex = attempt; } else { costIndex = Arrays.binarySearch(functionCDF, rand.nextDouble()); if (costIndex < 0) { costIndex = -costIndex - 1; } costIndex = Math.min(costIndex, nFunctions - 1); } } CostFunction cost = costFunctions[costIndex]; // Run variable elimination final long beforeNS = System.nanoTime(); OrderIterator iterator = eliminator.orderIterator(cost); Iterators.addAll(curList, iterator); final long elapsedNS = System.nanoTime() - beforeNS; timePerFunction[costIndex] += elapsedNS; totalTime += elapsedNS; // Compare stats Stats curStats = iterator.getStats(); if (curStats.addedEdges() == 0) { bestStats = curStats; bestList = curList; break; } if (bestStats == null || curStats.compareTo(bestStats, threshold) < 0) { ArrayList<Variable> tmp = curList; curList = bestList; curList.clear(); bestList = tmp; bestStats = curStats; if (bestStats.meetsThreshold(threshold)) { break; } } // Update functionCDF based on timings to favor cheaper cost function. // TODO: give bonus weight to functions that improved the stats. if (nFunctions > 1 && !deterministic) { final double normalizer = (double)totalTime * (nFunctions - 1); double cumProb = 0.0; for (int i = 0; i < nFunctions; ++i) { functionCDF[i] = cumProb += (totalTime - timePerFunction[i]) / normalizer; } } } if (bestStats == null) { bestStats = new Stats(null, 0); } return new Ordering(bestList, bestStats); } /** * The model for which ordering can be computed. */ public FactorGraph getModel() { return _model; } /** * The randomizer used to break ties between variables with the same cost. * When null, ties are broken deterministically. * * @see #VariableEliminator(FactorGraph, boolean, Random) */ public @Nullable Random getRandomizer() { return _rand; } /** * Returns an iterator to produce the variable ordering for the given cost function. * This may be invoked multiple times with different cost functions. When {@link #getRandomizer()} * is non-null, then running with the same cost function can produce different orderings. */ public OrderIterator orderIterator(VariableCost cost) { return orderIterator(cost.function()); } /** * Returns an iterator to produce the variable ordering for the given cost function. * This may be invoked multiple times with different cost functions. When {@link #getRandomizer()} * is non-null, then running with the same cost function can produce different orderings. */ public OrderIterator orderIterator(CostFunction cost) { return new OrderIterator(this, cost); } /** * True if eliminator takes into account variables that are conditioned with * a fixed value. If true, then such variables will be eliminated first and * will be treated as if they have no siblings. Value is set during construction. */ public boolean usesConditioning() { return _useConditioning; } /*---------- * Ordering */ /** * Holds a variable elimination ordering along with statistics for its derivation. * * @since 0.05 * @author Christopher Barber * @see VariableEliminator#generate(VariableEliminator, int, Stats, VariableCost...) */ @Immutable public static class Ordering { public final ArrayList<Variable> variables; public final Stats stats; Ordering(ArrayList<Variable> variables, Stats stats) { this.variables = variables; this.stats = stats; } } /*---------------- * OrderIterator */ /** * Produces a variable elimination order based on a given variable cost function. * <p> * Once iterator has terminated (i.e. {@link #hasNext()} is false), you can use * {@link #getStats()} to access statistics that can be used to measure the goodness * of the resulting ordering. * <p> * @see VariableEliminator#orderIterator(VariableCost) */ public static class OrderIterator implements Iterator<Variable> { private final VariableEliminator _eliminator; private final CostFunction _costFunction; private final IHeap<Var> _heap; private final Stats _stats; /*-------------- * Construction */ private OrderIterator(VariableEliminator eliminator, CostFunction costFunction) { _eliminator = eliminator; _costFunction = costFunction; _stats = new Stats(costFunction, 0); final List<Var> adjacencyList = eliminator.buildAdjacencyList(_stats); final int size = adjacencyList.size(); final IHeap<Var> heap = _heap = new BinaryHeap<Var>(size); for (Var var : adjacencyList) { var._heapEntry = heap.offer(var, var.adjustedCost(_costFunction)); } } /*------------------ * Iterator methods */ @Override public boolean hasNext() { return !_heap.isEmpty(); } @Override public @Nullable Variable next() { final CostFunction costFunction = _costFunction; final IHeap<Var> heap = _heap; Var var = heap.poll(); if (var == null) { return null; } // Remove variable from graph final boolean isConditioned =_eliminator.isConditioned(var._variable); if (isConditioned) { _stats.addConditionedVariable(); } long cliqueCardinality = isConditioned ? 1 : var.cardinality(); for (VarLink link = var._neighborList._next; link.hasVar(); link = link._next) { final Var neighbor = link.var(); neighbor.removeNeighbor(var); cliqueCardinality *= neighbor.cardinality(); } _stats.addClique(var, cliqueCardinality); // Add edges between remaining neighbors for (VarLink link1 = var._neighborList._next; link1.hasVar(); link1 = link1._next) { final Var neighbor1 = link1.var(); for (VarLink link2 = link1._next; link2.hasVar(); link2 = link2._next) { final Var neighbor2 = link2.var(); if (neighbor1.addNeighbor(neighbor2)) { neighbor2.addNeighbor(neighbor1); // Update added edge statistics _stats.addEdgeWeight(neighbor1.cardinality() * neighbor2.cardinality()); } } } // Update priorities if (costFunction.neighborsOnly()) { heap.deferOrderingForBulkChange(var.nNeighbors()); for (VarLink link = var._neighborList._next; link.hasVar(); link = link._next) { final Var neighbor = link.var(); IEntry<Var> heapEntry = neighbor._heapEntry; if (heapEntry != null) { heap.changePriority(heapEntry, neighbor.adjustedCost(costFunction)); } } } else { Set<Var> changeSet = new HashSet<Var>(); for (VarLink link1 = var._neighborList._next; link1.hasVar(); link1 = link1._next) { final Var neighbor = link1.var(); changeSet.add(neighbor); for (VarLink link2 = neighbor._neighborList._next; link2.hasVar(); link2 = link2._next) { changeSet.add(link2.var()); } } heap.deferOrderingForBulkChange(changeSet.size()); for (Var change : changeSet) { IEntry<Var> heapEntry = change._heapEntry; if (heapEntry != null) { heap.changePriority(heapEntry, change.adjustedCost(costFunction)); } } } return var._variable; } /** * Not supported. * @throws UnsupportedOperationException */ @Override public void remove() { throw new UnsupportedOperationException(); } /*--------------- * Local methods */ /** * The number of variables left to be returned by calls to {@link #next()}. */ public int size() { return _heap.size(); } /** * Identifies cost evaluator used by this iterator. */ public @Nullable CostFunction getCostEvaluator() { return _stats.cost(); } /** * The {@link VariableEliminator} that created this iterator. */ public VariableEliminator getEliminator() { return _eliminator; } /** * Incrementally updated statistics for the elimination order, which can be used to * measure the relative goodness of the resulting ordering. */ public Stats getStats() { return _stats; } } // OrderIterator /*------------------- * Elimination stats */ /** * Elimination quality statistics computed for an ordering. * <p> * @see OrderIterator#getStats() */ public static class Stats implements Cloneable { private final @Nullable CostFunction _cost; private int _addedEdges; private long _addedEdgeWeight; private int _conditionedVariables; private int _factorsWithDuplicateVariables; private int _maxClique; private long _maxCliqueCardinality; private int _mergedFactors; private int _variablesWithDuplicateEdges; /*-------------- * Construction */ /** * All values are initialized to -1. */ public Stats() { this(null, -1); } private Stats(@Nullable CostFunction costFunction, int value) { _cost = costFunction; _addedEdges = value; _addedEdgeWeight = value; _conditionedVariables = value; _factorsWithDuplicateVariables = value; _maxClique = value; _maxCliqueCardinality = value; _mergedFactors = value; _variablesWithDuplicateEdges = value; } public Stats(Stats that) { _cost = that._cost; _addedEdges = that._addedEdges; _addedEdgeWeight = that._addedEdgeWeight; _conditionedVariables = that._conditionedVariables; _factorsWithDuplicateVariables = that._factorsWithDuplicateVariables; _maxClique = that._maxClique; _maxCliqueCardinality = that._maxCliqueCardinality; _mergedFactors = that._mergedFactors; _variablesWithDuplicateEdges = that._variablesWithDuplicateEdges; } @Override public Stats clone() { return new Stats(); } /*-------------------- * Evaluation methods */ /** * False if statistics indicates that the original graph does not need to be transformed * to do efficient exact inference. * <p> * True if {@link #addedEdges()},{@link #conditionedVariables()}, {@link #mergedFactors()}, * {@link #factorsWithDuplicateVariables()}, and * {@link #variablesWithDuplicateEdges()} are all zero. */ public boolean alreadyGoodForFastExactInference() { return _addedEdges == 0 && _conditionedVariables == 0 && _factorsWithDuplicateVariables == 0 && _variablesWithDuplicateEdges == 0 && _mergedFactors == 0; } /** * Returns -1/0/1 if these stats are deemed better than/same as/worse than {@code other} stats given specified * threshold definition. Attributes for which {@code threshold} has a negative value * will not be considered (other than that the {@code threshold} attributes are ignored). * Attributes are compared in the following order: * <ol> * <li>{@link #maxCliqueCardinality()} * <li>{@link #addedEdgeWeight()} * <li>{@link #maxCliqueSize()} * <li>{@link #addedEdges()} * </ol> * The first of these that are not equal will be used for the comparison. */ public int compareTo(Stats other, Stats threshold) { long diff = 0; if (threshold._maxCliqueCardinality >= 0) { diff = _maxCliqueCardinality - other._maxCliqueCardinality; } if (diff == 0 && threshold._addedEdgeWeight >= 0) { diff = _addedEdgeWeight - other._addedEdgeWeight; } if (diff == 0 && threshold._maxClique >= 0) { diff = _maxClique - other._maxClique; } if (diff == 0 && threshold._addedEdges >= 0) { diff = _addedEdges - other._addedEdges; } return Long.signum(diff); } /** * True if these statistics satisfy the given threshold statistics. * <p> * Specifically, compares the values of the following attributes: * <ul> * <li>{@link #addedEdges()} * <li>{@link #addedEdgeWeight()} * <li>{@link #maxCliqueSize()} * <li>{@link #maxCliqueCardinality()} * </ul> * If for each these attribute of {@code threshold} that have a non-negative value, the * current object has value that is less than or equal to the threshold value, then the * threshold is satisfied. */ public boolean meetsThreshold(Stats threshold) { return (threshold._addedEdges < 0 || threshold._addedEdges >= _addedEdges) && (threshold._addedEdgeWeight < 0 || threshold._addedEdgeWeight >= _addedEdgeWeight) && (threshold._maxClique < 0 || threshold._maxClique >= _maxClique) && (threshold._maxCliqueCardinality < 0 || threshold._maxCliqueCardinality >= _maxCliqueCardinality) ; } /*------------ * Attributes */ /** * The number of edges that were added during the execution of the algorithm. */ public int addedEdges() { return _addedEdges; } /** * Sets value of {@link #addedEdges()} and returns this object. */ public Stats addedEdges(int edges) { _addedEdges = edges; return this; } /** * The total weight of edges that were added during the execution of the algorithm where * the weight is defined as the product of the cardinality of its variables. */ public long addedEdgeWeight() { return _addedEdgeWeight; } /** * Sets value of {@link #addedEdgeWeight()} and returns this object. */ public Stats addedEdgeWeight(long weight) { _addedEdgeWeight = weight; return this; } /** * The number of variables that were eliminated by conditioning. That is, the number of variables * with a fixed value when the variable eliminator is using conditioning. * <p> * Note: this attribute is not used by {@link #compareTo} or {@link #meetsThreshold}. * <p> * @see Variable#hasFixedValue() * @see VariableEliminator#usesConditioning() */ public int conditionedVariables() { return _conditionedVariables; } /** * Sets value of {@link #conditionedVariables()} and returns this object. */ public Stats conditionedVariables(int n) { _conditionedVariables = n; return this; } /** * The cost function used to generate these stats, if from {@link OrderIterator}. */ public @Nullable CostFunction cost() { return _cost; } /** * The number of factors with more than one edge to the same variable. * <p> * Note: this attribute is not used by {@link #compareTo} or {@link #meetsThreshold}. */ public int factorsWithDuplicateVariables() { return _factorsWithDuplicateVariables; } /** * Sets value of {@link #factorsWithDuplicateVariables()} and returns this object. */ public Stats factorsWithDuplicateVariables(int n) { _factorsWithDuplicateVariables = n; return this; } /** * Returns the size of the largest clique induced by the execution of the algorithm. * The clique size is determined when a variable is eliminated and is equivalent to * the number of non-eliminated neighbors of the variable plus one (for the variable itself). */ public int maxCliqueSize() { return _maxClique; } /** * Sets value of {@link #maxCliqueSize()} and returns this object. */ public Stats maxCliqueSize(int size) { _maxClique = size; return this; } /** * Returns the cardinality of the largest clique induced by the execution of the algorithm. * Like {@link #maxCliqueSize()} but instead of the number of variables in the clique, this * is based on the product of the cardinality of the variables in the clique. */ public long maxCliqueCardinality() { return _maxCliqueCardinality; } /** * Sets value of {@link #maxCliqueCardinality()} and returns this object. */ public Stats maxCliqueCardinality(long cardinality) { _maxCliqueCardinality = cardinality; return this; } /** * The number of factors that would need to be merged into other factors. */ public int mergedFactors() { return _mergedFactors; } /** * Sets value of {@link #mergedFactors()} and returns this object. */ public Stats mergedFactors(int n) { _mergedFactors = n; return this; } /** * The number of variables that are connected to another variable through more than one factor. * <p> * Note: this attribute is not used by {@link #compareTo} or {@link #meetsThreshold}. */ public int variablesWithDuplicateEdges() { return _variablesWithDuplicateEdges; } /** * Sets value of {@link #variablesWithDuplicateEdges()} and returns this object. */ public Stats variablesWithDuplicateEdges(int n) { _variablesWithDuplicateEdges = n; return this; } /*----------------- * Private methods */ private void addEdgeWeight(long weight) { ++_addedEdges; _addedEdgeWeight += weight; } private void addClique(Var var, long cardinality) { final int size = 1 + var.nNeighbors(); _maxClique = Math.max(_maxClique, size); _maxCliqueCardinality = Math.max(_maxCliqueCardinality, cardinality); final Variable variable = var._variable; final int nFactors = variable.getSiblingCount(); if (nFactors > 1) { // If there is more than one factor whose variables are wholly contained by this clique, // they will need to be merged. final Set<Variable> variables = Sets.newHashSetWithExpectedSize(size); variables.add(variable); for (VarLink link = var._neighborList._next; link.hasVar(); link = link._next) { variables.add(link.var()._variable); } int nCliqueFactors = 0; nextFactor: for (int i = 0; i < nFactors; ++i) { final Factor factor = variable.getSibling(i); for (int j = 0, nFactorVars = factor.getSiblingCount(); j < nFactorVars; ++j) { if (!variables.contains(factor.getSibling(j))) { // Factor is not entirely contained by this clique. continue nextFactor; } } ++nCliqueFactors; } if (nCliqueFactors > 1) { _mergedFactors += nCliqueFactors; } } } private void addConditionedVariable() { ++_conditionedVariables; } private void addFactorWithDuplicateVars(FactorBase factor) { ++_factorsWithDuplicateVariables; } private void addVariableWithDuplicateEdges(Variable variable) { ++_variablesWithDuplicateEdges; } } // Stats /** * Holds information about a single variable for use by variable eliminator. * <p> * Public methods are available for use by {@link CostFunction} implementations. */ public static class Var { final Variable _variable; final VarLink _neighborList = new VarLink(); final Map<Var, VarLink> _neighborMap; /** * Pointer to heap entry for this object for use in efficient reprioritization. */ @Nullable IEntry<Var> _heapEntry = null; /** * Can be set to a value in the range [0.0 and 1.0) to be used by * Prioritizer to break to randomly order elements with the same * integer priority. */ final double _incrementalCost; final boolean _isConditioned; /*-------------- * Construction */ private Var(Variable variable, double incrementalCost, boolean isConditioned) { _variable = variable; _incrementalCost = incrementalCost; _neighborMap = new HashMap<Var, VarLink>(variable.getSiblingCount()); _isConditioned = isConditioned; } /*---------------- * Object methods */ @Override public String toString() { return _variable.getName(); } private boolean addNeighbor(Var neighbor) { if (neighbor != this && !_neighborMap.containsKey(neighbor)) { VarLink link = new VarLink(neighbor); _neighborMap.put(neighbor, link); link.insertBefore(_neighborList); return true; } return false; } private double adjustedCost(CostFunction costFunction) { return costFunction.cost(this) + _incrementalCost; } /** * The cardinality of the underlying variable's domain, assumed to be discrete. */ public int cardinality() { return requireNonNull(_variable.getDomain().asDiscrete()).size(); } /** * Start of linked list of variable neighbors. */ public VarLink firstNeighbor() { return _neighborList._next; } /** * True if {@code other} variable neighbors this one (i.e. if both are connected to the same factor). */ public boolean isAdjacent(Var other) { return _neighborMap.containsKey(other); } /** * True if conditioning has been enabled and the variable has a fixed value. */ public boolean isConditioned() { return _isConditioned; } /** * The number of neighbor variables. */ public int nNeighbors() { return _neighborMap.size(); } /** * The underlying variable. */ public Variable variable() { return _variable; } private void removeNeighbor(Var neighbor) { _neighborMap.remove(neighbor).remove(); } } /** * A node in a linked list of {@link Var} entries. */ public static final class VarLink { private final @Nullable Var _var; private VarLink _prev = this; private VarLink _next = this; VarLink(Var info) { _var = info; } VarLink() { _var = null; } public boolean hasVar() { return _var != null; } /** * Refers to the next link. */ public VarLink next() { return _next; } /** * The {@link Var} object for this link. */ public Var var() { return Objects.requireNonNull(_var); } void insertBefore(VarLink next) { _next = next; _prev = next._prev; next._prev = this; _prev._next = this; } void remove() { _prev._next = _next; _next._prev = _prev; _next = this; _prev = this; } } /*------------------------------- * Cost function implementations */ public static abstract class CostFunction implements Serializable { private static final long serialVersionUID = 1L; private @Nullable VariableCost _type = null; protected CostFunction() { } /** * Replace with canonical instance if there is one when deserializing. * @since 0.07 */ protected Object readResolve() { VariableCost type = _type; return type != null ? type.function() : this; } final double cost(Var var) { final int nNeighbors = var.nNeighbors(); // It is always better to first eliminate variables connected by no more // than one edge because their removal will not expand the tree width. switch (nNeighbors) { case 0: // Return -2 if conditioned, or -1 for other variables with no edges // (which is unlikely). This ensures that conditioned variables will // always come first in the elimination order. return var._isConditioned ? -2 : -1; case 1: // Return 0 if there is only one edge. return 0; default: break; } return computeCost(var); } /** * Computes cost in range [0.0,infinity] for {@code var}. Lower cost variables will be * eliminated before higher cost ones. */ public abstract double computeCost(Var var); /** * True if evaluation only depends on immediate neighbors. */ public abstract boolean neighborsOnly(); /** * If this is a standard built-in cost function, returns its corresponding descriptor, * otherwise returns null. */ public final @Nullable VariableCost type() { return _type; } } /** * Cost is the number of neighbors of the variable in the current graph. * <p> * Get instance from {@link VariableCost#MIN_NEIGHBORS}. */ public static class MinNeighbors extends CostFunction { private static final long serialVersionUID = 1L; private MinNeighbors() { } @Override public double computeCost(Var var) { return var.nNeighbors(); } @Override public boolean neighborsOnly() { return true; } } /** * Cost is the product of the domain cardinalities of all of the neighboring * variables in the current graph. * <p> * Get instance from {@link VariableCost#WEIGHTED_MIN_NEIGHBORS}. */ public static class MinWeight extends CostFunction { private static final long serialVersionUID = 1L; private MinWeight() { } @Override public double computeCost(Var var) { double weight = 1.0; for (VarLink link = var._neighborList._next; link.hasVar(); link = link._next) { weight *= link.var().cardinality(); } return weight; } @Override public boolean neighborsOnly() { return true; } } /** * Cost is the number of edges that would be added if this variable were to be eliminated * from the current graph, i.e the number of unique neighbor variable pairs that are not * already adjacent to each other. * <p> * Get instance from {@link VariableCost#MIN_FILL}. */ public static class MinFill extends CostFunction { private static final long serialVersionUID = 1L; private MinFill() { } @Override public double computeCost(Var var) { double count = 0.0; for (VarLink link1 = var._neighborList._next; link1.hasVar(); link1 = link1._next) { final Var neighbor1 = link1.var(); for (VarLink link2 = link1._next; link2.hasVar(); link2 = link2._next) { final Var neighbor2 = link2.var(); if (!neighbor1.isAdjacent(neighbor2)) { ++count; } } } return count; } @Override public boolean neighborsOnly() { return false; } } /** * Similar to {@link MinFill} but instead of counting edges that would be added, it * counts the sum of the weights of added edges where the weight is the product of * the domain cardinalities at each end. * <p> * Get instance from {@link VariableCost#WEIGHTED_MIN_FILL}. */ public static class WeightedMinFill extends CostFunction { private static final long serialVersionUID = 1L; private WeightedMinFill() { } @Override public double computeCost(Var var) { double weight = 0.0; for (VarLink link1 = var._neighborList._next; link1.hasVar(); link1 = link1._next) { final Var neighbor1 = link1.var(); for (VarLink link2 = link1._next; link2.hasVar(); link2 = link2._next) { final Var neighbor2 = link2.var(); if (!neighbor1.isAdjacent(neighbor2)) { weight += neighbor1.cardinality() * neighbor2.cardinality(); } } } return weight; } @Override public boolean neighborsOnly() { return false; } } /*----------------- * Private methods */ private List<Var> buildAdjacencyList(Stats stats) { final List<Var> list = new LinkedList<Var>(); final VariableList variables = _model.getVariables(); final Map<Variable,Var> map = new LinkedHashMap<Variable,Var>(variables.size()); for (Variable variable : variables) { if (!variable.getDomain().isDiscrete() && !isConditioned(variable)) { throw new DimpleException("VariableEliminator cannot handle non-discrete variable '%s'", variable); } Var var = new Var(variable, generateCostIncrement(variable), isConditioned(variable)); map.put(variable, var); list.add(var); variable.clearMarked(); } final FactorList factors = _model.getFactors(); for (Factor factor : factors) { factor.clearMarked(); } Set<Factor> factorsWithDuplicateVars = new HashSet<Factor>(); Set<Variable> variablesWithDuplicateEdges = new HashSet<Variable>(); for (Var var : map.values()) { if (var._isConditioned) continue; final Variable variable = var._variable; for (int fi = 0, nFactors = variable.getSiblingCount(); fi < nFactors; ++fi) { final Factor factor = variable.getSibling(fi); if (factor.isMarked()) { factorsWithDuplicateVars.add(factor); continue; } factor.setMarked(); for (int vi = 0, nVariables = factor.getSiblingCount(); vi < nVariables; ++vi) { final Variable neighborVariable = factor.getSibling(vi); if (neighborVariable == variable) continue; final Var neighborVar = map.get(neighborVariable); if (neighborVar._isConditioned) continue; if (neighborVariable.isMarked()) { variablesWithDuplicateEdges.add(variable); } else { neighborVariable.setMarked(); var.addNeighbor(map.get(neighborVariable)); } } } // Reset marks for visited factors and variables. for (int fi = 0, nFactors = variable.getSiblingCount(); fi < nFactors; ++fi) { final Factor factor = variable.getSibling(fi); factor.clearMarked(); } for (VarLink link = var._neighborList._next; link.hasVar(); link = link._next) { link.var()._variable.clearMarked(); } } for (Factor factor : factorsWithDuplicateVars) { stats.addFactorWithDuplicateVars(factor); } for (Variable variable : variablesWithDuplicateEdges) { stats.addVariableWithDuplicateEdges(variable); } return list; } /** * Generates a cost-increment in the range [0, 1) to break ties between * variables with same integer cost. */ private double generateCostIncrement(Variable variable) { final Random rand = _rand; if (rand == null) { return (double)variable.getGraphTreeId() / (double)Integer.MAX_VALUE; } else { return rand.nextDouble(); } } private boolean isConditioned(Variable variable) { return _useConditioning && variable.hasFixedValue(); } }