/** * Copyright (c) 2011 Michael Kutschke. * All rights reserved. This program and the accompanying materials * are made available under the terms of the Eclipse Public License v1.0 * which accompanies this distribution, and is available at * http://www.eclipse.org/legal/epl-v10.html * * Contributors: * Michael Kutschke - initial API and implementation. */ package org.eclipse.recommenders.tests.jayes; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import java.io.BufferedReader; import java.io.InputStreamReader; import java.io.Reader; import java.nio.CharBuffer; import java.util.HashMap; import java.util.List; import java.util.Map; import org.eclipse.recommenders.jayes.BayesNet; import org.eclipse.recommenders.jayes.BayesNode; import org.eclipse.recommenders.jayes.factor.AbstractFactor; import org.eclipse.recommenders.jayes.factor.FactorFactory; import org.eclipse.recommenders.jayes.inference.IBayesInferer; import org.eclipse.recommenders.jayes.inference.junctionTree.JunctionTreeAlgorithm; import org.eclipse.recommenders.jayes.io.XMLBIFReader; import org.eclipse.recommenders.jayes.testgen.TestCase; import org.eclipse.recommenders.jayes.testgen.TestcaseDeserializer; import org.eclipse.recommenders.jayes.testgen.scenario.impl.SampledScenarioGenerator; import org.eclipse.recommenders.tests.jayes.lbp.LoopyBeliefPropagation; import org.eclipse.recommenders.tests.jayes.util.NetExamples; import org.junit.Test; public class JunctionTreeTest { private static final double TOLERANCE = 0.01; private static final double SMALL_TOLERANCE = 0.00001; @Test public void testInference1() { BayesNet net = NetExamples.testNet1(); BayesNode a = net.getNode("a"); BayesNode b = net.getNode("b"); IBayesInferer inference = new JunctionTreeAlgorithm(); inference.addEvidence(a, "false"); inference.addEvidence(b, "lu"); inference.setNetwork(net); IBayesInferer compare = new LoopyBeliefPropagation(); compare.setNetwork(net); compare.addEvidence(a, "false"); compare.addEvidence(b, "lu"); for (BayesNode n : net.getNodes()) assertArrayEquals(compare.getBeliefs(n), inference.getBeliefs(n), 0.01); } @Test public void testLogScale() { BayesNet net = NetExamples.testNet1(); BayesNode a = net.getNode("a"); BayesNode b = net.getNode("b"); JunctionTreeAlgorithm inferer = new JunctionTreeAlgorithm(); inferer.getFactory().setUseLogScale(true); inferer.addEvidence(a, "false"); inferer.addEvidence(b, "lu"); inferer.setNetwork(net); IBayesInferer compare = new LoopyBeliefPropagation(); compare.setNetwork(net); compare.addEvidence(a, "false"); compare.addEvidence(b, "lu"); for (BayesNode n : net.getNodes()) assertArrayEquals(compare.getBeliefs(n), inferer.getBeliefs(n), TOLERANCE); } @Test public void testMixedScale() { BayesNet net = NetExamples.testNet1(); BayesNode a = net.getNode("a"); BayesNode b = net.getNode("b"); JunctionTreeAlgorithm inferer = new JunctionTreeAlgorithm(); // this will make the a,b,c clique log scale but the // c,d clique normal inferer.setFactorFactory(new FactorFactory() { @Override protected boolean getUseLogScale(AbstractFactor f) { return f.getDimensions().length > 2; } }); inferer.addEvidence(a, "false"); inferer.addEvidence(b, "lu"); inferer.setNetwork(net); IBayesInferer compare = new LoopyBeliefPropagation(); compare.setNetwork(net); compare.addEvidence(a, "false"); compare.addEvidence(b, "lu"); for (BayesNode n : net.getNodes()) assertArrayEquals(compare.getBeliefs(n), inferer.getBeliefs(n), 0.01); } @Test public void testFailedCase1() { BayesNet net = NetExamples.testNet1(); BayesNode a = net.getNode("a"); BayesNode b = net.getNode("b"); BayesNode c = net.getNode("c"); JunctionTreeAlgorithm inferer = new JunctionTreeAlgorithm(); inferer.setNetwork(net); Map<BayesNode, String> evidence = new HashMap<BayesNode, String>(); evidence.put(a, "false"); evidence.put(c, "true"); inferer.setEvidence(evidence); assertEquals(0.22, inferer.getBeliefs(b)[0], 0.01); } @Test public void testUnconnected() { BayesNet net = NetExamples.unconnectedNet(); BayesNode a = net.getNode("a"); BayesNode b = net.getNode("b"); IBayesInferer inference = new JunctionTreeAlgorithm(); inference.addEvidence(a, "false"); inference.addEvidence(b, "true"); inference.setNetwork(net); IBayesInferer compare = new LoopyBeliefPropagation(); compare.setNetwork(net); compare.addEvidence(a, "false"); compare.addEvidence(b, "true"); for (BayesNode n : net.getNodes()) assertArrayEquals(inference.getBeliefs(n), compare.getBeliefs(n), 0.01); } @Test public void testSparseFactors() { BayesNet net = NetExamples.sparseNet(); BayesNode a = net.getNode("a"); BayesNode b = net.getNode("b"); IBayesInferer inference = new JunctionTreeAlgorithm(); inference.addEvidence(a, "false"); inference.addEvidence(b, "lu"); inference.setNetwork(net); IBayesInferer compare = new LoopyBeliefPropagation(); compare.setNetwork(net); compare.addEvidence(a, "false"); compare.addEvidence(b, "lu"); for (BayesNode n : net.getNodes()) assertArrayEquals(compare.getBeliefs(n), inference.getBeliefs(n), 0.01); } @Test public void testLargerScaleCorrectness() throws Exception { getClass().getClassLoader(); BayesNet net = new XMLBIFReader().read(getClass().getClassLoader().getResourceAsStream("JPanel.xml")); TestcaseDeserializer deser = new TestcaseDeserializer(net); Reader rdr = new BufferedReader(new InputStreamReader(getClass().getClassLoader().getResourceAsStream( "testcases_JPanel.json"))); StringBuffer buf = new StringBuffer(); CharBuffer cbuff = CharBuffer.allocate(1024); while (rdr.read(cbuff) != -1) { cbuff.flip(); buf.append(cbuff); cbuff.clear(); } rdr.close(); List<TestCase> testcases = deser.deserialize(buf.toString()); JunctionTreeAlgorithm algo = new JunctionTreeAlgorithm(); algo.setNetwork(net); for (TestCase tc : testcases) { algo.setEvidence(tc.evidence); for (BayesNode node : net.getNodes()) { assertArrayEquals(tc.beliefs.get(node), algo.getBeliefs(node), SMALL_TOLERANCE); } } } @Test public void testLargerScaleCorrectnessAB() throws Exception { BayesNet net = new XMLBIFReader().read(getClass().getClassLoader().getResourceAsStream("JPanel.xml")); SampledScenarioGenerator testgen = new SampledScenarioGenerator(); testgen.setNetwork(net); testgen.seed(1337); testgen.setEvidenceRate(0.5); JunctionTreeAlgorithm a = new JunctionTreeAlgorithm(); a.setNetwork(net); JunctionTreeAlgorithm b = new JunctionTreeAlgorithm(); b.getFactory().setFloatingPointType(float.class); b.setNetwork(net); for (int i = 0; i < 1000; i++) { Map<BayesNode, String> testcase = testgen.testcase(); b.setEvidence(testcase); a.setEvidence(testcase); for (BayesNode node : net.getNodes()) { assertArrayEquals(a.getBeliefs(node), b.getBeliefs(node), SMALL_TOLERANCE); } } } }