/**
* Copyright (c) 2011 Michael Kutschke. All rights reserved. This program and the accompanying
* materials are made available under the terms of the Eclipse Public License v1.0 which accompanies
* this distribution, and is available at http://www.eclipse.org/legal/epl-v10.html Contributors:
* Michael Kutschke - initial API and implementation.
*/
package org.eclipse.recommenders.jayes.inference.jtree;
import static org.eclipse.recommenders.jayes.util.Pair.newPair;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import org.eclipse.recommenders.internal.jayes.util.ArrayUtils;
import org.eclipse.recommenders.jayes.BayesNet;
import org.eclipse.recommenders.jayes.BayesNode;
import org.eclipse.recommenders.jayes.factor.AbstractFactor;
import org.eclipse.recommenders.jayes.factor.arraywrapper.DoubleArrayWrapper;
import org.eclipse.recommenders.jayes.factor.arraywrapper.IArrayWrapper;
import org.eclipse.recommenders.jayes.inference.AbstractInferer;
import org.eclipse.recommenders.jayes.util.Graph;
import org.eclipse.recommenders.jayes.util.MathUtils;
import org.eclipse.recommenders.jayes.util.NumericalInstabilityException;
import org.eclipse.recommenders.jayes.util.OrderIgnoringPair;
import org.eclipse.recommenders.jayes.util.Pair;
import org.eclipse.recommenders.jayes.util.sharing.CanonicalArrayWrapperManager;
import org.eclipse.recommenders.jayes.util.sharing.CanonicalIntArrayManager;
import org.eclipse.recommenders.jayes.util.triangulation.MinFillIn;
@SuppressWarnings("deprecation")
public class JunctionTreeAlgorithm extends AbstractInferer {
private static final double ONE = 1.0;
private static final double ONE_LOG = 0.0;
protected Map<OrderIgnoringPair<Integer>, AbstractFactor> sepSets;
protected Graph junctionTree;
protected AbstractFactor[] nodePotentials;
protected Map<Pair<Integer, Integer>, int[]> preparedMultiplications;
// mapping from variables to clusters that contain them
protected int[][] concernedClusters;
protected AbstractFactor[] queryFactors;
protected int[][] preparedQueries;
protected boolean[] isBeliefValid;
protected List<Pair<AbstractFactor, IArrayWrapper>> initializations;
protected int[][] queryFactorReverseMapping;
// used for computing evidence collection skip
protected Set<Integer> clustersHavingEvidence;
protected boolean[] isObserved;
protected double[] scratchpad;
protected JunctionTreeBuilder junctionTreeBuilder = JunctionTreeBuilder.forHeuristic(new MinFillIn());
public void setJunctionTreeBuilder(JunctionTreeBuilder bldr) {
this.junctionTreeBuilder = bldr;
}
@Override
public double[] getBeliefs(final BayesNode node) {
if (!beliefsValid) {
beliefsValid = true;
updateBeliefs();
}
final int nodeId = node.getId();
if (!isBeliefValid[nodeId]) {
isBeliefValid[nodeId] = true;
if (!evidence.containsKey(node)) {
validateBelief(nodeId);
} else {
Arrays.fill(beliefs[nodeId], 0);
beliefs[nodeId][node.getOutcomeIndex(evidence.get(node))] = 1;
}
}
return super.getBeliefs(node);
}
private void validateBelief(final int nodeId) {
final AbstractFactor f = queryFactors[nodeId];
// TODO change beliefs to ArrayWrappers
f.sumPrepared(new DoubleArrayWrapper(beliefs[nodeId]), preparedQueries[nodeId]);
if (f.isLogScale()) {
MathUtils.exp(beliefs[nodeId]);
}
try {
beliefs[nodeId] = MathUtils.normalize(beliefs[nodeId]);
} catch (final IllegalArgumentException exception) {
throw new NumericalInstabilityException("Numerical instability detected for evidence: " + evidence
+ " and node : " + nodeId
+ ", consider using logarithmic scale computation (configurable in FactorFactory)", exception);
}
}
@Override
protected void updateBeliefs() {
Arrays.fill(isBeliefValid, false);
doUpdateBeliefs();
}
private void doUpdateBeliefs() {
incorporateAllEvidence();
int propagationRoot = findPropagationRoot();
replayFactorInitializations();
collectEvidence(propagationRoot, skipCollection(propagationRoot));
distributeEvidence(propagationRoot, skipDistribution(propagationRoot));
}
private void replayFactorInitializations() {
for (final Pair<AbstractFactor, IArrayWrapper> init : initializations) {
init.getFirst().copyValues(init.getSecond());
}
}
private void incorporateAllEvidence() {
for (Pair<AbstractFactor, IArrayWrapper> init : initializations) {
init.getFirst().resetSelections();
}
clustersHavingEvidence.clear();
Arrays.fill(isObserved, false);
for (BayesNode n : evidence.keySet()) {
incorporateEvidence(n);
}
}
private void incorporateEvidence(final BayesNode node) {
int n = node.getId();
isObserved[n] = true;
// get evidence to all concerned factors (includes home cluster)
for (final Integer concernedCluster : concernedClusters[n]) {
nodePotentials[concernedCluster].select(n, node.getOutcomeIndex(evidence.get(node)));
clustersHavingEvidence.add(concernedCluster);
}
}
private int findPropagationRoot() {
int propagationRoot = 0;
for (BayesNode n : evidence.keySet()) {
propagationRoot = concernedClusters[n.getId()][0];
}
return propagationRoot;
}
/**
* checks which nodes need not be processed during collectEvidence (because of preprocessing). These are those nodes
* without evidence which are leaves or which only have non-evidence descendants
*
* @param root
* the node to start the check from
* @return a set of the nodes not needing a call of collectEvidence
*/
private Set<Integer> skipCollection(final int root) {
final Set<Integer> skipped = new HashSet<Integer>(nodePotentials.length);
recursiveSkipCollection(root, new HashSet<Integer>(nodePotentials.length), skipped);
return skipped;
}
private void recursiveSkipCollection(final int node, final Set<Integer> visited, final Set<Integer> skipped) {
visited.add(node);
boolean areAllDescendantsSkipped = true;
for (final int neighbor : junctionTree.getNeighbors(node)) {
if (!visited.contains(neighbor)) {
recursiveSkipCollection(neighbor, visited, skipped);
if (!skipped.contains(neighbor)) {
areAllDescendantsSkipped = false;
}
}
}
if (areAllDescendantsSkipped && !clustersHavingEvidence.contains(node)) {
skipped.add(node);
}
}
/**
* checks which nodes do not need to be visited during evidence distribution. These are exactly those nodes which
* are
* <ul>
* <li>not the query factor of a non-evidence variable</li>
* <li>AND have no descendants that cannot be skipped</li>
* </ul>
*
* @param distNode
* @return
*/
private Set<Integer> skipDistribution(final int distNode) {
final Set<Integer> skipped = new HashSet<Integer>(nodePotentials.length);
recursiveSkipDistribution(distNode, new HashSet<Integer>(nodePotentials.length), skipped);
return skipped;
}
private void recursiveSkipDistribution(final int node, final Set<Integer> visited, final Set<Integer> skipped) {
visited.add(node);
boolean areAllDescendantsSkipped = true;
for (final Integer neighbor : junctionTree.getNeighbors(node)) {
if (!visited.contains(neighbor)) {
recursiveSkipDistribution(neighbor, visited, skipped);
if (!skipped.contains(neighbor)) {
areAllDescendantsSkipped = false;
}
}
}
if (areAllDescendantsSkipped && !isQueryFactorOfUnobservedVariable(node)) {
skipped.add(node);
}
}
private boolean isQueryFactorOfUnobservedVariable(final int node) {
for (int i : queryFactorReverseMapping[node]) {
if (!isObserved[i]) {
return true;
}
}
return false;
}
private void collectEvidence(final int cluster, final Set<Integer> marked) {
marked.add(cluster);
for (final int n : junctionTree.getNeighbors(cluster)) {
if (!marked.contains(n)) {
collectEvidence(n, marked);
messagePass(n, cluster);
}
}
}
private void distributeEvidence(final int cluster, final Set<Integer> marked) {
marked.add(cluster);
for (final int n : junctionTree.getNeighbors(cluster)) {
if (!marked.contains(n)) {
messagePass(cluster, n);
distributeEvidence(n, marked);
}
}
}
private void messagePass(final int v1, int v2) {
OrderIgnoringPair<Integer> sepSetEdge = new OrderIgnoringPair<Integer>(v1, v2);
final AbstractFactor sepSet = sepSets.get(sepSetEdge);
if (!needMessagePass(sepSet)) {
return;
}
final IArrayWrapper newSepValues = sepSet.getValues();
System.arraycopy(newSepValues.toDoubleArray(), 0, scratchpad, 0, newSepValues.length());
final int[] preparedOp = preparedMultiplications.get(Pair.newPair(v2, v1));
nodePotentials[sepSetEdge.getFirst()].sumPrepared(newSepValues, preparedOp);
if (isOnlyFirstLogScale(sepSetEdge)) {
MathUtils.exp(newSepValues);
}
if (areBothEndsLogScale(sepSetEdge)) {
MathUtils.secureSubtract(newSepValues.toDoubleArray(), scratchpad, scratchpad);
} else {
MathUtils.secureDivide(newSepValues.toDoubleArray(), scratchpad, scratchpad);
}
if (isOnlySecondLogScale(sepSetEdge)) {
MathUtils.log(scratchpad);
}
// TODO scratchpad -> ArrayWrapper
nodePotentials[sepSetEdge.getSecond()].multiplyPrepared(new DoubleArrayWrapper(scratchpad),
preparedMultiplications.get(Pair.newPair(v1, v2)));
}
/*
* we don't get additional information if all variables in the sepSet are observed, so skip message pass
*/
private boolean needMessagePass(final AbstractFactor sepSet) {
for (final int var : sepSet.getDimensionIDs()) {
if (!isObserved[var]) {
return true;
}
}
return false;
}
private boolean isOnlyFirstLogScale(final OrderIgnoringPair<Integer> edge) {
return nodePotentials[edge.getFirst()].isLogScale() && !nodePotentials[edge.getSecond()].isLogScale();
}
private boolean isOnlySecondLogScale(final OrderIgnoringPair<Integer> edge) {
return !nodePotentials[edge.getFirst()].isLogScale() && nodePotentials[edge.getSecond()].isLogScale();
}
@Override
public void setNetwork(final BayesNet net) {
super.setNetwork(net);
initializeFields(net.getNodes().size());
JunctionTree jtree = buildJunctionTree(net);
Map<AbstractFactor, Integer> homeClusters = computeHomeClusters(net, jtree.getClusters());
initializeClusterFactors(net, jtree.getClusters(), homeClusters);
initializeSepsetFactors(jtree.getSepSets());
determineConcernedClusters();
setQueryFactors();
initializePotentialValues();
multiplyCPTsIntoPotentials(net, homeClusters);
prepareMultiplications();
prepareScratch();
invokeInitialBeliefUpdate();
storePotentialValues();
}
@SuppressWarnings("unchecked")
private void determineConcernedClusters() {
concernedClusters = new int[queryFactors.length][];
List<Integer>[] temp = new List[concernedClusters.length];
for (int i = 0; i < temp.length; i++) {
temp[i] = new ArrayList<Integer>();
}
for (int i = 0; i < nodePotentials.length; i++) {
int[] dimensionIDs = nodePotentials[i].getDimensionIDs();
for (final int var : dimensionIDs) {
temp[var].add(i);
}
}
for (int i = 0; i < temp.length; i++) {
concernedClusters[i] = ArrayUtils.toIntArray(temp[i]);
}
}
private void initializeFields(int numNodes) {
isBeliefValid = new boolean[beliefs.length];
Arrays.fill(isBeliefValid, false);
queryFactors = new AbstractFactor[numNodes];
preparedQueries = new int[numNodes][];
sepSets = new HashMap<OrderIgnoringPair<Integer>, AbstractFactor>(numNodes);
preparedMultiplications = new HashMap<Pair<Integer, Integer>, int[]>(numNodes);
initializations = new ArrayList<Pair<AbstractFactor, IArrayWrapper>>();
clustersHavingEvidence = new HashSet<Integer>(numNodes);
isObserved = new boolean[numNodes];
}
private JunctionTree buildJunctionTree(BayesNet net) {
final JunctionTree jtree = junctionTreeBuilder.buildJunctionTree(net);
this.junctionTree = jtree.getGraph();
return jtree;
}
private Map<AbstractFactor, Integer> computeHomeClusters(BayesNet net, final List<List<Integer>> clusters) {
Map<AbstractFactor, Integer> homeClusters = new HashMap<AbstractFactor, Integer>();
for (final BayesNode node : net.getNodes()) {
final int[] nodeAndParents = node.getFactor().getDimensionIDs();
for (final ListIterator<List<Integer>> clusterIt = clusters.listIterator(); clusterIt.hasNext();) {
if (containsAll(clusterIt.next(), nodeAndParents)) {
homeClusters.put(node.getFactor(), clusterIt.nextIndex() - 1);
break;
}
}
}
return homeClusters;
}
private boolean containsAll(List<Integer> list, int[] ints) {
for (int n : ints) {
if (!list.contains(n)) {
return false;
}
}
return true;
}
private void initializeClusterFactors(BayesNet net, final List<List<Integer>> clusters,
Map<AbstractFactor, Integer> homeClusters) {
nodePotentials = new AbstractFactor[clusters.size()];
Map<Integer, List<AbstractFactor>> multiplicationPartners = findMultiplicationPartners(net, homeClusters);
for (final ListIterator<List<Integer>> cliqueIt = clusters.listIterator(); cliqueIt.hasNext();) {
final List<Integer> cluster = cliqueIt.next();
int current = cliqueIt.nextIndex() - 1;
List<AbstractFactor> multiplicationPartnerList = multiplicationPartners.get(current);
final AbstractFactor cliqueFactor = factory.create(cluster,
multiplicationPartnerList == null ? Collections.<AbstractFactor>emptyList()
: multiplicationPartnerList);
nodePotentials[current] = cliqueFactor;
}
}
private Map<Integer, List<AbstractFactor>> findMultiplicationPartners(BayesNet net,
Map<AbstractFactor, Integer> homeClusters) {
Map<Integer, List<AbstractFactor>> potentialMap = new HashMap<Integer, List<AbstractFactor>>();
for (final BayesNode node : net.getNodes()) {
final Integer nodeHome = homeClusters.get(node.getFactor());
if (!potentialMap.containsKey(nodeHome)) {
potentialMap.put(nodeHome, new ArrayList<AbstractFactor>());
}
potentialMap.get(nodeHome).add(node.getFactor());
}
return potentialMap;
}
private void initializeSepsetFactors(final List<Pair<OrderIgnoringPair<Integer>, List<Integer>>> sepSets) {
for (final Pair<OrderIgnoringPair<Integer>, List<Integer>> sep : sepSets) {
this.sepSets.put(sep.getFirst(), factory.create(sep.getSecond(), Collections.<AbstractFactor>emptyList()));
}
}
private void setQueryFactors() {
for (int i = 0; i < queryFactors.length; i++) {
for (final Integer f : concernedClusters[i]) {
final boolean isFirstOrSmallerTable = queryFactors[i] == null
|| queryFactors[i].getValues().length() > nodePotentials[f].getValues().length();
if (isFirstOrSmallerTable) {
queryFactors[i] = nodePotentials[f];
}
}
}
queryFactorReverseMapping = new int[nodePotentials.length][];
for (int i = 0; i < nodePotentials.length; i++) {
List<Integer> queryVars = new ArrayList<Integer>();
for (int var : nodePotentials[i].getDimensionIDs()) {
if (queryFactors[var] == nodePotentials[i]) {
queryVars.add(var);
}
}
queryFactorReverseMapping[i] = ArrayUtils.toIntArray(queryVars);
}
}
private void prepareMultiplications() {
// compress by combining equal prepared statements, thus saving memory
final CanonicalIntArrayManager flyWeight = new CanonicalIntArrayManager();
prepareSepsetMultiplications(flyWeight);
prepareQueries(flyWeight);
}
private void prepareSepsetMultiplications(final CanonicalIntArrayManager flyWeight) {
for (int node = 0; node < nodePotentials.length; node++) {
for (final int n : junctionTree.getNeighbors(node)) {
final int[] preparedMultiplication = nodePotentials[n].prepareMultiplication(sepSets
.get(new OrderIgnoringPair<Integer>(node, n)));
preparedMultiplications.put(Pair.newPair(node, n), flyWeight.getInstance(preparedMultiplication));
}
}
}
private void prepareQueries(final CanonicalIntArrayManager flyWeight) {
for (int i = 0; i < queryFactors.length; i++) {
final AbstractFactor beliefFactor = factory.create(Arrays.asList(i),
Collections.<AbstractFactor>emptyList());
final int[] preparedQuery = queryFactors[i].prepareMultiplication(beliefFactor);
preparedQueries[i] = flyWeight.getInstance(preparedQuery);
}
}
private void prepareScratch() {
int maxSize = 0;
for (AbstractFactor sepSet : sepSets.values()) {
maxSize = Math.max(maxSize, sepSet.getValues().length());
}
scratchpad = new double[maxSize];
}
private void invokeInitialBeliefUpdate() {
collectEvidence(0, new HashSet<Integer>());
distributeEvidence(0, new HashSet<Integer>());
}
private void initializePotentialValues() {
for (final AbstractFactor f : nodePotentials) {
f.fill(f.isLogScale() ? ONE_LOG : ONE);
}
for (final Entry<OrderIgnoringPair<Integer>, AbstractFactor> sepSet : sepSets.entrySet()) {
if (!areBothEndsLogScale(sepSet.getKey())) {
// if one part is log-scale, we transform to non-log-scale
sepSet.getValue().fill(ONE);
} else {
sepSet.getValue().fill(ONE_LOG);
}
}
}
private void multiplyCPTsIntoPotentials(BayesNet net, Map<AbstractFactor, Integer> homeClusters) {
for (final BayesNode node : net.getNodes()) {
final AbstractFactor nodeHome = nodePotentials[homeClusters.get(node.getFactor())];
if (nodeHome.isLogScale()) {
nodeHome.multiplyCompatibleToLog(node.getFactor());
} else {
nodeHome.multiplyCompatible(node.getFactor());
}
}
}
private boolean areBothEndsLogScale(final OrderIgnoringPair<Integer> edge) {
return nodePotentials[edge.getFirst()].isLogScale() && nodePotentials[edge.getSecond()].isLogScale();
}
private void storePotentialValues() {
CanonicalArrayWrapperManager flyweight = new CanonicalArrayWrapperManager();
for (final AbstractFactor pot : nodePotentials) {
initializations.add(newPair(pot, flyweight.getInstance(pot.getValues().clone())));
}
for (final AbstractFactor sep : sepSets.values()) {
initializations.add(newPair(sep, flyweight.getInstance(sep.getValues().clone())));
}
}
}