/*
* Copyright 2015 Red Hat, Inc. and/or its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.drools.beliefs.bayes;
import org.drools.beliefs.graph.Graph;
import org.drools.beliefs.graph.GraphNode;
import org.junit.Before;
import org.junit.Test;
import java.util.ArrayList;
import java.util.List;
import static org.drools.beliefs.bayes.GraphTest.addNode;
import static org.drools.beliefs.bayes.GraphTest.bitSet;
import static org.junit.Assert.assertEquals;
/**
* This class tests that the iteration order for collect and distribute evidence is correct.
* It tests from 4 different positions on the same network. First individually for collect and then distribute, and then through globalUpdate request
* Then it calls the globalUpdate that recurses the network and calls globalUpdate for each clique
*/
public class GlobalUpdateTest {
Graph<BayesVariable> graph = new BayesNetwork();
GraphNode x0 = addNode(graph);
// 0
// |
// 3_2__1
// | |
// 4 5
// |
// 7__6__8
JunctionTreeClique n0 = new JunctionTreeClique(0, graph, bitSet("1"));
JunctionTreeClique n1 = new JunctionTreeClique(1, graph, bitSet("1"));
JunctionTreeClique n2 = new JunctionTreeClique(2, graph, bitSet("1"));
JunctionTreeClique n3 = new JunctionTreeClique(3, graph, bitSet("1"));
JunctionTreeClique n4 = new JunctionTreeClique(4, graph, bitSet("1"));
JunctionTreeClique n5 = new JunctionTreeClique(5, graph, bitSet("1"));
JunctionTreeClique n6 = new JunctionTreeClique(6, graph, bitSet("1"));
JunctionTreeClique n7 = new JunctionTreeClique(7, graph, bitSet("1"));
JunctionTreeClique n8 = new JunctionTreeClique(8, graph, bitSet("1"));
JunctionTree tree;
BayesInstance bayesInstance;
final List<String> messageResults = new ArrayList<String>();
final List<String> globalUpdateResults = new ArrayList<String>();
@Before
public void startUp() {
int i = 0;
List<JunctionTreeSeparator> list = new ArrayList<JunctionTreeSeparator>();
connectChildren(graph, n0, list, n1);
connectChildren(graph, n1, list, n2, n5);
connectChildren(graph, n2, list, n3, n4);
connectChildren(graph, n5, list, n6);
connectChildren(graph, n6, list, n7, n8);
tree = new JunctionTree(graph, n0, new JunctionTreeClique[]{n0, n1, n2, n3, n4, n5, n6, n7, n8}, list.toArray(new JunctionTreeSeparator[list.size()]));
bayesInstance = new BayesInstance(tree);
bayesInstance.setPassMessageListener(new PassMessageListener() {
@Override
public void beforeProjectAndAbsorb(JunctionTreeClique sourceNode, JunctionTreeSeparator sep, JunctionTreeClique targetNode, double[] oldSeparatorPotentials) {
// System.out.print("\"" + sourceNode.getId() + ":" + targetNode.getId() + "\", ");
messageResults.add(sourceNode.getId() + ":" + targetNode.getId());
}
@Override
public void afterProject(JunctionTreeClique sourceNode, JunctionTreeSeparator sep, JunctionTreeClique targetNode, double[] oldSeparatorPotentials) {
}
@Override
public void afterAbsorb(JunctionTreeClique sourceNode, JunctionTreeSeparator sep, JunctionTreeClique targetNode, double[] oldSeparatorPotentials) {
}
});
bayesInstance.setGlobalUpdateListener(new GlobalUpdateListener() {
@Override
public void beforeGlobalUpdate(CliqueState clique) {
globalUpdateResults.add("" + clique.getJunctionTreeClique().getId());
}
@Override
public void afterGlobalUpdate(CliqueState clique) {
}
});
}
@Test
public void testCollectFromRootClique() {
bayesInstance.collectEvidence(n0);
assertEquals(asList("3:2", "4:2", "2:1", "7:6", "8:6", "6:5", "5:1", "1:0" ), messageResults);
}
@Test
public void testCollectFromMidTipClique() {
bayesInstance.collectEvidence(n4);
assertEquals( asList( "0:1", "7:6", "8:6", "6:5", "5:1", "1:2", "3:2", "2:4" ), messageResults);
}
@Test
public void testCollectFromEndTipClique() {
bayesInstance.collectEvidence(n7);
assertEquals( asList( "0:1", "3:2", "4:2", "2:1", "1:5", "5:6", "8:6", "6:7" ), messageResults);
}
@Test
public void testCollectFromMidClique() {
bayesInstance.collectEvidence(n5);
assertEquals( asList( "0:1", "3:2", "4:2", "2:1", "1:5", "7:6", "8:6", "6:5" ), messageResults);
}
@Test
public void testDistributeFromRootClique() {
bayesInstance.distributeEvidence(n0);
assertEquals( asList( "0:1", "1:2", "2:3", "2:4", "1:5", "5:6", "6:7", "6:8" ), messageResults);
}
@Test
public void testDistributeFromMidTipClique() {
bayesInstance.distributeEvidence(n4);
assertEquals( asList( "4:2", "2:1", "1:0", "1:5", "5:6", "6:7", "6:8", "2:3" ), messageResults);
}
@Test
public void testDistributeFromEndTipClique() {
bayesInstance.distributeEvidence(n7);
assertEquals( asList( "7:6", "6:5", "5:1", "1:0", "1:2", "2:3", "2:4", "6:8" ), messageResults);
}
@Test
public void testDistributeFromMidClique() {
bayesInstance.distributeEvidence(n5);
assertEquals( asList( "5:1", "1:0", "1:2", "2:3", "2:4", "5:6", "6:7", "6:8" ), messageResults);
}
@Test
public void testGlobalUpdateFromRootClique() {
bayesInstance.globalUpdate(n0);
assertEquals( asList( "3:2", "4:2", "2:1", "7:6", "8:6", "6:5", "5:1", "1:0", //n0
"0:1", "1:2", "2:3", "2:4", "1:5", "5:6", "6:7", "6:8" //n0
), messageResults);
assertEquals( asList("0"), globalUpdateResults);
}
@Test
public void testGlobalUpdateFromMidTipClique() {
bayesInstance.globalUpdate(n4);
assertEquals( asList( "0:1", "7:6", "8:6", "6:5", "5:1", "1:2", "3:2", "2:4", //n4
"4:2", "2:1", "1:0", "1:5", "5:6", "6:7", "6:8", "2:3" //n4
), messageResults);
assertEquals( asList("4"), globalUpdateResults);
}
@Test
public void testGlobalUpdateFromEndTipClique() {
bayesInstance.globalUpdate(n7);
assertEquals( asList( "0:1", "3:2", "4:2", "2:1", "1:5", "5:6", "8:6", "6:7", //n7
"7:6", "6:5", "5:1", "1:0", "1:2", "2:3", "2:4", "6:8" //n7
), messageResults);
assertEquals( asList("7"), globalUpdateResults);
}
@Test
public void testGlobalUpdateFromMidClique() {
bayesInstance.globalUpdate(n5);
assertEquals( asList( "0:1", "3:2", "4:2", "2:1", "1:5", "7:6", "8:6", "6:5", //n5
"5:1", "1:0", "1:2", "2:3", "2:4", "5:6", "6:7", "6:8" //n5
), messageResults);
assertEquals( asList("5"), globalUpdateResults);
}
@Test
public void testDistributeFromGlobalUpdate() {
bayesInstance.globalUpdate();
assertEquals( asList( "3:2", "4:2", "2:1", "7:6", "8:6", "6:5", "5:1", "1:0", //n0
"0:1", "1:2", "2:3", "2:4", "1:5", "5:6", "6:7", "6:8" //n0
// "0:1", "3:2", "4:2", "2:1", "7:6", "8:6", "6:5", "5:1", //n1
// "1:0", "1:2", "2:3", "2:4", "1:5", "5:6", "6:7", "6:8", //n1
// "0:1", "7:6", "8:6", "6:5", "5:1", "1:2", "3:2", "4:2", //n2
// "2:1", "1:0", "1:5", "5:6", "6:7", "6:8", "2:3", "2:4", //n2
// "0:1", "7:6", "8:6", "6:5", "5:1", "1:2", "4:2", "2:3", //n3
// "3:2", "2:1", "1:0", "1:5", "5:6", "6:7", "6:8", "2:4", //n3
// "0:1", "7:6", "8:6", "6:5", "5:1", "1:2", "3:2", "2:4", //n4
// "4:2", "2:1", "1:0", "1:5", "5:6", "6:7", "6:8", "2:3", //n4
// "0:1", "3:2", "4:2", "2:1", "1:5", "7:6", "8:6", "6:5", //n5
// "5:1", "1:0", "1:2", "2:3", "2:4", "5:6", "6:7", "6:8", //n5
// "0:1", "3:2", "4:2", "2:1", "1:5", "5:6", "7:6", "8:6", //n6
// "6:5", "5:1", "1:0", "1:2", "2:3", "2:4", "6:7", "6:8", //n6
// "0:1", "3:2", "4:2", "2:1", "1:5", "5:6", "8:6", "6:7", //n7
// "7:6", "6:5", "5:1", "1:0", "1:2", "2:3", "2:4", "6:8", //n7
// "0:1", "3:2", "4:2", "2:1", "1:5", "5:6", "7:6", "6:8", //n8
// "8:6", "6:5", "5:1", "1:0", "1:2", "2:3", "2:4", "6:7" //n8
), messageResults);
// assertEquals( asList( "0", "1", "2", "3", "4", "5", "6", "7", "8"), globalUpdateResults);
assertEquals( asList( "0" ), globalUpdateResults);
}
public void testGlobalUpdate() {
bayesInstance.globalUpdate();
}
public List asList(String... items) {
List<String> list = new ArrayList<String>();
for ( String s : items ) {
list.add( s );
}
return list;
}
public void connectChildren(Graph<BayesVariable> graph, JunctionTreeClique parent, List list, JunctionTreeClique... children) {
for ( JunctionTreeClique child : children ) {
list.add( new JunctionTreeSeparator(list.size(), parent, child, bitSet("0"), graph) );
}
}
}