package edu.stanford.nlp.loglinear.inference;
import edu.stanford.nlp.util.logging.Redwood;
import edu.stanford.nlp.loglinear.model.ConcatVector;
import edu.stanford.nlp.loglinear.model.GraphicalModel;
import java.util.*;
/**
* Created on 8/11/15.
* @author keenon
* <p>
* This is instantiated once per model, so that it can keep caches of important stuff like messages and
* local factors during many game playing sample steps. It assumes that the model that is passed in is by-reference,
* and that it can change between inference calls in small ways, so that cacheing of some results is worthwhile.
*/
public class CliqueTree {
/** A logger for this class */
private static final Redwood.RedwoodChannels log = Redwood.channels(CliqueTree.class);
private GraphicalModel model;
private ConcatVector weights;
// This is the metadata key for the model to store an observed value for a variable, as an int
public static final String VARIABLE_OBSERVED_VALUE = "inference.CliqueTree.VARIABLE_OBSERVED_VALUE";
private static final boolean CACHE_MESSAGES = true;
/**
* Create an Inference object for a given set of weights, and a model.
* <p>
* The object is around to facilitate cacheing as an eventual optimization, when models are changing in minor ways
* and inference is required several times. Work is done lazily, so is left until actual inference is requested.
*
* @param model the model to be computed over, subject to change in the future
* @param weights the weights to dot product with model features to get log-linear factors, is cloned internally so
* that no changes to the weights vector will be reflected by the CliqueTree. If you want to change
* the weights, you must create a new CliqueTree.
*/
public CliqueTree(GraphicalModel model, ConcatVector weights) {
this.model = model;
this.weights = weights.deepClone();
}
/**
* Little data structure for passing around the results of marginal computations.
*/
public static class MarginalResult {
public double[][] marginals;
public double partitionFunction;
public Map<GraphicalModel.Factor, TableFactor> jointMarginals;
public MarginalResult(double[][] marginals, double partitionFunction, Map<GraphicalModel.Factor, TableFactor> jointMarginals) {
this.marginals = marginals;
this.partitionFunction = partitionFunction;
this.jointMarginals = jointMarginals;
}
}
/**
* This assumes that factors represent joint probabilities.
*
* @return global marginals
*/
public MarginalResult calculateMarginals() {
return messagePassing(MarginalizationMethod.SUM, true);
}
/**
* This will calculate marginals, but skip the stuff that is created for gradient descent: joint marginals and
* partition functions. This makes it much faster. It is thus appropriate for gameplayer style work, where many
* samples need to be drawn with the same marginals.
*
* @return an array, indexed first by variable, then by variable assignment, of global probability
*/
public double[][] calculateMarginalsJustSingletons() {
MarginalResult result = messagePassing(MarginalizationMethod.SUM, false);
return result.marginals;
}
/**
* This assumes that factors represent joint probabilities.
*
* @return an array, indexed by variable, of maximum likelihood assignments
*/
public int[] calculateMAP() {
double[][] mapMarginals = messagePassing(MarginalizationMethod.MAX, false).marginals;
int[] result = new int[mapMarginals.length];
for (int i = 0; i < result.length; i++) {
if (mapMarginals[i] != null) {
for (int j = 0; j < mapMarginals[i].length; j++) {
if (mapMarginals[i][j] > mapMarginals[i][result[i]]) {
result[i] = j;
}
}
}
// If there is no factor touching an observed variable, the resulting MAP won't reference the variable
// observation since message passing won't touch the variable index
if (model.getVariableMetaDataByReference(i).containsKey(VARIABLE_OBSERVED_VALUE)) {
result[i] = Integer.parseInt(model.getVariableMetaDataByReference(i).get(VARIABLE_OBSERVED_VALUE));
}
}
return result;
}
////////////////////////////////////////////////////////////////////////////
// PRIVATE IMPLEMENTATION
////////////////////////////////////////////////////////////////////////////
private enum MarginalizationMethod {
SUM,
MAX
}
// OPTIMIZATION:
// cache the creation of TableFactors, to avoid redundant dot products
private IdentityHashMap<GraphicalModel.Factor, CachedFactorWithObservations> cachedFactors = new IdentityHashMap<>();
private static class CachedFactorWithObservations {
TableFactor cachedFactor;
int[] observations;
boolean impossibleObservation;
}
// OPTIMIZATION:
// cache the last list of factors, and the last set of messages passed, in case we can recycle some
private TableFactor[] cachedCliqueList;
private TableFactor[][] cachedMessages;
private boolean[][] cachedBackwardPassedMessages;
/**
* Does tree shaped message passing. The algorithm calls for first passing down to the leaves, then passing back up
* to the root.
*
* @param marginalize the method for marginalization, controls MAP or marginals
* @return the marginal messages
*/
private MarginalResult messagePassing(MarginalizationMethod marginalize, boolean includeJointMarginalsAndPartition) {
// Using the behavior of brute force factor multiplication as ground truth, the desired
// outcome of marginal calculation with an impossible factor is a uniform probability dist.,
// since we have a resulting factor of all 0s. That is of course assuming that normalizing
// all 0s gives you uniform, which is not real math, but that's a useful tolerance to include, so we do.
boolean impossibleObservationMade = false;
// Message passing will look at fully observed cliques as non-entities, but their
// log-likelihood (the log-likelihood of the single observed value) is still relevant for the
// partition function.
double partitionFunction = 1.0;
if (includeJointMarginalsAndPartition) {
outer:
for (GraphicalModel.Factor f : model.factors) {
for (int n : f.neigborIndices) {
if (!model.getVariableMetaDataByReference(n).containsKey(VARIABLE_OBSERVED_VALUE)) continue outer;
}
int[] assignment = new int[f.neigborIndices.length];
for (int i = 0; i < f.neigborIndices.length; i++) {
assignment[i] = Integer.parseInt(model.getVariableMetaDataByReference(f.neigborIndices[i]).get(VARIABLE_OBSERVED_VALUE));
}
double assignmentValue = f.featuresTable.getAssignmentValue(assignment).get().dotProduct(weights);
if (Double.isInfinite(assignmentValue)) {
impossibleObservationMade = true;
} else {
partitionFunction *= Math.exp(assignmentValue);
}
}
}
// Create the cliques by multiplying out table factors
// TODO:OPT This could be made more efficient by observing first, then dot product
List<TableFactor> cliquesList = new ArrayList<>();
Map<Integer, GraphicalModel.Factor> cliqueToFactor = new HashMap<>();
int numFactorsCached = 0;
for (GraphicalModel.Factor f : model.factors) {
boolean allObserved = true;
int maxVar = 0;
for (int n : f.neigborIndices) {
if (!model.getVariableMetaDataByReference(n).containsKey(VARIABLE_OBSERVED_VALUE)) allObserved = false;
if (n > maxVar) maxVar = n;
}
if (allObserved) continue;
TableFactor clique = null;
// Retrieve cache if exists and none of the observations have changed
if (cachedFactors.containsKey(f)) {
CachedFactorWithObservations obs = cachedFactors.get(f);
boolean allConsistent = true;
for (int i = 0; i < f.neigborIndices.length; i++) {
int n = f.neigborIndices[i];
if (model.getVariableMetaDataByReference(n).containsKey(VARIABLE_OBSERVED_VALUE) &&
(obs.observations[i] == -1 ||
Integer.parseInt(model.getVariableMetaDataByReference(n).get(VARIABLE_OBSERVED_VALUE)) != obs.observations[i])) {
allConsistent = false;
break;
}
// NOTE: This disqualifies lots of stuff for some reason...
if (!model.getVariableMetaDataByReference(n).containsKey(VARIABLE_OBSERVED_VALUE) && (obs.observations[i] != -1)) {
allConsistent = false;
break;
}
}
if (allConsistent) {
clique = obs.cachedFactor;
numFactorsCached++;
if (obs.impossibleObservation) {
impossibleObservationMade = true;
}
}
}
// Otherwise make a new cache
if (clique == null) {
int[] observations = new int[f.neigborIndices.length];
for (int i = 0; i < observations.length; i++) {
Map<String, String> metadata = model.getVariableMetaDataByReference(f.neigborIndices[i]);
if (metadata.containsKey(VARIABLE_OBSERVED_VALUE)) {
int value = Integer.parseInt(metadata.get(VARIABLE_OBSERVED_VALUE));
observations[i] = value;
} else {
observations[i] = -1;
}
}
clique = new TableFactor(weights, f, observations);
CachedFactorWithObservations cache = new CachedFactorWithObservations();
cache.cachedFactor = clique;
cache.observations = observations;
// Check for an impossible observation
boolean nonZeroValue = false;
for (int[] assignment : clique) {
if (clique.getAssignmentValue(assignment) > 0) {
nonZeroValue = true;
break;
}
}
if (!nonZeroValue) {
impossibleObservationMade = true;
cache.impossibleObservation = true;
}
cachedFactors.put(f, cache);
}
cliqueToFactor.put(cliquesList.size(), f);
cliquesList.add(clique);
}
TableFactor[] cliques = cliquesList.toArray(new TableFactor[cliquesList.size()]);
// If we made any impossible observations, we can just return a uniform distribution for all the variables that
// weren't observed, since that's the semantically correct thing to do (our 'probability' is broken at this
// point).
if (impossibleObservationMade) {
int maxVar = 0;
for (TableFactor c : cliques) {
for (int i : c.neighborIndices) if (i > maxVar) maxVar = i;
}
double[][] result = new double[maxVar + 1][];
for (TableFactor c : cliques) {
for (int i = 0; i < c.neighborIndices.length; i++) {
result[c.neighborIndices[i]] = new double[c.getDimensions()[i]];
for (int j = 0; j < result[c.neighborIndices[i]].length; j++) {
result[c.neighborIndices[i]][j] = 1.0 / result[c.neighborIndices[i]].length;
}
}
}
// Create a bunch of uniform joint marginals, constrained by observations, and fill up the joint marginals
// with them
Map<GraphicalModel.Factor, TableFactor> jointMarginals = new IdentityHashMap<>();
if (includeJointMarginalsAndPartition) {
for (GraphicalModel.Factor f : model.factors) {
TableFactor uniformZero = new TableFactor(f.neigborIndices, f.featuresTable.getDimensions());
for (int[] assignment : uniformZero) {
uniformZero.setAssignmentValue(assignment, 0.0);
}
jointMarginals.put(f, uniformZero);
}
}
return new MarginalResult(result, 1.0, jointMarginals);
}
// Find the largest contained variable, so that we can size arrays appropriately
int maxVar = 0;
for (GraphicalModel.Factor fac : model.factors) {
for (int i : fac.neigborIndices) if (i > maxVar) maxVar = i;
}
// Indexed by (start-clique, end-clique), this array will remain mostly null in most graphs
TableFactor[][] messages = new TableFactor[cliques.length][cliques.length];
// OPTIMIZATION:
// check if we've only added one factor since the last time we ran marginal inference. If that's the case, we
// can use the new factor as the root, all the messages passed in from the leaves will not have changed. That
// means we can cut message passing computation in half.
boolean[][] backwardPassedMessages = new boolean[cliques.length][cliques.length];
int forceRootForCachedMessagePassing = -1;
int[] cachedCliquesBackPointers = null;
if (CACHE_MESSAGES && (numFactorsCached == cliques.length - 1) && (numFactorsCached > 0)) {
cachedCliquesBackPointers = new int[cliques.length];
// Sometimes we'll have cached versions of the factors, but they're from inference steps a long time ago, so we
// don't get consistent backpointers to our cache of factors. This is a flag to indicate if this happens.
boolean backPointersConsistent = true;
// Calculate the correspondence between the old cliques list and the new cliques list
for (int i = 0; i < cliques.length; i++) {
cachedCliquesBackPointers[i] = -1;
for (int j = 0; j < cachedCliqueList.length; j++) {
if (cliques[i] == cachedCliqueList[j]) {
cachedCliquesBackPointers[i] = j;
break;
}
}
if (cachedCliquesBackPointers[i] == -1) {
if (forceRootForCachedMessagePassing != -1) {
backPointersConsistent = false;
break;
}
forceRootForCachedMessagePassing = i;
}
}
if (!backPointersConsistent) forceRootForCachedMessagePassing = -1;
}
// Create the data structures to hold the tree pattern
boolean[] visited = new boolean[cliques.length];
int numVisited = 0;
int[] visitedOrder = new int[cliques.length];
int[] parent = new int[cliques.length];
for (int i = 0; i < parent.length; i++) parent[i] = -1;
// Figure out which cliques are connected to which trees. This is important for calculating the partition
// function later, since each tree will converge to its own partition function by multiplication, and we will
// need to multiply the partition function of each of the trees to get the global one.
int[] trees = new int[cliques.length];
// Forward pass, record a BFS forest pattern that we can use for message passing
int treeIndex = -1;
boolean[] seenVariable = new boolean[maxVar + 1];
while (numVisited < cliques.length) {
treeIndex++;
// Pick the largest connected graph remaining as the root for message passing
int root = -1;
// OPTIMIZATION: if there's a forced root for message passing (a node that we just added) then make it the
// root
if (CACHE_MESSAGES && forceRootForCachedMessagePassing != -1 && !visited[forceRootForCachedMessagePassing]) {
root = forceRootForCachedMessagePassing;
} else {
for (int i = 0; i < cliques.length; i++) {
if (!visited[i] &&
(root == -1 || cliques[i].neighborIndices.length > cliques[root].neighborIndices.length)) {
root = i;
}
}
}
assert (root != -1);
Queue<Integer> toVisit = new ArrayDeque<>();
toVisit.add(root);
boolean[] toVisitArray = new boolean[cliques.length];
toVisitArray[root] = true;
while (toVisit.size() > 0) {
int cursor = toVisit.poll();
// toVisitArray[cursor] = false;
trees[cursor] = treeIndex;
if (visited[cursor]) {
log.info("Visited contains: " + cursor);
log.info("Visited: " + Arrays.toString(visited));
log.info("To visit: " + toVisit);
}
assert (!visited[cursor]);
visited[cursor] = true;
visitedOrder[numVisited] = cursor;
for (int i : cliques[cursor].neighborIndices) seenVariable[i] = true;
numVisited++;
childLoop:
for (int i = 0; i < cliques.length; i++) {
if (i == cursor) continue;
if (i == parent[cursor]) continue;
if (domainsOverlap(cliques[cursor], cliques[i])) {
// Make sure that for every variable that we've already seen somewhere in the graph, if it's
// in the child, it's in the parent. Otherwise we'll break the property of continuous
// transmission of information about variables through messages.
childNeighborLoop:
for (int child : cliques[i].neighborIndices) {
if (seenVariable[child]) {
for (int j : cliques[cursor].neighborIndices) {
if (j == child) {
continue childNeighborLoop;
}
}
// If we get here it means that this clique is not good as a child, since we can't pass
// it all the information it needs from other elements of the tree
continue childLoop;
}
}
if (parent[i] == -1 && !visited[i]) {
if (!toVisitArray[i]) {
toVisit.add(i);
toVisitArray[i] = true;
for (int j : cliques[i].neighborIndices) seenVariable[j] = true;
}
parent[i] = cursor;
}
}
}
}
// No cycles in the tree
assert (parent[root] == -1);
}
assert (numVisited == cliques.length);
// Backward pass, run the visited list in reverse
for (int i = numVisited - 1; i >= 0; i--) {
int cursor = visitedOrder[i];
if (parent[cursor] == -1) continue;
backwardPassedMessages[cursor][parent[cursor]] = true;
// OPTIMIZATION:
// if these conditions are met we can avoid calculating the message, and instead retrieve from the cache,
// since they should be the same
if (CACHE_MESSAGES
&& forceRootForCachedMessagePassing != -1
&& cachedCliquesBackPointers[cursor] != -1
&& cachedCliquesBackPointers[parent[cursor]] != -1
&& cachedMessages[cachedCliquesBackPointers[cursor]][cachedCliquesBackPointers[parent[cursor]]] != null
&& cachedBackwardPassedMessages[cachedCliquesBackPointers[cursor]][cachedCliquesBackPointers[parent[cursor]]]) {
messages[cursor][parent[cursor]] =
cachedMessages[cachedCliquesBackPointers[cursor]][cachedCliquesBackPointers[parent[cursor]]];
} else {
// Calculate the message to the clique's parent, given all incoming messages so far
TableFactor message = cliques[cursor];
for (int k = 0; k < cliques.length; k++) {
if (k == parent[cursor]) continue;
if (messages[k][cursor] != null) {
message = message.multiply(messages[k][cursor]);
}
}
messages[cursor][parent[cursor]] = marginalizeMessage(message, cliques[parent[cursor]].neighborIndices, marginalize);
// Invalidate any cached outgoing messages
if (CACHE_MESSAGES
&& forceRootForCachedMessagePassing != -1
&& cachedCliquesBackPointers[parent[cursor]] != -1) {
for (int k = 0; k < cachedCliqueList.length; k++) {
cachedMessages[cachedCliquesBackPointers[parent[cursor]]][k] = null;
}
}
}
}
// Forward pass, run the visited list forward
for (int i = 0; i < numVisited; i++) {
int cursor = visitedOrder[i];
for (int j = 0; j < cliques.length; j++) {
if (parent[j] != cursor) continue;
TableFactor message = cliques[cursor];
for (int k = 0; k < cliques.length; k++) {
if (k == j) continue;
if (messages[k][cursor] != null) {
message = message.multiply(messages[k][cursor]);
}
}
messages[cursor][j] = marginalizeMessage(message, cliques[j].neighborIndices, marginalize);
}
}
// OPTIMIZATION:
// cache the messages, and the current list of cliques
if (CACHE_MESSAGES) {
cachedCliqueList = cliques;
cachedMessages = messages;
cachedBackwardPassedMessages = backwardPassedMessages;
}
// Calculate final marginals for each variable
double[][] marginals = new double[maxVar + 1][];
// Include observed variables as deterministic
for (GraphicalModel.Factor fac : model.factors) {
for (int i = 0; i < fac.neigborIndices.length; i++) {
int n = fac.neigborIndices[i];
if (model.getVariableMetaDataByReference(n).containsKey(VARIABLE_OBSERVED_VALUE)) {
double[] deterministic = new double[fac.featuresTable.getDimensions()[i]];
int assignment = Integer.parseInt(model.getVariableMetaDataByReference(n).get(VARIABLE_OBSERVED_VALUE));
if (assignment > deterministic.length) {
throw new IllegalStateException("Variable " + n + ": Can't have as assignment (" + assignment + ") that is out of bounds for dimension size (" + deterministic.length + ")");
}
deterministic[assignment] = 1.0;
marginals[n] = deterministic;
}
}
}
Map<GraphicalModel.Factor, TableFactor> jointMarginals = new IdentityHashMap<>();
if (marginalize == MarginalizationMethod.SUM && includeJointMarginalsAndPartition) {
boolean[] partitionIncludesTrees = new boolean[treeIndex + 1];
double[] treePartitionFunctions = new double[treeIndex + 1];
for (int i = 0; i < cliques.length; i++) {
TableFactor convergedClique = cliques[i];
for (int j = 0; j < cliques.length; j++) {
if (i == j) continue;
if (messages[j][i] == null) continue;
convergedClique = convergedClique.multiply(messages[j][i]);
}
// Calculate the partition function when we're calculating marginals
// We need one contribution per tree in our forest graph
if (!partitionIncludesTrees[trees[i]]) {
partitionIncludesTrees[trees[i]] = true;
treePartitionFunctions[trees[i]] = convergedClique.valueSum();
partitionFunction *= treePartitionFunctions[trees[i]];
} else {
// This is all just an elaborate assert
// Check that our partition function is the same as the trees we're attached to, or with %.1, for numerical reasons.
// Sometimes the partition function will explode in value, which can make a non-%-based assert worthless here
if (assertsEnabled() && !TableFactor.USE_EXP_APPROX) {
double valueSum = convergedClique.valueSum();
if (Double.isFinite(valueSum) && Double.isFinite(treePartitionFunctions[trees[i]])) {
if (Math.abs(treePartitionFunctions[trees[i]] - valueSum) >= 1.0e-3 * treePartitionFunctions[trees[i]]) {
log.info("Different partition functions for tree " + trees[i] + ": ");
log.info("Pre-existing for tree: " + treePartitionFunctions[trees[i]]);
log.info("This clique for tree: " + valueSum);
}
assert (Math.abs(treePartitionFunctions[trees[i]] - valueSum) < 1.0e-3 * treePartitionFunctions[trees[i]]);
}
}
}
// Calculate the factor this clique corresponds to, and put in an entry for joint marginals
GraphicalModel.Factor f = cliqueToFactor.get(i);
assert (f != null);
if (!jointMarginals.containsKey(f)) {
int[] observedAssignments = getObservedAssignments(f);
// Collect back pointers and check if this factor matches the clique we're using
int[] backPointers = new int[observedAssignments.length];
int cursor = 0;
for (int j = 0; j < observedAssignments.length; j++) {
if (observedAssignments[j] == -1) {
backPointers[j] = cursor;
cursor++;
}
// This is not strictly necessary but will trigger array OOB exception if things go wrong, so is nice
else backPointers[j] = -1;
}
double sum = convergedClique.valueSum();
TableFactor jointMarginal = new TableFactor(f.neigborIndices, f.featuresTable.getDimensions());
// OPTIMIZATION:
// Rather than use the standard iterator, which creates lots of int[] arrays on the heap, which need to be GC'd,
// we use the fast version that just mutates one array. Since this is read once for us here, this is ideal.
Iterator<int[]> fastPassByReferenceIterator = convergedClique.fastPassByReferenceIterator();
int[] assignment = fastPassByReferenceIterator.next();
while (true) {
if (backPointers.length == assignment.length) {
jointMarginal.setAssignmentValue(assignment, convergedClique.getAssignmentValue(assignment) / sum);
} else {
int[] jointAssignment = new int[backPointers.length];
for (int j = 0; j < jointAssignment.length; j++) {
if (observedAssignments[j] != -1) jointAssignment[j] = observedAssignments[j];
else jointAssignment[j] = assignment[backPointers[j]];
}
jointMarginal.setAssignmentValue(jointAssignment, convergedClique.getAssignmentValue(assignment) / sum);
}
// Set the assignment arrays correctly
if (fastPassByReferenceIterator.hasNext()) fastPassByReferenceIterator.next();
else break;
}
jointMarginals.put(f, jointMarginal);
}
boolean anyNull = false;
for (int j = 0; j < convergedClique.neighborIndices.length; j++) {
int k = convergedClique.neighborIndices[j];
if (marginals[k] == null) {
anyNull = true;
}
}
if (anyNull) {
double[][] cliqueMarginals = null;
switch (marginalize) {
case SUM:
cliqueMarginals = convergedClique.getSummedMarginals();
break;
case MAX:
cliqueMarginals = convergedClique.getMaxedMarginals();
break;
}
for (int j = 0; j < convergedClique.neighborIndices.length; j++) {
int k = convergedClique.neighborIndices[j];
if (marginals[k] == null) {
marginals[k] = cliqueMarginals[j];
}
}
}
}
}
// If we don't care about joint marginals, we can be careful about not calculating more cliques than we need to,
// by explicitly sorting by which cliques are most profitable to calculate over. In this way we can avoid, in
// the case of a chain CRF, calculating almost half the joint factors.
else {
// First do a pass where we only calculate all-null neighbors
for (int i = 0; i < cliques.length; i++) {
boolean allNull = true;
for (int k : cliques[i].neighborIndices) {
if (marginals[k] != null) allNull = false;
}
if (allNull) {
TableFactor convergedClique = cliques[i];
for (int j = 0; j < cliques.length; j++) {
if (i == j) continue;
if (messages[j][i] == null) continue;
convergedClique = convergedClique.multiply(messages[j][i]);
}
double[][] cliqueMarginals = null;
switch (marginalize) {
case SUM:
cliqueMarginals = convergedClique.getSummedMarginals();
break;
case MAX:
cliqueMarginals = convergedClique.getMaxedMarginals();
break;
}
for (int j = 0; j < convergedClique.neighborIndices.length; j++) {
int k = convergedClique.neighborIndices[j];
if (marginals[k] == null) {
marginals[k] = cliqueMarginals[j];
}
}
}
}
// Now we calculate any remaining cliques with any non-null variables
for (int i = 0; i < cliques.length; i++) {
boolean anyNull = false;
for (int j = 0; j < cliques[i].neighborIndices.length; j++) {
int k = cliques[i].neighborIndices[j];
if (marginals[k] == null) {
anyNull = true;
}
}
if (anyNull) {
TableFactor convergedClique = cliques[i];
for (int j = 0; j < cliques.length; j++) {
if (i == j) continue;
if (messages[j][i] == null) continue;
convergedClique = convergedClique.multiply(messages[j][i]);
}
double[][] cliqueMarginals = null;
switch (marginalize) {
case SUM:
cliqueMarginals = convergedClique.getSummedMarginals();
break;
case MAX:
cliqueMarginals = convergedClique.getMaxedMarginals();
break;
}
for (int j = 0; j < convergedClique.neighborIndices.length; j++) {
int k = convergedClique.neighborIndices[j];
if (marginals[k] == null) {
marginals[k] = cliqueMarginals[j];
}
}
}
}
}
// Add any factors to the joint marginal map that were fully observed and so didn't get cliques
if (marginalize == MarginalizationMethod.SUM && includeJointMarginalsAndPartition) {
for (GraphicalModel.Factor f : model.factors) {
if (!jointMarginals.containsKey(f)) {
// This implies that every variable in the factor is observed. If that's the case, we need to construct
// a one hot TableFactor representing the deterministic distribution.
TableFactor deterministicJointMarginal = new TableFactor(f.neigborIndices, f.featuresTable.getDimensions());
int[] observedAssignment = getObservedAssignments(f);
for (int i : observedAssignment) assert (i != -1);
deterministicJointMarginal.setAssignmentValue(observedAssignment, 1.0);
jointMarginals.put(f, deterministicJointMarginal);
}
}
}
return new MarginalResult(marginals, partitionFunction, jointMarginals);
}
private int[] getObservedAssignments(GraphicalModel.Factor f) {
int[] observedAssignments = new int[f.neigborIndices.length];
for (int i = 0; i < observedAssignments.length; i++) {
if (model.getVariableMetaDataByReference(f.neigborIndices[i]).containsKey(VARIABLE_OBSERVED_VALUE)) {
observedAssignments[i] = Integer.parseInt(model.getVariableMetaDataByReference(f.neigborIndices[i]).get(VARIABLE_OBSERVED_VALUE));
} else observedAssignments[i] = -1;
}
return observedAssignments;
}
/**
* This is a key step in message passing. When we are calculating a message, we want to marginalize out all variables
* not relevant to the recipient of the message. This function does that.
*
* @param message the message to marginalize
* @param relevant the variables that are relevant
* @param marginalize whether to use sum of max marginalization, for marginal or MAP inference
* @return the marginalized message
*/
private static TableFactor marginalizeMessage(TableFactor message, int[] relevant, MarginalizationMethod marginalize) {
TableFactor result = message;
for (int i : message.neighborIndices) {
boolean contains = false;
for (int j : relevant) {
if (i == j) {
contains = true;
break;
}
}
if (!contains) {
switch (marginalize) {
case SUM:
result = result.sumOut(i);
break;
case MAX:
result = result.maxOut(i);
break;
}
}
}
return result;
}
/**
* Just a quick inline to check if two factors have overlapping domains. Since factor neighbor sets are super small,
* this n^2 algorithm is fine.
*
* @param f1 first factor to compare
* @param f2 second factor to compare
* @return whether their domains overlap
*/
private static boolean domainsOverlap(TableFactor f1, TableFactor f2) {
for (int n1 : f1.neighborIndices) {
for (int n2 : f2.neighborIndices) {
if (n1 == n2) return true;
}
}
return false;
}
@SuppressWarnings({"*", "AssertWithSideEffects", "ConstantConditions", "UnusedAssignment"})
private static boolean assertsEnabled() {
boolean assertsEnabled = false;
assert (assertsEnabled = true); // intentional side effect
return assertsEnabled;
}
}