/*******************************************************************************
* Copyright (c) 2013 Michael Kutschke.
* All rights reserved. This program and the accompanying materials
* are made available under the terms of the Eclipse Public License v1.0
* which accompanies this distribution, and is available at
* http://www.eclipse.org/legal/epl-v10.html
*
* Contributors:
* Michael Kutschke - initial API and implementation
******************************************************************************/
package org.eclipse.recommenders.jayes.inference.jtree;
import static org.eclipse.recommenders.jayes.util.Pair.newPair;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.ListIterator;
import org.eclipse.recommenders.internal.jayes.util.UnionFind;
import org.eclipse.recommenders.jayes.BayesNet;
import org.eclipse.recommenders.jayes.util.Graph;
import org.eclipse.recommenders.jayes.util.OrderIgnoringPair;
import org.eclipse.recommenders.jayes.util.Pair;
public class SepsetComputer {
public List<Pair<OrderIgnoringPair<Integer>, List<Integer>>> computeSepsets(JunctionTree junctionTree, BayesNet net) {
final List<Pair<OrderIgnoringPair<Integer>, List<Integer>>> candidates = enumerateCandidateSepSets(junctionTree
.getClusters());
Collections.sort(candidates, new SepsetComparator(net));
return computeMaxSpanningTree(junctionTree.getGraph(), candidates);
}
private List<Pair<OrderIgnoringPair<Integer>, List<Integer>>> enumerateCandidateSepSets(List<List<Integer>> clusters) {
final List<Pair<OrderIgnoringPair<Integer>, List<Integer>>> sepSets = new ArrayList<Pair<OrderIgnoringPair<Integer>, List<Integer>>>();
final ListIterator<List<Integer>> it = clusters.listIterator();
while (it.hasNext()) {
final List<Integer> clique1 = it.next();
final ListIterator<List<Integer>> remainingIt = clusters.listIterator(it.nextIndex());
while (remainingIt.hasNext()) { // generate sepSets
final List<Integer> clique2 = new ArrayList<Integer>(remainingIt.next());
clique2.retainAll(clique1);
sepSets.add(newPair(new OrderIgnoringPair<Integer>(it.nextIndex() - 1, remainingIt.nextIndex() - 1),
clique2));
}
}
return sepSets;
}
private List<Pair<OrderIgnoringPair<Integer>, List<Integer>>> computeMaxSpanningTree(Graph graph,
final List<Pair<OrderIgnoringPair<Integer>, List<Integer>>> sortedCandidateSepSets) {
final ArrayDeque<Pair<OrderIgnoringPair<Integer>, List<Integer>>> pq = new ArrayDeque<Pair<OrderIgnoringPair<Integer>, List<Integer>>>(
sortedCandidateSepSets);
final int vertexCount = graph.numberOfVertices();
final UnionFind[] sets = UnionFind.createArray(vertexCount);
final List<Pair<OrderIgnoringPair<Integer>, List<Integer>>> leftSepSets = new ArrayList<Pair<OrderIgnoringPair<Integer>, List<Integer>>>();
while (leftSepSets.size() < vertexCount - 1) {
final Pair<OrderIgnoringPair<Integer>, List<Integer>> sep = pq.poll();
final boolean bothEndsInSameTree = sets[sep.getFirst().getFirst()].find() == sets[sep.getFirst()
.getSecond()].find();
if (!bothEndsInSameTree) {
sets[sep.getFirst().getFirst()].merge(sets[sep.getFirst().getSecond()]);
leftSepSets.add(sep);
graph.addEdge(sep.getFirst().getFirst(), sep.getFirst().getSecond());
}
}
return leftSepSets;
}
private static final class SepsetComparator implements Comparator<Pair<OrderIgnoringPair<Integer>, List<Integer>>> {
private final BayesNet net;
public SepsetComparator(BayesNet net) {
this.net = net;
}
// heuristic: choose sepSet with most variables first,
// if equal, choose the on with least table size
@Override
public int compare(final Pair<OrderIgnoringPair<Integer>, List<Integer>> sepSet1,
final Pair<OrderIgnoringPair<Integer>, List<Integer>> sepSet2) {
final int compareNumberOfVariables = compare(sepSet2.getSecond().size(), sepSet1.getSecond().size());
if (compareNumberOfVariables != 0) {
return compareNumberOfVariables;
}
final int tableSize1 = getTableSize(sepSet1.getSecond());
final int tableSize2 = getTableSize(sepSet2.getSecond());
return compare(tableSize1, tableSize2);
}
private int getTableSize(final List<Integer> cluster) {
int tableSize = 1;
for (final int id : cluster) {
tableSize *= net.getNode(id).getOutcomeCount();
}
return tableSize;
}
private int compare(final int i1, final int i2) {
return i1 - i2;
}
}
}