/* Copyright (C) 2003 Univ. of Massachusetts Amherst, Computer Science Dept. This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit). http://www.cs.umass.edu/~mccallum/mallet This software is provided under the terms of the Common Public License, version 1.0, as published by http://www.opensource.org. For further information, see the file `LICENSE' included with this distribution. */ package cc.mallet.grmm.test; import junit.framework.AssertionFailedError; import junit.framework.Test; import junit.framework.TestCase; import junit.framework.TestSuite; import java.util.*; import java.util.Random; import java.util.logging.Logger; import java.io.IOException; import java.io.StringReader; import java.io.BufferedReader; import cc.mallet.grmm.inference.*; import cc.mallet.grmm.types.*; import cc.mallet.grmm.util.GeneralUtils; import cc.mallet.grmm.util.ModelReader; import cc.mallet.types.Dirichlet; import cc.mallet.types.Matrix; import cc.mallet.types.Matrixn; import cc.mallet.types.tests.TestSerializable; import cc.mallet.util.*; //import cc.mallet.util.Random; import gnu.trove.TDoubleArrayList; /** * Torture tests of inference in GRMM. Well, actually, they're * not all that torturous, but hopefully they're at least * somewhat disconcerting. * * @author <a href="mailto:casutton@cs.umass.edu">Charles Sutton</a> * @version $Id: TestInference.java,v 1.1 2007/10/22 21:37:40 mccallum Exp $ */ public class TestInference extends TestCase { private static Logger logger = MalletLogger.getLogger(TestInference.class.getName()); private static double APPX_EPSILON = 0.15; final public Class[] algorithms = { BruteForceInferencer.class, VariableElimination.class, JunctionTreeInferencer.class, }; final public Class[] appxAlgs = { TRP.class, LoopyBP.class, }; // only used for logJoint test for now final public Class[] allAlgs = { // BruteForceInferencer.class, JunctionTreeInferencer.class, TRP.class, // VariableElimination.class, LoopyBP.class, }; final public Class[] treeAlgs = { TreeBP.class, }; List modelsList; UndirectedModel[] models; FactorGraph[] trees; Factor[][] treeMargs; public TestInference(String name) { super(name); } private static UndirectedModel createChainGraph() { Variable[] vars = new Variable[5]; UndirectedModel model = new UndirectedModel(); try { // Add all variables to model for (int i = 0; i < 5; i++) { vars[i] = new Variable(2); } // Add some links double probs[] = {0.9, 0.1, 0.1, 0.9}; for (int i = 0; i < 4; i++) { Variable[] pair = { vars[i], vars[i + 1], }; TableFactor pot = new TableFactor (pair, probs); model.addFactor (pot); } } catch (Exception e) { e.printStackTrace(); assertTrue(false); } return model; } private static UndirectedModel createTriangle() { Variable[] vars = new Variable[3]; for (int i = 0; i < 3; i++) { vars[i] = new Variable (2); } UndirectedModel model = new UndirectedModel (vars); double[][] pots = new double[][] { { 0.2, 0.8, 0.1, 0.9 }, { 0.7, 0.3, 0.5, 0.5 }, { 0.6, 0.4, 0.8, 0.2 }, { 0.35, 0.65 } }; // double[][] pots = new double[] [] { { model.addFactor (vars[0], vars[1], pots[0]); model.addFactor (vars[1], vars[2], pots[1]); model.addFactor (vars[2], vars[0], pots[2]); TableFactor pot = new TableFactor (new Variable[] { vars[0] }, pots[3]); model.addFactor (pot); return model; } private static TableFactor randomEdgePotential(Random r, Variable v1, Variable v2) { int max1 = v1.getNumOutcomes(); int max2 = v2.getNumOutcomes(); Matrix phi = new Matrixn(new int[]{max1, max2}); for (int i = 0; i < v1.getNumOutcomes(); i++) { for (int j = 0; j < v2.getNumOutcomes(); j++) { phi.setValue(new int[]{i, j}, r.nextDouble ()); // rescale(r.nextDouble())); } } return new TableFactor (new Variable[]{v1, v2}, phi); } private static TableFactor randomNodePotential(Random r, Variable v) { int max = v.getNumOutcomes(); Matrix phi = new Matrixn(new int[]{max}); for (int i = 0; i < v.getNumOutcomes(); i++) { phi.setSingleValue(i, rescale(r.nextDouble())); } return new TableFactor (new Variable[]{v}, phi); } // scale d into range 0.2..0.8 private static double rescale(double d) { return 0.2 + 0.6 * d; } private static UndirectedModel createRandomGraph(int numV, int numOutcomes, Random r) { Variable[] vars = new Variable[numV]; for (int i = 0; i < numV; i++) { vars[i] = new Variable(numOutcomes); } UndirectedModel model = new UndirectedModel(vars); for (int i = 0; i < numV; i++) { boolean hasOne = false; for (int j = i + 1; j < numV; j++) { if (r.nextBoolean()) { hasOne = true; model.addFactor (randomEdgePotential (r, vars[i], vars[j])); } } // If vars [i] has no edge potential, add a node potential // To keep things simple, we'll require the potential to be normalized. if (!hasOne) { Factor pot = randomNodePotential(r, vars[i]); pot.normalize(); model.addFactor (pot); } } // Ensure exactly one connected component for (int i = 0; i < numV; i++) { for (int j = i + 1; j < numV; j++) { if (!model.isConnected(vars[i], vars[j])) { Factor ptl = randomEdgePotential (r, vars[i], vars[j]); model.addFactor (ptl); } } } return model; } public static UndirectedModel createRandomGrid(int w, int h, int maxOutcomes, Random r) { Variable[][] vars = new Variable[w][h]; UndirectedModel mdl = new UndirectedModel(w * h); for (int i = 0; i < w; i++) { for (int j = 0; j < h; j++) { vars[i][j] = new Variable(r.nextInt(maxOutcomes - 1) + 2); } } for (int i = 0; i < w; i++) { for (int j = 0; j < h; j++) { Factor ptl; if (i < w - 1) { ptl = randomEdgePotential (r, vars[i][j], vars[i + 1][j]); mdl.addFactor (ptl); } if (j < h - 1) { ptl = randomEdgePotential (r, vars[i][j], vars[i][j + 1]); mdl.addFactor (ptl); } } } return mdl; } private UndirectedModel createRandomTree(int nnodes, int maxOutcomes, Random r) { Variable[] vars = new Variable[nnodes]; UndirectedModel mdl = new UndirectedModel(nnodes); for (int i = 0; i < nnodes; i++) { vars[i] = new Variable(r.nextInt(maxOutcomes - 1) + 2); } // Add some random edges for (int i = 0; i < nnodes; i++) { for (int j = i + 1; j < nnodes; j++) { if (!mdl.isConnected(vars[i], vars[j]) && r.nextBoolean()) { Factor ptl = randomEdgePotential (r, vars[i], vars[j]); mdl.addFactor (ptl); } } } // Ensure exactly one connected component for (int i = 0; i < nnodes; i++) { for (int j = i + 1; j < nnodes; j++) { if (!mdl.isConnected(vars[i], vars[j])) { System.out.println ("forced edge: " + i + " " + j); Factor ptl = randomEdgePotential (r, vars[i], vars[j]); mdl.addFactor (ptl); } } } return mdl; } public static List createTestModels() { Random r = new Random(42); // These models are all small so that we can run the brute force // inferencer on them. FactorGraph[] mdls = new FactorGraph[]{ createTriangle(), createChainGraph(), createRandomGraph(3, 2, r), createRandomGraph(3, 3, r), createRandomGraph(6, 3, r), createRandomGraph(8, 2, r), createRandomGrid(3, 2, 4, r), createRandomGrid(4, 3, 2, r), }; return new ArrayList(Arrays.asList(mdls)); } public void testUniformJoint () throws Exception { FactorGraph mdl = RandomGraphs.createUniformChain (3); double expected = -Math.log (8); for (int i = 0; i < allAlgs.length; i++) { Inferencer inf = (Inferencer) allAlgs[i].newInstance (); inf.computeMarginals (mdl); for (AssignmentIterator it = mdl.assignmentIterator (); it.hasNext ();) { Assignment assn = it.assignment (); double actual = inf.lookupLogJoint (assn); assertEquals ("Incorrect joint for inferencer "+inf, expected, actual, 1e-5); it.advance (); } } } public void testJointConsistent () throws Exception { for (int i = 0; i < allAlgs.length; i++) { // for (int mdlIdx = 0; mdlIdx < models.length; mdlIdx++) { { int mdlIdx = 13; Inferencer inf = (Inferencer) allAlgs[i].newInstance(); try { FactorGraph mdl = models[mdlIdx]; inf.computeMarginals(mdl); Assignment assn = new Assignment (mdl, new int [mdl.numVariables ()]); assertEquals (Math.log (inf.lookupJoint (assn)), inf.lookupLogJoint (assn), 1e-5); } catch (UnsupportedOperationException e) { // LoopyBP only handles edge ptls logger.warning("Skipping (" + mdlIdx + "," + i + ")\n" + e); throw e; // continue; } } } } public void testFactorizedJoint() throws Exception { Inferencer[][] infs = new Inferencer[allAlgs.length][models.length]; for (int i = 0; i < allAlgs.length; i++) { for (int mdl = 0; mdl < models.length; mdl++) { Inferencer alg = (Inferencer) allAlgs[i].newInstance(); if (alg instanceof TRP) { ((TRP)alg).setRandomSeed (1231234); } try { alg.computeMarginals(models[mdl]); infs[i][mdl] = alg; } catch (UnsupportedOperationException e) { // LoopyBP only handles edge ptls logger.warning("Skipping (" + mdl + "," + i + ")\n" + e); throw e; // continue; } } } /* Ensure that lookupLogJoint() consistent */ int alg1 = 0; // Brute force for (int alg2 = 1; alg2 < allAlgs.length; alg2++) { for (int mdl = 0; mdl < models.length; mdl++) { Inferencer inf1 = infs[alg1][mdl]; Inferencer inf2 = infs[alg2][mdl]; if ((inf1 == null) || (inf2 == null)) { continue; } Iterator it = models[mdl].assignmentIterator(); while (it.hasNext()) { try { Assignment assn = (Assignment) it.next(); double joint1 = inf1.lookupLogJoint(assn); double joint2 = inf2.lookupLogJoint(assn); logger.finest("logJoint: " + inf1 + " " + inf2 + " Model " + mdl + " Assn: " + assn + " INF1: " + joint1 + "\n" + " INF2: " + joint2 + "\n"); assertTrue("logJoint not equal btwn " + GeneralUtils.classShortName (inf1) + " " + " and " + GeneralUtils.classShortName (inf2) + "\n" + " Model " + mdl + "\n" + " INF1: " + joint1 + "\n" + " INF2: " + joint2 + "\n", Math.abs(joint1 - joint2) < 0.2); double joint3 = inf1.lookupJoint(assn); assertTrue("logJoint & joint not consistent\n " + "Model " + mdl + "\n" + assn, Maths.almostEquals(joint3, Math.exp(joint1))); } catch (UnsupportedOperationException e) { // VarElim doesn't compute log joints. Let it slide logger.warning("Skipping " + inf1 + " -> " + inf2 + "\n" + e); continue; } } } } } public void testMarginals() throws Exception { Factor[][][] joints = new Factor[models.length][][]; Inferencer[] appxInferencers = constructAllAppxInferencers (); int numExactAlgs = algorithms.length; int numAppxAlgs = appxInferencers.length; int numAlgs = numExactAlgs + numAppxAlgs; for (int mdl = 0; mdl < models.length; mdl++) { joints[mdl] = new Factor[numAlgs][]; } /* Query every known graph with every known alg. */ for (int i = 0; i < algorithms.length; i++) { for (int mdl = 0; mdl < models.length; mdl++) { Inferencer alg = (Inferencer) algorithms[i].newInstance(); logger.fine("Computing marginals for model " + mdl + " alg " + alg); alg.computeMarginals(models[mdl]); joints[mdl][i] = collectAllMarginals (models [mdl], alg); } } logger.fine("Checking that results are consistent..."); /* Now, make sure the exact marginals are consistent for * the same model. */ for (int mdl = 0; mdl < models.length; mdl++) { int maxV = models[mdl].numVariables (); for (int vrt = 0; vrt < maxV; vrt++) { for (int alg1 = 0; alg1 < algorithms.length; alg1++) { for (int alg2 = 0; alg2 < algorithms.length; alg2++) { Factor joint1 = joints[mdl][alg1][vrt]; Factor joint2 = joints[mdl][alg2][vrt]; try { // By the time we get here, a joint is null only if // there was an UnsupportedOperationException. if ((joint1 != null) && (joint2 != null)) { assertTrue(joint1.almostEquals(joint2)); } } catch (AssertionFailedError e) { System.out.println("\n************************************\nTest FAILED\n\n"); System.out.println("Model " + mdl + " Vertex " + vrt); System.out.println("Algs " + alg1 + " and " + alg2 + " not consistent."); System.out.println("MARGINAL from " + alg1); System.out.println(joint1); System.out.println("MARGINAL from " + alg2); System.out.println(joint2); System.out.println("Marginals from " + alg1 + ":"); for (int i = 0; i < maxV; i++) { System.out.println(joints[mdl][alg1][i]); } System.out.println("Marginals from " + alg2 + ":"); for (int i = 0; i < maxV; i++) { System.out.println(joints[mdl][alg2][i]); } models[mdl].dump (); throw e; } } } } } // Compare all approximate algorithms against brute force. logger.fine("Checking the approximate algorithms..."); int alg2 = 0; // Brute force for (int appxIdx = 0; appxIdx < appxInferencers.length; appxIdx++) { Inferencer alg = appxInferencers [appxIdx]; for (int mdl = 0; mdl < models.length; mdl++) { logger.finer("Running inference alg " + alg + " with model " + mdl); try { alg.computeMarginals(models[mdl]); } catch (UnsupportedOperationException e) { // LoopyBP does not support vertex potentials. // We'll let that slide. if (alg instanceof AbstractBeliefPropagation) { logger.warning("Skipping model " + mdl + " for alg " + alg + "\nInference unsupported."); continue; } else { throw e; } } /* lookup all marginals */ int vrt = 0; int alg1 = numExactAlgs + appxIdx; int maxV = models[mdl].numVariables (); joints[mdl][alg1] = new Factor[maxV]; for (Iterator it = models[mdl].variablesSet ().iterator(); it.hasNext(); vrt++) { Variable var = (Variable) it.next(); logger.finer("Lookup marginal for model " + mdl + " vrt " + var + " alg " + alg); Factor ptl = alg.lookupMarginal(var); joints[mdl][alg1][vrt] = ptl.duplicate(); } for (vrt = 0; vrt < maxV; vrt++) { Factor joint1 = joints[mdl][alg1][vrt]; Factor joint2 = joints[mdl][alg2][vrt]; try { assertTrue(joint1.almostEquals(joint2, APPX_EPSILON)); } catch (AssertionFailedError e) { System.out.println("\n************************************\nAppx Marginal Test FAILED\n\n"); System.out.println("Inferencer: " + alg); System.out.println("Model " + mdl + " Vertex " + vrt); System.out.println(joint1.dumpToString ()); System.out.println(joint2.dumpToString ()); models[mdl].dump (); System.out.println("All marginals:"); for (int i = 0; i < maxV; i++) { System.out.println(joints[mdl][alg1][i].dumpToString ()); } System.out.println("Correct marginals:"); for (int i = 0; i < maxV; i++) { System.out.println(joints[mdl][alg2][i].dumpToString ()); } throw e; } } } } System.out.println("Tested " + models.length + " undirected models."); } private Inferencer[] constructAllAppxInferencers () throws IllegalAccessException, InstantiationException { List algs = new ArrayList (appxAlgs.length * 2); for (int i = 0; i < appxAlgs.length; i++) { algs.add (appxAlgs[i].newInstance ()); } // Add a few that don't fit algs.add (new TRP ().setMessager (new AbstractBeliefPropagation.SumProductMessageStrategy (0.8))); algs.add (new LoopyBP ().setMessager (new AbstractBeliefPropagation.SumProductMessageStrategy (0.8))); algs.add (new SamplingInferencer (new GibbsSampler (10000), 10000)); algs.add (new SamplingInferencer (new ExactSampler (), 1000)); return (Inferencer[]) algs.toArray (new Inferencer [algs.size ()]); } private Inferencer[] constructMaxProductInferencers () throws IllegalAccessException, InstantiationException { List algs = new ArrayList (); algs.add (JunctionTreeInferencer.createForMaxProduct ()); algs.add (TRP.createForMaxProduct ()); algs.add (LoopyBP.createForMaxProduct ()); return (Inferencer[]) algs.toArray (new Inferencer [algs.size ()]); } private Factor[] collectAllMarginals (FactorGraph mdl, Inferencer alg) { int vrt = 0; int numVertices = mdl.numVariables (); Factor[] collector = new Factor[numVertices]; for (Iterator it = mdl.variablesSet ().iterator(); it.hasNext(); vrt++) { Variable var = (Variable) it.next(); try { collector[vrt] = alg.lookupMarginal(var); assert collector [vrt] != null : "Query returned null for model " + mdl + " vertex " + var + " alg " + alg; } catch (UnsupportedOperationException e) { // Allow unsupported inference to slide with warning logger.warning("Warning: Skipping model " + mdl + " for alg " + alg + "\n Inference unsupported."); } } return collector; } public void testQuery () throws Exception { java.util.Random rand = new java.util.Random (15667); for (int mdlIdx = 0; mdlIdx < models.length; mdlIdx++) { FactorGraph mdl = models [mdlIdx]; int size = rand.nextInt (3) + 2; size = Math.min (size, mdl.varSet ().size ()); Collection vars = CollectionUtils.subset (mdl.variablesSet (), size, rand); Variable[] varArr = (Variable[]) vars.toArray (new Variable [0]); Assignment assn = new Assignment (varArr, new int [size]); BruteForceInferencer brute = new BruteForceInferencer(); Factor joint = brute.joint(mdl); double marginal = joint.marginalize(vars).value (assn); for (int algIdx = 0; algIdx < appxAlgs.length; algIdx++) { Inferencer alg = (Inferencer) appxAlgs[algIdx].newInstance(); if (alg instanceof TRP) continue; // trp can't handle disconnected models, which arise during query() double returned = alg.query (mdl, assn); assertEquals ("Failure on model "+mdlIdx+" alg "+alg, marginal, returned, APPX_EPSILON); } } logger.info ("Test testQuery passed."); } // be careful that caching of inference algorithms does not affect results here. public void testSerializable () throws Exception { for (int i = 0; i < algorithms.length; i++) { Inferencer alg = (Inferencer) algorithms[i].newInstance(); testSerializationForAlg (alg); } for (int i = 0; i < appxAlgs.length; i++) { Inferencer alg = (Inferencer) appxAlgs[i].newInstance(); testSerializationForAlg (alg); } Inferencer[] maxAlgs = constructMaxProductInferencers (); for (int i = 0; i < maxAlgs.length; i++) { testSerializationForAlg (maxAlgs [i]); } } private void testSerializationForAlg (Inferencer alg) throws IOException, ClassNotFoundException { for (int mdlIdx = 0; mdlIdx < models.length; mdlIdx++) { FactorGraph mdl = models [mdlIdx]; // Copy the inferencer before calling b/c of random seed issues. Inferencer alg2 = (Inferencer) TestSerializable.cloneViaSerialization (alg); alg.computeMarginals(mdl); Factor[] pre = collectAllMarginals (mdl, alg); alg2.computeMarginals (mdl); Factor[] post2 = collectAllMarginals (mdl, alg2); compareMarginals ("Error comparing marginals after serialzation on model "+mdl, pre, post2); } } private void compareMarginals (String msg, Factor[] pre, Factor[] post) { for (int i = 0; i < pre.length; i++) { Factor ptl1 = pre[i]; Factor ptl2 = post[i]; assertTrue (msg + "\n" + ptl1.dumpToString () + "\n" + ptl2.dumpToString (), ptl1.almostEquals (ptl2, 1e-3)); } } // This is really impossible after the change to the factor graph representation // Tests the measurement of numbers of messages sent public void ignoreTestNumMessages () { for (int mdlIdx = 0; mdlIdx < models.length; mdlIdx++) { UndirectedModel mdl = models [mdlIdx]; TRP trp = new TRP (); trp.computeMarginals (mdl); int expectedMessages = (mdl.numVariables () - 1) * 2 * trp.iterationsUsed(); assertEquals (expectedMessages, trp.getTotalMessagesSent ()); LoopyBP loopy = new LoopyBP (); loopy.computeMarginals (mdl); expectedMessages = mdl.getEdgeSet().size() * 2 * loopy.iterationsUsed(); assertEquals (expectedMessages, loopy.getTotalMessagesSent ()); } } private UndirectedModel createJtChain() { int numNodes = 4; Variable[] nodes = new Variable[numNodes]; for (int i = 0; i < numNodes; i++) { nodes[i] = new Variable(2); } Factor[] pots = new TableFactor[]{ new TableFactor (new Variable[]{nodes[0], nodes[1]}, new double[]{1, 2, 5, 4}), new TableFactor (new Variable[]{nodes[1], nodes[2]}, new double[]{4, 2, 4, 1}), new TableFactor (new Variable[]{nodes[2], nodes[3]}, new double[]{7, 3, 6, 9}), }; for (int i = 0; i < pots.length; i++) { pots[i].normalize(); } UndirectedModel uGraph = new UndirectedModel(); for (int i = 0; i < numNodes - 1; i++) { uGraph.addFactor (pots[i]); } return uGraph; } private static final int JT_CHAIN_TEST_TREE = 2; private void createTestTrees() { Random r = new Random(185); trees = new FactorGraph[] { RandomGraphs.createUniformChain (2), RandomGraphs.createUniformChain (4), createJtChain(), createRandomGrid(5, 1, 3, r), createRandomGrid(6, 1, 2, r), createRandomTree(10, 2, r), createRandomTree(10, 2, r), createRandomTree(8, 3, r), createRandomTree(8, 3, r), }; modelsList.addAll(Arrays.asList(trees)); } private void computeTestTreeMargs() { treeMargs = new Factor[trees.length][]; BruteForceInferencer brute = new BruteForceInferencer(); for (int i = 0; i < trees.length; i++) { FactorGraph mdl = trees[i]; Factor joint = brute.joint(mdl); treeMargs[i] = new Factor[mdl.numVariables ()]; for (Iterator it = mdl.variablesIterator (); it.hasNext();) { Variable var = (Variable) it.next(); treeMargs[i][mdl.getIndex(var)] = joint.marginalize(var); } } } public void testJtConsistency() { for (int mdlIdx = 0; mdlIdx < models.length; mdlIdx++) { UndirectedModel mdl = models[mdlIdx]; JunctionTreeInferencer jti = new JunctionTreeInferencer(); JunctionTree jt = jti.buildJunctionTree(mdl); for (Iterator it = jt.getVerticesIterator(); it.hasNext();) { VarSet parent = (VarSet) it.next(); for (Iterator it2 = jt.getChildren(parent).iterator(); it2.hasNext();) { VarSet child = (VarSet) it2.next(); Factor ptl = jt.getSepsetPot(parent, child); Set intersection = parent.intersection (child); assertTrue (intersection.equals (ptl.varSet())); } } } } private void compareTrpJoint(Factor joint, TRP trp) { Assignment assn = null; double prob1 = 0.0, prob2 = 0.0; try { VarSet all = new HashVarSet (joint.varSet()); for (Iterator it = all.assignmentIterator(); it.hasNext();) { assn = (Assignment) it.next(); prob1 = trp.lookupJoint(assn); prob2 = joint.value (assn); // assertTrue (Maths.almostEquals (prob1, prob2)); assertTrue(Math.abs(prob1 - prob2) < 0.01); } } catch (AssertionFailedError e) { System.out.println("*****************************************\nTEST FAILURE in compareTrpJoint"); System.out.println("*****************************************\nat"); System.out.println(assn); System.out.println("Expected: " + prob2); System.out.println("TRP: " + prob1); System.out.println("*****************************************\nExpected joint"); System.out.println(joint); System.out.println("*****************************************\nTRP dump"); trp.dump(); throw e; } } public void testTrp() { final UndirectedModel model = createTriangle(); TRP trp = new TRP().setTerminator (new TRP.IterationTerminator(200)); BruteForceInferencer brute = new BruteForceInferencer(); Factor joint = brute.joint(model); trp.computeMarginals(model); // Check joint // DiscretePotential joint = brute.joint (model); compareTrpJoint(joint, trp); // Check all marginals try { for (Iterator it = model.variablesIterator (); it.hasNext();) { Variable var = (Variable) it.next(); Factor marg1 = trp.lookupMarginal(var); Factor marg2 = joint.marginalize (var); assertTrue(marg1.almostEquals(marg2, APPX_EPSILON)); } for (Iterator it = model.factorsIterator(); it.hasNext();) { Factor factor = (Factor) it.next (); Factor marg1 = trp.lookupMarginal (factor.varSet ()); Factor marg2 = joint.marginalize (factor.varSet ()); assertTrue(marg1.almostEquals(marg2, APPX_EPSILON)); } } catch (AssertionFailedError e) { System.out.println("\n*************************************\nTEST FAILURE in compareTrpMargs"); // System.out.println(marg1); // System.out.println(marg2); System.out.println("*************************************\nComplete model:\n\n"); model.dump (); System.out.println("*************************************\nTRP margs:\n\n"); trp.dump(); System.out.println("**************************************\nAll correct margs:\n"); for (Iterator it2 = model.variablesIterator (); it2.hasNext();) { Variable v2 = (Variable) it2.next(); brute.computeMarginals (model); System.out.println(brute.lookupMarginal(v2)); } throw e; } } public void testTrpJoint() { FactorGraph model = createTriangle(); TRP trp = new TRP().setTerminator (new TRP.IterationTerminator(25)); trp.computeMarginals(model); // For each assignment to the model, check that // TRP.lookupLogJoint and TRP.lookupJoint are consistent VarSet all = new HashVarSet (model.variablesSet ()); for (Iterator it = all.assignmentIterator(); it.hasNext();) { Assignment assn = (Assignment) it.next(); double log = trp.lookupLogJoint(assn); double prob = trp.lookupJoint(assn); assertTrue(Maths.almostEquals(Math.exp(log), prob)); } logger.info("Test trpJoint passed."); } /** Tests that running TRP doesn't inadvertantly change potentials in the original graph. */ public void testTrpNonDestructivity() { FactorGraph model = createTriangle(); TRP trp = new TRP(new TRP.IterationTerminator(25)); BruteForceInferencer brute = new BruteForceInferencer(); Factor joint1 = brute.joint(model); trp.computeMarginals(model); Factor joint2 = brute.joint(model); assertTrue(joint1.almostEquals(joint2)); logger.info("Test trpNonDestructivity passed."); } public void testTrpReuse() { TRP trp1 = new TRP(new TRP.IterationTerminator(25)); for (int i = 0; i < models.length; i++) { trp1.computeMarginals(models[i]); } // Hard to do automatically right now... logger.info("Please ensure that all instantiations above run for 25 iterations."); // Ensure that all edges touched works... UndirectedModel mdl = models[0]; final Tree tree = trp1.new AlmostRandomTreeFactory().nextTree(mdl); TRP trp2 = new TRP(new TRP.TreeFactory() { public Tree nextTree(FactorGraph mdl) { return tree; } }); trp2.computeMarginals(mdl); logger.info("Ensure that the above instantiation ran for 1000 iterations with a warning."); } private static String[] treeStrs = new String[] { "<TREE>" + " <VAR NAME='V0'>" + " <FACTOR VARS='V0 V1'>" + " <VAR NAME='V1'/>" + " </FACTOR>" + " <FACTOR VARS='V0 V2'>" + " <VAR NAME='V2'/>" + " </FACTOR>" + " </VAR>"+ "</TREE>", "<TREE>" + " <VAR NAME='V1'>" + " <FACTOR VARS='V0 V1'>" + " <VAR NAME='V0'/>" + " </FACTOR>" + " <FACTOR VARS='V1 V2'>" + " <VAR NAME='V2'/>" + " </FACTOR>" + " </VAR>"+ "</TREE>", "<TREE>" + " <VAR NAME='V0'>" + " <FACTOR VARS='V0 V1'>" + " <VAR NAME='V1'>" + " <FACTOR VARS='V1 V2'>" + " <VAR NAME='V2'/>" + " </FACTOR>" + "</VAR>"+ " </FACTOR>" + " </VAR>" + "</TREE>", "<TREE>" + " <VAR NAME='V2'>" + " <FACTOR VARS='V2 V1'>" + " <VAR NAME='V1'/>" + " </FACTOR>" + " <FACTOR VARS='V0 V2'>" + " <VAR NAME='V0'/>" + " </FACTOR>" + " </VAR>"+ "</TREE>", }; public void testTrpTreeList () { FactorGraph model = createTriangle(); model.getVariable (0).setLabel ("V0"); model.getVariable (1).setLabel ("V1"); model.getVariable (2).setLabel ("V2"); List readers = new ArrayList (); for (int i = 0; i < treeStrs.length; i++) { readers.add (new StringReader (treeStrs[i])); } TRP trp = new TRP().setTerminator (new TRP.DefaultConvergenceTerminator()) .setFactory (TRP.TreeListFactory.makeFromReaders (model, readers)); trp.computeMarginals(model); Inferencer jt = new BruteForceInferencer (); jt.computeMarginals (model); compareMarginals ("", model, trp, jt); } // Verify that variable indices are consistent in undirectected // models. public void testUndirectedIndices() { for (int mdlIdx = 0; mdlIdx < models.length; mdlIdx++) { FactorGraph mdl = models[mdlIdx]; for (Iterator it = mdl.variablesIterator (); it.hasNext();) { Variable var1 = (Variable) it.next(); Variable var2 = mdl.get(mdl.getIndex(var1)); assertTrue("Mismatch in Variable index for " + var1 + " vs " + var2 + " in model " + mdlIdx + "\n" + mdl, var1 == var2); } } logger.info("Test undirectedIndices passed."); } // Tests that TRP and max-product propagation return the same // results when TRP runs for exactly one iteration. public void testTrpViterbiEquiv() { for (int mdlIdx = 0; mdlIdx < trees.length; mdlIdx++) { FactorGraph mdl = trees[mdlIdx]; TreeBP maxprod = TreeBP.createForMaxProduct (); TRP trp = TRP.createForMaxProduct () .setTerminator (new TRP.IterationTerminator (1)); maxprod.computeMarginals (mdl); trp.computeMarginals (mdl); // TRP should return same results as viterbi for (Iterator it = mdl.variablesIterator (); it.hasNext ();) { Variable var = (Variable) it.next (); Factor maxPotBp = maxprod.lookupMarginal (var); Factor maxPotTrp = trp.lookupMarginal (var); maxPotBp.normalize (); maxPotTrp.normalize (); assertTrue ("TRP 1 iter maxprod propagation not the same as plain maxProd!\n" + "Trp " + maxPotTrp.dumpToString () + "\n Plain maxprod " + maxPotBp.dumpToString (), maxPotBp.almostEquals (maxPotTrp)); } } } public void testTrpOnTrees () { for (int mdlIdx = 0; mdlIdx < trees.length; mdlIdx++) { FactorGraph mdl = trees[mdlIdx]; Inferencer bp = new TreeBP (); Inferencer trp = new TRP ().setTerminator (new TRP.IterationTerminator (1)); bp.computeMarginals (mdl); trp.computeMarginals (mdl); int[] outcomes = new int [mdl.numVariables ()]; Assignment assn = new Assignment (mdl, outcomes); assertEquals (bp.lookupLogJoint (assn), trp.lookupLogJoint (assn), 1e-5); Arrays.fill (outcomes, 1); assn = new Assignment (mdl, outcomes); assertEquals (bp.lookupLogJoint (assn), trp.lookupLogJoint (assn), 1e-5); // TRP should return same results as viterbi for (Iterator it = mdl.variablesIterator (); it.hasNext ();) { Variable var = (Variable) it.next (); Factor maxPotBp = bp.lookupMarginal (var); Factor maxPotTrp = trp.lookupMarginal (var); maxPotBp.normalize (); maxPotTrp.normalize (); assertTrue ("TRP 1 iter bp propagation not the same as plain maxProd!\n" + "Trp " + maxPotTrp.dumpToString () + "\n Plain bp " + maxPotBp.dumpToString (), maxPotBp.almostEquals (maxPotTrp)); } } } // Tests that TRP and max-product propagation return the same // results when TRP is allowed to run to convergence. public void testTrpViterbiEquiv2() { for (int mdlIdx = 0; mdlIdx < trees.length; mdlIdx++) { FactorGraph mdl = trees[mdlIdx]; Inferencer maxprod = TreeBP.createForMaxProduct (); TRP trp = TRP.createForMaxProduct (); maxprod.computeMarginals (mdl); trp.computeMarginals (mdl); // TRP should return same results as viterbi for (Iterator it = mdl.variablesIterator (); it.hasNext ();) { Variable var = (Variable) it.next (); Factor maxPotBp = maxprod.lookupMarginal (var); Factor maxPotTrp = trp.lookupMarginal (var); assertTrue ("TRP maxprod propagation not the same as plain maxProd!\n" + "Trp " + maxPotTrp + "\n Plain maxprod " + maxPotBp, maxPotBp.almostEquals (maxPotTrp)); } } } public void testTreeViterbi() { for (int mdlIdx = 0; mdlIdx < trees.length; mdlIdx++) { FactorGraph mdl = trees[mdlIdx]; BruteForceInferencer brute = new BruteForceInferencer (); Inferencer maxprod = TreeBP.createForMaxProduct (); Factor joint = brute.joint (mdl); maxprod.computeMarginals (mdl); for (Iterator it = mdl.variablesIterator (); it.hasNext ();) { Variable var = (Variable) it.next (); Factor maxPot = maxprod.lookupMarginal (var); Factor trueMaxPot = joint.extractMax (var); maxPot.normalize (); trueMaxPot.normalize (); assertTrue ("Maximization failed! Normalized returns:\n" + maxPot + "\nTrue: " + trueMaxPot, maxPot.almostEquals (trueMaxPot)); } } logger.info("Test treeViterbi passed: " + trees.length + " models."); } public void testJtViterbi() { JunctionTreeInferencer jti = new JunctionTreeInferencer(); for (int mdlIdx = 0; mdlIdx < models.length; mdlIdx++) { UndirectedModel mdl = models[mdlIdx]; BruteForceInferencer brute = new BruteForceInferencer (); JunctionTreeInferencer maxprod = JunctionTreeInferencer.createForMaxProduct (); JunctionTree jt = maxprod.buildJunctionTree (mdl); Factor joint = brute.joint (mdl); maxprod.computeMarginals (jt); for (Iterator it = mdl.variablesIterator (); it.hasNext ();) { Variable var = (Variable) it.next (); Factor maxPotRaw = maxprod.lookupMarginal (var); Factor trueMaxPotRaw = joint.extractMax (var); Factor maxPot = maxPotRaw.duplicate().normalize (); Factor trueMaxPot = trueMaxPotRaw.duplicate().normalize (); assertTrue ("Maximization failed on model " + mdlIdx + " ! Normalized returns:\n" + maxPot.dumpToString () + "\nTrue: " + trueMaxPot.dumpToString (), maxPot.almostEquals (trueMaxPot, 0.01)); } } logger.info("Test jtViterbi passed."); } /* public void testMM() throws Exception { testQuery(); testTreeViterbi(); testTrpViterbiEquiv(); testTrpViterbiEquiv2(); testMaxMarginals(); } */ // xxx fails because of TRP termination // i.e., always succeeds if termination is IterationTermination (10) // but usually fails if termination is DefaultConvergenceTerminator (1e-12, 1000) // something about selection of random spanning trees??? public void testMaxMarginals() throws Exception { for (int mdlIdx = 0; mdlIdx < models.length; mdlIdx++) { // { int mdlIdx = 4; FactorGraph mdl = models[mdlIdx]; // if (mdlIdx != 3) { // Visualizer.showModel(mdl); // mdl.dump(); System.out.println ("***END MDL "+mdlIdx+"***"); // } BruteForceInferencer brute = new BruteForceInferencer(); Factor joint = brute.joint(mdl); // long foo = System.currentTimeMillis (); // System.out.println(foo); Inferencer[] algs = constructMaxProductInferencers (); for (int infIdx = 0; infIdx < algs.length; infIdx++) { Inferencer inf = algs[infIdx]; if (inf instanceof TRP) ((TRP)inf).setRandomSeed(42); inf.computeMarginals(mdl); for (Iterator it = mdl.variablesIterator (); it.hasNext();) { Variable var = (Variable) it.next(); Factor maxPot = inf.lookupMarginal(var); Factor trueMaxPot = joint.extractMax(var); if (maxPot.argmax() != trueMaxPot.argmax()) { logger.warning("Argmax not equal on model " + mdlIdx + " inferencer " + inf + " !\n Factors:\nReturned: " + maxPot + "\nTrue: " + trueMaxPot); System.err.println("Dump of model " + mdlIdx + " ***"); mdl.dump (); assertTrue (maxPot.argmax() == trueMaxPot.argmax()); } } } } logger.info("Test maxMarginals passed."); } public void testBeliefPropagation() { for (int mdlIdx = 0; mdlIdx < trees.length; mdlIdx++) { FactorGraph mdl = trees[mdlIdx]; Inferencer prop = new TreeBP (); // System.out.println(mdl); prop.computeMarginals(mdl); for (Iterator it = mdl.variablesIterator (); it.hasNext();) { Variable var = (Variable) it.next(); Factor marg1 = treeMargs[mdlIdx][mdl.getIndex(var)]; Factor marg2 = prop.lookupMarginal(var); try { assertTrue("Test failed on graph " + mdlIdx + " vertex " + var + "\n" + "Model: " + mdl + "\nExpected: " + marg1.dumpToString () + "\nActual: " + marg2.dumpToString (), marg1.almostEquals(marg2, 0.011)); } catch (AssertionFailedError e) { System.out.println (e.getMessage ()); System.out.println("*******************************************\nMODEL:\n"); mdl.dump (); System.out.println("*******************************************\nMESSAGES:\n"); ((AbstractBeliefPropagation)prop).dump(); throw e; } } } logger.info("Test beliefPropagation passed."); } public void testBpJoint () { for (int mdlIdx = 0; mdlIdx < trees.length; mdlIdx++) { FactorGraph mdl = trees[mdlIdx]; Inferencer bp = new TreeBP (); BruteForceInferencer brute = new BruteForceInferencer (); brute.computeMarginals (mdl); bp.computeMarginals (mdl); for (AssignmentIterator it = mdl.assignmentIterator (); it.hasNext();) { Assignment assn = (Assignment) it.next (); assertEquals (brute.lookupJoint (assn), bp.lookupJoint (assn), 1e-15); } } } // Eventially this should be folded into testMarginals, testJoint, etc. public void testDirectedJt () { DirectedModel bn = createDirectedModel (); BruteForceInferencer brute = new BruteForceInferencer (); brute.computeMarginals (bn); JunctionTreeInferencer jt = new JunctionTreeInferencer (); jt.computeMarginals (bn); compareMarginals ("Error comparing junction tree to brute on directed model!", bn, brute, jt); } private DirectedModel createDirectedModel () { int NUM_OUTCOMES = 2; cc.mallet.util.Randoms random = new cc.mallet.util.Randoms (13413); Dirichlet dirichlet = new Dirichlet (NUM_OUTCOMES, 1.0); double[] pA = dirichlet.randomVector (random); double[] pB = dirichlet.randomVector (random); TDoubleArrayList pC = new TDoubleArrayList (NUM_OUTCOMES * NUM_OUTCOMES * NUM_OUTCOMES); for (int i = 0; i < (NUM_OUTCOMES * NUM_OUTCOMES); i++) { pC.add (dirichlet.randomVector (random)); } Variable[] vars = new Variable[] { new Variable (NUM_OUTCOMES), new Variable (NUM_OUTCOMES), new Variable (NUM_OUTCOMES) }; DirectedModel mdl = new DirectedModel (); mdl.addFactor (new CPT (new TableFactor (vars[0], pA), vars[0])); mdl.addFactor (new CPT (new TableFactor (vars[1], pB), vars[1])); mdl.addFactor (new CPT (new TableFactor (vars, pC.toNativeArray ()), vars[2])); return mdl; } private void compareMarginals (String msg, FactorGraph fg, Inferencer inf1, Inferencer inf2) { for (int i = 0; i < fg.numVariables (); i++) { Variable var = fg.get (i); Factor ptl1 = inf1.lookupMarginal (var); Factor ptl2 = inf2.lookupMarginal (var); assertTrue (msg + "\n" + ptl1.dumpToString () + "\n" + ptl2.dumpToString (), ptl1.almostEquals (ptl2, 1e-5)); } } protected void setUp() { modelsList = createTestModels(); createTestTrees(); models = (UndirectedModel[]) modelsList.toArray (new UndirectedModel[]{}); computeTestTreeMargs(); } public void testMultiply() { TableFactor p1 = new TableFactor (new Variable[]{}); System.out.println(p1); Variable[] vars = new Variable[]{ new Variable(2), new Variable(2), }; double[] probs = new double[]{1, 3, 5, 6}; TableFactor p2 = new TableFactor (vars, probs); Factor p3 = p1.multiply(p2); assertTrue("Should be equal: " + p2 + "\n" + p3, p2.almostEquals(p3)); } /* TODO: Not sure how to test this anymore. // Test multiplication of potentials where variables are in // a different order public void testMultiplication2 () { Variable[] vars = new Variable[] { new Variable (2), new Variable (2), }; double[] probs1 = new double[] { 2, 4, 1, 6 }; double[] probs2a = new double[] { 3, 7, 6, 5 }; double[] probs2b = new double[] { 3, 6, 7, 5 }; MultinomialPotential ptl1a = new MultinomialPotential (vars, probs1); MultinomialPotential ptl1b = new MultinomialPotential (vars, probs1); MultinomialPotential ptl2a = new MultinomialPotential (vars, probs2a); Variable[] vars2 = new Variable[] { vars[1], vars[0], }; MultinomialPotential ptl2b = new MultinomialPotential (vars2, probs2b); ptl1a.multiplyBy (ptl2a); ptl1b.multiplyBy (ptl2b); assertTrue (ptl1a.almostEquals (ptl1b)); } */ public void testLogMarginalize () { FactorGraph mdl = models [0]; Iterator it = mdl.variablesIterator (); Variable v1 = (Variable) it.next(); Variable v2 = (Variable) it.next(); Random rand = new Random (3214123); for (int i = 0; i < 10; i++) { Factor ptl = randomEdgePotential (rand, v1, v2); Factor logmarg1 = new LogTableFactor ((AbstractTableFactor) ptl).marginalize(v1); Factor marglog1 = new LogTableFactor((AbstractTableFactor) ptl.marginalize(v1)); assertTrue ("LogMarg failed! Correct: "+marglog1+" Log-marg: "+logmarg1, logmarg1.almostEquals (marglog1)); Factor logmarg2 = new LogTableFactor ((AbstractTableFactor) ptl).marginalize(v2); Factor marglog2 = new LogTableFactor((AbstractTableFactor) ptl.marginalize(v2)); assertTrue (logmarg2.almostEquals (marglog2)); } } public void testLogNormalize () { FactorGraph mdl = models [0]; Iterator it = mdl.variablesIterator (); Variable v1 = (Variable) it.next(); Variable v2 = (Variable) it.next(); Random rand = new Random (3214123); for (int i = 0; i < 10; i++) { Factor ptl = randomEdgePotential (rand, v1, v2); Factor norm1 = new LogTableFactor((AbstractTableFactor) ptl); Factor norm2 = ptl.duplicate(); norm1.normalize(); norm2.normalize(); assertTrue ("LogNormalize failed! Correct: "+norm2+" Log-normed: "+norm1, norm1.almostEquals (norm2)); } } public void testSumLogProb () { java.util.Random rand = new java.util.Random (3214123); for (int i = 0; i < 10; i++) { double v1 = rand.nextDouble(); double v2 = rand.nextDouble(); double sum1 = Math.log (v1 + v2); double sum2 = Maths.sumLogProb (Math.log(v1), Math.log (v2)); // System.out.println("Summing "+v1+" + "+v2); assertEquals (sum1, sum2, 0.00001); } } public void testInfiniteCost() { Variable[] vars = new Variable[3]; for (int i = 0; i < vars.length; i++) { vars[i] = new Variable (2); } FactorGraph mdl = new FactorGraph (vars); mdl.addFactor (vars[0], vars[1], new double[] { 2, 6, 4, 8 }); mdl.addFactor (vars[1], vars[2], new double[] { 1, 0, 0, 1 }); mdl.dump (); Inferencer bp = new TreeBP (); bp.computeMarginals (mdl); //below should be true, except potentials have different ranges. //assertTrue (bp.lookupMarginal(vars[1]).almostEquals (bp.lookupMarginal(vars[2]))); } public void testJtCaching() { // clear all caches for (int i = 0; i < models.length; i++) { FactorGraph model = models[i]; model.setInferenceCache (JunctionTreeInferencer.class, null); } Factor[][] margs = new Factor[models.length][]; long stime1 = new Date().getTime(); for (int i = 0; i < models.length; i++) { FactorGraph model = models[i]; JunctionTreeInferencer inf = new JunctionTreeInferencer(); inf.computeMarginals(model); margs[i] = new Factor[model.numVariables ()]; Iterator it = model.variablesIterator (); int j = -1; while (it.hasNext()) { Variable var = (Variable) it.next(); j++; margs[i][j] = inf.lookupMarginal(var); } } long etime1 = new Date().getTime(); long diff1 = etime1 - stime1; logger.info ("Pre-cache took "+diff1+" ms."); long stime2 = new Date().getTime(); for (int i = 0; i < models.length; i++) { FactorGraph model = models[i]; JunctionTreeInferencer inf = new JunctionTreeInferencer(); inf.computeMarginals(model); Iterator it = model.variablesIterator (); int j = -1; while (it.hasNext()) { Variable var = (Variable) it.next(); j++; assertTrue (margs[i][j].almostEquals (inf.lookupMarginal (var))); } } long etime2 = new Date().getTime(); long diff2 = etime2 - stime2; logger.info ("Post-cache took "+diff2+" ms."); // assertTrue (diff2 < diff1); } public void testFindVariable () { FactorGraph mdl = models [0]; Variable[] vars = new Variable [mdl.numVariables ()]; Iterator it = mdl.variablesIterator (); while (it.hasNext()) { Variable var = (Variable) it.next(); String name = new String (var.getLabel()); assertTrue (var == mdl.findVariable (name)); } assertTrue (mdl.findVariable ("xsdfasdf") == null); } public void testDefaultLookupMarginal () { Inferencer inf = new TreeBP (); FactorGraph mdl = trees[JT_CHAIN_TEST_TREE]; Variable var = mdl.get (0); inf.computeMarginals (mdl); // Previously: UnsupportedOperationException // Exptected: default to lookupMarginal (Variable) for clique of size 1 VarSet varSet = new HashVarSet (new Variable[] { var }); Factor ptl1 = inf.lookupMarginal (varSet); Factor ptl2 = inf.lookupMarginal (var); assertTrue (ptl1.almostEquals (ptl2)); Variable var2 = mdl.get (1); Variable var3 = mdl.get (2); VarSet c2 = new HashVarSet (new Variable[] { var, var2, var3 }); try { inf.lookupMarginal (c2); fail ("Expected an UnsupportedOperationException with clique "+c2); } catch (UnsupportedOperationException e) {} } // Eventually this should be moved to models[], but TRP currently chokes on disconnected // model public void testDisconnectedModel () { Variable[] vars = new Variable [4]; for (int i = 0; i < vars.length; i++) { vars [i] = new Variable (2); } FactorGraph mdl = new UndirectedModel (vars); Random r = new Random (67); Factor[] ptls = new Factor [4]; Factor[] normed = new Factor [4]; for (int i = 0; i < vars.length; i++) { ptls[i] = randomNodePotential (r, vars[i]); normed[i] = ptls[i].duplicate(); normed[i].normalize(); mdl.addFactor (ptls[i]); } mdl.dump (); Inferencer inf = new LoopyBP (); inf.computeMarginals (mdl); for (int i = 0; i < vars.length; i++) { Factor marg = inf.lookupMarginal (vars[i]); assertTrue ("Marginals not equal!\n True: "+normed[i]+"\n Returned "+marg, marg.almostEquals (normed[i])); } for (AssignmentIterator it = mdl.assignmentIterator (); it.hasNext();) { Assignment assn = (Assignment) it.next (); double trueProb = 1.0; for (int i = 0; i < vars.length; i++) trueProb *= normed[i].value (assn); assertEquals (trueProb, inf.lookupJoint (assn), 1e-5); } } public void timeMarginalization () { java.util.Random r = new java.util.Random (7732847); Variable[] vars = new Variable[] { new Variable (2), new Variable (2), }; TableFactor ptl = randomEdgePotential (r, vars[0], vars[1]); long stime = System.currentTimeMillis (); for (int i = 0; i < 1000; i++) { Factor marg = ptl.marginalize (vars[0]); Factor marg2 = ptl.marginalize (vars[1]); } long etime = System.currentTimeMillis (); logger.info ("Marginalization (2-outcome) took "+(etime-stime)+" ms."); Variable[] vars45 = new Variable[] { new Variable (45), new Variable (45), }; TableFactor ptl45 = randomEdgePotential (r, vars45[0], vars45[1]); stime = System.currentTimeMillis(); for (int i = 0; i < 1000; i++) { Factor marg = ptl45.marginalize (vars45[0]); Factor marg2 = ptl45.marginalize (vars45[1]); } etime = System.currentTimeMillis(); logger.info ("Marginalization (45-outcome) took "+(etime-stime)+" ms."); } // using this for profiling public void runJunctionTree () { for (int mdlIdx = 0; mdlIdx < models.length; mdlIdx++) { FactorGraph model = models[mdlIdx]; JunctionTreeInferencer inf = new JunctionTreeInferencer(); inf.computeMarginals(model); Iterator it = model.variablesIterator (); while (it.hasNext()) { Variable var = (Variable) it.next(); inf.lookupMarginal (var); } } } public void testDestructiveAssignment () { Variable vars[] = { new Variable(2), new Variable (2), }; Assignment assn = new Assignment (vars, new int[] { 0, 1 }); assertEquals (0, assn.get (vars[0])); assertEquals (1, assn.get (vars[1])); assn.setValue (vars[0], 1); assertEquals (1, assn.get (vars[0])); assertEquals (1, assn.get (vars[1])); } public void testLoopyConvergence () { Random r = new Random (67); FactorGraph mdl = createRandomGrid (5, 5, 2, r); LoopyBP loopy = new LoopyBP (); loopy.computeMarginals (mdl); assertTrue (loopy.iterationsUsed() > 8); } public void testSingletonGraph () { Variable v = new Variable (2); FactorGraph mdl = new FactorGraph (new Variable[] { v }); mdl.addFactor (new TableFactor (v, new double[] { 1, 2 })); TRP trp = new TRP (); trp.computeMarginals (mdl); Factor ptl = trp.lookupMarginal (v); double[] dbl = ((AbstractTableFactor) ptl).toValueArray (); assertEquals (2, dbl.length); assertEquals (0.33333, dbl[0], 1e-4); assertEquals (0.66666, dbl[1], 1e-4); } public void testLoopyCaching () { FactorGraph mdl1 = models[4]; FactorGraph mdl2 = models[5]; Variable var = mdl1.get (0); LoopyBP inferencer = new LoopyBP (); inferencer.setUseCaching (true); inferencer.computeMarginals (mdl1); Factor origPtl = inferencer.lookupMarginal (var); assertTrue (2 < inferencer.iterationsUsed ()); // confuse the inferencer inferencer.computeMarginals (mdl2); // make sure we have cached, correct results inferencer.computeMarginals (mdl1); Factor sndPtl = inferencer.lookupMarginal (var); // note that we can't use an epsilon here, that's less than our convergence criteria. assertTrue ("Huh? Original potential:"+origPtl+"After: "+sndPtl, origPtl.almostEquals (sndPtl, 1e-4)); assertEquals (1, inferencer.iterationsUsed ()); } public void testJunctionTreeConnectedFromRoot () { JunctionTreeInferencer jti = new JunctionTreeInferencer (); jti.computeMarginals (models[0]); jti.computeMarginals (models[1]); JunctionTree jt = jti.lookupJunctionTree (); List reached = new ArrayList (); LinkedList queue = new LinkedList (); queue.add (jt.getRoot ()); while (!queue.isEmpty ()) { VarSet current = (VarSet) queue.removeFirst (); queue.addAll (jt.getChildren (current)); reached.add (current); } assertEquals (jt.clusterPotentials ().size (), reached.size()); } public void testBpLargeModels () { Timing timing = new Timing (); // UndirectedModel mdl = RandomGraphs.createUniformChain (800); FactorGraph mdl = RandomGraphs.createUniformChain (8196); timing.tick ("Model creation"); AbstractBeliefPropagation inf = new LoopyBP (); try { inf.computeMarginals (mdl); } catch (OutOfMemoryError e) { System.out.println ("OUT OF MEMORY: Messages sent "+inf.getTotalMessagesSent ()); throw e; } timing.tick ("Inference time (Random sched BP)"); } public void testTrpLargeModels () { Timing timing = new Timing (); // UndirectedModel mdl = RandomGraphs.createUniformChain (800); FactorGraph mdl = RandomGraphs.createUniformChain (8192); timing.tick ("Model creation"); Inferencer inf = new TRP (); inf.computeMarginals (mdl); timing.tick ("Inference time (TRP)"); } /* public void testBpDualEdgeFactor () { Variable[] vars = new Variable[] { new Variable (2), new Variable (2), new Variable (2), new Variable (2), }; Random r = new Random (); Factor tbl1 = createEdgePtl (vars[0], vars[1], r); Factor tbl2a = createEdgePtl (vars[1], vars[2], r); Factor tbl2b = createEdgePtl (vars[1], vars[2], r); Factor tbl3 = createEdgePtl (vars[2], vars[3], r); FactorGraph fg = new FactorGraph (vars); fg.addFactor (tbl1); fg.addFactor (tbl2a); fg.addFactor (tbl2b); fg.addFactor (tbl3); Inferencer inf = new TRP (); inf.computeMarginals (fg); VarSet vs = tbl2a.varSet (); Factor marg1 = inf.lookupMarginal (vs); Factor prod = TableFactor.multiplyAll (fg.factors ()); Factor marg2 = prod.marginalize (vs); marg2.normalize (); assertTrue ("Factors not equal! BP: "+marg1.dumpToString ()+"\n EXACT: "+marg2.dumpToString (), marg1.almostEquals (marg2)); } */ private Factor createEdgePtl (Variable var1, Variable var2, Random r) { double[] dbls = new double [4]; for (int i = 0; i < dbls.length; i++) { dbls[i] = r.nextDouble (); } return new TableFactor (new Variable[] { var1, var2 }, dbls); } private String gridStr = "VAR alpha u : CONTINUOUS\n" + "alpha ~ Uniform -1.0 1.0\n" + "u ~ Uniform -2.0 2.0\n" + "x00 ~ Unary u\n" + "x10 ~ Unary u\n" + "x01 ~ Unary u\n" + "x11 ~ Unary u\n" + "x00 x01 ~ Potts alpha\n" + "x00 x10 ~ Potts alpha\n" + "x01 x11 ~ Potts alpha\n" + "x10 x11 ~ Potts alpha\n"; public void testJtConstant () throws IOException { FactorGraph masterFg = new ModelReader ().readModel (new BufferedReader (new StringReader (gridStr))); JunctionTreeInferencer jt = new JunctionTreeInferencer (); Assignment assn = masterFg.sampleContinuousVars (new cc.mallet.util.Randoms (3214)); FactorGraph fg = (FactorGraph) masterFg.slice (assn); jt.computeMarginals (fg); } public static Test suite() { return new TestSuite(TestInference.class); } public static void main(String[] args) throws Exception { TestSuite theSuite; if (args.length > 0) { theSuite = new TestSuite(); for (int i = 0; i < args.length; i++) { theSuite.addTest(new TestInference(args[i])); } } else { theSuite = (TestSuite) suite(); } junit.textui.TestRunner.run(theSuite); } }