/** * Copyright (c) 2011 Michael Kutschke. * All rights reserved. This program and the accompanying materials * are made available under the terms of the Eclipse Public License v1.0 * which accompanies this distribution, and is available at * http://www.eclipse.org/legal/epl-v10.html * * Contributors: * Michael Kutschke - initial API and implementation. */ package org.eclipse.recommenders.jayes.inference.junctionTree; import static org.eclipse.recommenders.jayes.util.Pair.newPair; import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.Comparator; import java.util.List; import java.util.ListIterator; import org.eclipse.recommenders.internal.jayes.util.UnionFind; import org.eclipse.recommenders.jayes.BayesNet; import org.eclipse.recommenders.jayes.BayesNode; import org.eclipse.recommenders.jayes.util.Graph; import org.eclipse.recommenders.jayes.util.Graph.Edge; import org.eclipse.recommenders.jayes.util.Pair; import org.eclipse.recommenders.jayes.util.triangulation.GraphElimination; import org.eclipse.recommenders.jayes.util.triangulation.IEliminationHeuristic; public class JunctionTreeBuilder { private IEliminationHeuristic heuristic; public static JunctionTreeBuilder forHeuristic(IEliminationHeuristic heuristic) { return new JunctionTreeBuilder(heuristic); } protected JunctionTreeBuilder(IEliminationHeuristic heuristic) { this.heuristic = heuristic; } public JunctionTree buildJunctionTree(BayesNet net) { JunctionTree junctionTree = new JunctionTree(new Graph()); junctionTree.setClusters(triangulateGraphAndFindCliques(buildMoralGraph(net), weightNodesByOutcomes(net), heuristic)); junctionTree.setSepSets(computeSepsets(junctionTree, net)); return junctionTree; } private Graph buildMoralGraph(BayesNet net) { Graph moral = new Graph(); moral.initialize(net.getNodes().size()); for (final BayesNode node : net.getNodes()) { addMoralEdges(moral, node); } return moral; } private void addMoralEdges(Graph moral, final BayesNode node) { final ListIterator<BayesNode> it = node.getParents().listIterator(); while (it.hasNext()) { final BayesNode parent = it.next(); final ListIterator<BayesNode> remainingParentsIt = node.getParents().listIterator(it.nextIndex()); while (remainingParentsIt.hasNext()) { // connect parents final BayesNode otherParent = remainingParentsIt.next(); moral.addEdge(parent.getId(), otherParent.getId()); } moral.addEdge(node.getId(), parent.getId()); } } private List<List<Integer>> triangulateGraphAndFindCliques(Graph graph, double[] weights, IEliminationHeuristic eliminationHeuristic) { GraphElimination triangulate = new GraphElimination(graph, weights, eliminationHeuristic); final List<List<Integer>> cliques = new ArrayList<List<Integer>>(); for (List<Integer> nextClique : triangulate) { if (!containsSuperset(cliques, nextClique)) { cliques.add(nextClique); } } return cliques; } private double[] weightNodesByOutcomes(BayesNet net) { double[] weights = new double[net.getNodes().size()]; for (BayesNode node : net.getNodes()) { weights[node.getId()] = Math.log(node.getOutcomeCount()); // using these weights is the same as minimizing the resulting cluster factor size // which is given by the product of the variable outcome counts. } return weights; } private boolean containsSuperset(final Collection<? extends Collection<Integer>> sets, final Collection<Integer> set) { boolean isSubsetOfOther = false; for (final Collection<Integer> superset : sets) { if (superset.containsAll(set)) { isSubsetOfOther = true; break; } } return isSubsetOfOther; } private List<Pair<Edge, List<Integer>>> computeSepsets(JunctionTree junctionTree, BayesNet net) { final List<Pair<Edge, List<Integer>>> candidates = enumerateCandidateSepSets(junctionTree.getClusters()); Collections.sort(candidates, new SepsetComparator(net)); return computeMaxSpanningTree(junctionTree.getGraph(), candidates); } private List<Pair<Edge, List<Integer>>> enumerateCandidateSepSets(List<List<Integer>> clusters) { final List<Pair<Edge, List<Integer>>> sepSets = new ArrayList<Pair<Edge, List<Integer>>>(); final ListIterator<List<Integer>> it = clusters.listIterator(); while (it.hasNext()) { final List<Integer> clique1 = it.next(); final ListIterator<List<Integer>> remainingIt = clusters.listIterator(it.nextIndex()); while (remainingIt.hasNext()) { // generate sepSets final List<Integer> clique2 = new ArrayList<Integer>(remainingIt.next()); clique2.retainAll(clique1); sepSets.add(newPair(new Edge(it.nextIndex() - 1, remainingIt.nextIndex() - 1), clique2)); } } return sepSets; } private List<Pair<Edge, List<Integer>>> computeMaxSpanningTree(Graph graph, final List<Pair<Edge, List<Integer>>> sortedCandidateSepSets) { final ArrayDeque<Pair<Edge, List<Integer>>> pq = new ArrayDeque<Pair<Edge, List<Integer>>>( sortedCandidateSepSets); final int vertexCount = graph.getAdjacency().size(); final UnionFind[] sets = UnionFind.createArray(vertexCount); final List<Pair<Edge, List<Integer>>> leftSepSets = new ArrayList<Pair<Edge, List<Integer>>>(); while (leftSepSets.size() < (vertexCount - 1)) { final Pair<Edge, List<Integer>> sep = pq.poll(); final boolean bothEndsInSameTree = sets[sep.getFirst().getFirst()].find() == sets[sep.getFirst() .getSecond()].find(); if (!bothEndsInSameTree) { sets[sep.getFirst().getFirst()].merge(sets[sep.getFirst().getSecond()]); leftSepSets.add(sep); graph.addEdge(sep.getFirst().getFirst(), sep.getFirst().getSecond()); } } return leftSepSets; } private final class SepsetComparator implements Comparator<Pair<Edge, List<Integer>>> { private final BayesNet net; public SepsetComparator(BayesNet net) { this.net = net; } // heuristic: choose sepSet with most variables first, // if equal, choose the on with least table size @Override public int compare(final Pair<Edge, List<Integer>> sepSet1, final Pair<Edge, List<Integer>> sepSet2) { final int compareNumberOfVariables = compare(sepSet1.getSecond().size(), sepSet2.getSecond().size()); if (compareNumberOfVariables != 0) { return -compareNumberOfVariables; } final int tableSize1 = getTableSize(sepSet1.getSecond()); final int tableSize2 = getTableSize(sepSet2.getSecond()); return compare(tableSize1, tableSize2); } private int getTableSize(final List<Integer> cluster) { int tableSize = 1; for (final int id : cluster) { tableSize *= net.getNode(id).getOutcomeCount(); } return tableSize; } private int compare(final int i1, final int i2) { return i1 - i2; } } }