/* Copyright (C) 2006 Univ. of Massachusetts Amherst, Computer Science Dept. This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit). http://www.cs.umass.edu/~mccallum/mallet This software is provided under the terms of the Common Public License, version 1.0, as published by http://www.opensource.org. For further information, see the file `LICENSE' included with this distribution. */ package cc.mallet.grmm.test; import junit.framework.TestCase; import junit.framework.Test; import junit.framework.TestSuite; import java.io.BufferedReader; import java.io.StringReader; import java.io.IOException; import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Arrays; import cc.mallet.grmm.inference.RandomGraphs; import cc.mallet.grmm.types.*; import cc.mallet.grmm.util.ModelReader; import cc.mallet.types.MatrixOps; import cc.mallet.util.Randoms; import cc.mallet.util.Timing; /** * Created: Mar 17, 2005 * * @author <A HREF="mailto:casutton@cs.umass.edu>casutton@cs.umass.edu</A> * @version $Id: TestFactorGraph.java,v 1.1 2007/10/22 21:37:40 mccallum Exp $ */ public class TestFactorGraph extends TestCase { private Variable[] vars; private TableFactor tbl1; private TableFactor tbl2; private TableFactor tbl3; private LogTableFactor ltbl1; private LogTableFactor ltbl2; public TestFactorGraph (String name) { super (name); } protected void setUp () throws Exception { vars = new Variable[] { new Variable (2), new Variable (2), new Variable (2), new Variable (2), }; tbl1 = new TableFactor (new Variable[] { vars[0], vars[1] }, new double[] { 0.8, 0.1, 0.1, 0.8 }); tbl2 = new TableFactor (new Variable[] { vars[1], vars[2] }, new double[] { 0.2, 0.7, 0.8, 0.2 }); tbl3 = new TableFactor (new Variable[] { vars[2], vars[3] }, new double[] { 0.2, 0.4, 0.6, 0.4 }); ltbl1 = LogTableFactor.makeFromValues (new Variable[] { vars[0], vars[1] }, new double[] { 0.8, 0.1, 0.1, 0.8 }); ltbl2 = LogTableFactor.makeFromValues (new Variable[] { vars[1], vars[2] }, new double[] { 0.2, 0.7, 0.8, 0.2 }); } public void testMultiplyBy () { FactorGraph fg = new FactorGraph (); fg.multiplyBy (tbl1); fg.multiplyBy (tbl2); assertEquals (2, fg.factors ().size()); assertTrue (fg.factors ().contains (tbl1)); assertTrue (fg.factors ().contains (tbl2)); assertEquals (3, fg.numVariables ()); assertTrue (fg.variablesSet ().contains (vars[0])); assertTrue (fg.variablesSet ().contains (vars[1])); assertTrue (fg.variablesSet ().contains (vars[2])); } public void testNumVariables () { FactorGraph fg = new FactorGraph (); fg.multiplyBy (tbl1); fg.multiplyBy (tbl2); assertEquals (3, fg.numVariables ()); } public void testMultiply () { FactorGraph fg = new FactorGraph (); fg.multiplyBy (tbl1); fg.multiplyBy (tbl2); FactorGraph fg2 = (FactorGraph) fg.multiply (tbl3); assertEquals (2, fg.factors ().size()); assertEquals (3, fg2.factors ().size()); assertTrue (!fg.factors ().contains (tbl3)); assertTrue (fg2.factors ().contains (tbl3)); } public void testValue () { FactorGraph fg = new FactorGraph (); fg.multiplyBy (tbl1); fg.multiplyBy (tbl2); Assignment assn = new Assignment (fg.varSet ().toVariableArray (), new int[] { 0, 1, 0 }); assertEquals (0.08, fg.value (assn), 1e-5); } public void testMarginalize () { FactorGraph fg = new FactorGraph (); fg.multiplyBy (tbl1); fg.multiplyBy (tbl2); Factor marg = fg.marginalize (vars[1]); Factor expected = new TableFactor (vars[1], new double[] { 0.81, 0.9 }); assertTrue (expected.almostEquals (marg)); } public void testSum () { FactorGraph fg = new FactorGraph (); fg.multiplyBy (tbl1); fg.multiplyBy (tbl2); assertEquals (1.71, fg.sum (), 1e-5); } public void testNormalize () { FactorGraph fg = new FactorGraph (); fg.multiplyBy (tbl1); fg.multiplyBy (tbl2); fg.normalize (); assertEquals (1.0, fg.sum(), 1e-5); } public void testLogNormalize () { FactorGraph fg = new FactorGraph (); fg.multiplyBy (ltbl1); fg.multiplyBy (ltbl2); fg.normalize (); assertEquals (1.0, fg.sum(), 1e-5); } public void testEmbeddedFactorGraph () { FactorGraph embeddedFg = new FactorGraph (); embeddedFg.multiplyBy (tbl1); embeddedFg.multiplyBy (tbl2); FactorGraph fg = new FactorGraph (); fg.multiplyBy (embeddedFg); fg.multiplyBy (tbl3); assertEquals (4, fg.varSet ().size ()); assertEquals (2, fg.factors ().size ()); Assignment assn = new Assignment (fg.varSet ().toVariableArray (), new int [4]); assertEquals (0.032, fg.value (assn), 1e-5); AbstractTableFactor tbl = fg.asTable (); assertEquals (4, tbl.varSet ().size ()); assertEquals (0.032, tbl.value (assn), 1e-5); } public void testAsTable () { FactorGraph fg = new FactorGraph (); fg.multiplyBy (tbl1); fg.multiplyBy (tbl2); AbstractTableFactor actual = fg.asTable (); AbstractTableFactor expected = (AbstractTableFactor) tbl1.multiply (tbl2); assertTrue (expected.almostEquals (actual)); } public void testTableTimesFg () { FactorGraph fg = new FactorGraph (); fg.multiplyBy (tbl1); fg.multiplyBy (tbl2); Factor product = tbl3.multiply (fg); assertTrue (product instanceof AbstractTableFactor); assertEquals (4, product.varSet ().size ()); Assignment assn = new Assignment (product.varSet ().toVariableArray (), new int [4]); assertEquals (0.032, product.value (assn), 1e-5); } public void testLogTableTimesFg () { FactorGraph fg = new FactorGraph (); fg.multiplyBy (tbl1); fg.multiplyBy (tbl2); Factor product = ltbl1.multiply (fg); assertTrue (product instanceof AbstractTableFactor); assertEquals (3, product.varSet ().size ()); Assignment assn = new Assignment (product.varSet ().toVariableArray (), new int [3]); assertEquals (0.128, product.value (assn), 1e-5); } public void testRemove () { FactorGraph fg = new FactorGraph (); fg.multiplyBy (tbl1); fg.multiplyBy (tbl2); assertEquals (2, fg.getDegree (vars[1])); fg.divideBy (tbl1); assertEquals (2, fg.varSet ().size ()); Assignment assn = new Assignment (fg.varSet ().toVariableArray (), new int [2]); assertEquals (0.2, fg.value (assn), 1e-5); int nvs = 0; for (Iterator it = fg.varSetIterator (); it.hasNext(); it.next ()) { nvs++; } assertEquals (1, nvs); assertEquals (1, fg.getDegree (vars[1])); assertTrue (fg.get (0) != fg.get (1)); assertEquals (vars[1], fg.get (0)); assertEquals (vars[2], fg.get (1)); } public void testRedundantDomains () { FactorGraph fg = new FactorGraph (); fg.multiplyBy (tbl1); fg.multiplyBy (tbl2); fg.multiplyBy (ltbl1); assertEquals (3, fg.varSet ().size ()); assertEquals ("Wrong factors in FG, was "+fg.dumpToString (), 3, fg.factors ().size ()); Assignment assn = new Assignment (fg.varSet ().toVariableArray (), new int [3]); assertEquals (0.128, fg.value (assn), 1e-5); } private static String uniformMdlstr = "VAR sigma u1 u2 : continuous\n" + "VAR x1 x2 : 2\n" + "sigma ~ Uniform -0.5 0.5\n" + "u1 ~ Uniform -0.5 0.5\n" + "u2 ~ Uniform -0.5 0.5\n" + "x1 x2 ~ BinaryPair sigma\n" + "x1 ~ Unary u1\n" + "x2 ~ Unary u2\n"; public void testContinousSample () throws IOException { ModelReader reader = new ModelReader (); FactorGraph fg = reader.readModel (new BufferedReader (new StringReader (uniformMdlstr))); Randoms r = new Randoms (324143); Assignment allAssn = new Assignment (); for (int i = 0; i < 10000; i++) { Assignment row = fg.sample (r); allAssn.addRow (row); } Variable x1 = fg.findVariable ("x1"); Assignment assn1 = (Assignment) allAssn.marginalize (x1); int[] col = assn1.getColumnInt (x1); double mean = MatrixOps.sum (col) / ((double)col.length); assertEquals (0.5, mean, 0.025); } private static String uniformMdlstr2 = "VAR sigma u1 u2 : continuous\n" + "VAR x1 x2 : 2\n" + "sigma ~ Normal 0.0 0.2\n" + "u1 ~ Normal 0.0 0.2\n" + "u2 ~ Normal 0.0 0.2\n" + "x1 x2 ~ BinaryPair sigma\n" + "x1 ~ Unary u1\n" + "x2 ~ Unary u2\n"; public void testContinousSample2 () throws IOException { ModelReader reader = new ModelReader (); FactorGraph fg = reader.readModel (new BufferedReader (new StringReader (uniformMdlstr2))); Randoms r = new Randoms (324143); Assignment allAssn = new Assignment (); for (int i = 0; i < 10000; i++) { Assignment row = fg.sample (r); allAssn.addRow (row); } Variable x1 = fg.findVariable ("x2"); Assignment assn1 = (Assignment) allAssn.marginalize (x1); int[] col = assn1.getColumnInt (x1); double mean = MatrixOps.sum (col) / ((double)col.length); assertEquals (0.5, mean, 0.01); Variable x2 = fg.findVariable ("x2"); Assignment assn2 = (Assignment) allAssn.marginalize (x2); int[] col2 = assn2.getColumnInt (x2); double mean2 = MatrixOps.sum (col2) / ((double)col2.length); assertEquals (0.5, mean2, 0.025); } public void testAllFactorsOf () throws IOException { ModelReader reader = new ModelReader (); FactorGraph fg = reader.readModel (new BufferedReader (new StringReader (uniformMdlstr2))); Variable var = new Variable (2); var.setLabel ("v0"); List lst = fg.allFactorsOf (var); assertEquals (0, lst.size ()); } public void testAllFactorsOf2 () throws IOException { Variable x1 = new Variable (2); Variable x2 = new Variable (2); FactorGraph fg = new FactorGraph (); fg.addFactor (new TableFactor (x1)); fg.addFactor (new TableFactor (x2)); fg.addFactor (new TableFactor (new Variable[] { x1, x2 })); List lst = fg.allFactorsOf (x1); assertEquals (1, lst.size ()); for (Iterator it = lst.iterator (); it.hasNext ();) { Factor f = (Factor) it.next (); assertEquals (1, f.varSet().size()); assertTrue (f.varSet ().contains (x1)); } HashVarSet vs = new HashVarSet (new Variable[]{x1, x2}); List lst2 = fg.allFactorsOf (vs); assertEquals (1, lst2.size ()); Factor f = (Factor) lst2.get (0); assertTrue (f.varSet ().equals (vs)); } public void testAsTable2 () { Factor f1 = new TableFactor (vars[0], new double[] { 0.6, 0.4 }); Factor f2 = new ConstantFactor (2.0); FactorGraph fg = new FactorGraph (new Factor[] { f1, f2 }); AbstractTableFactor tbl = fg.asTable (); assertTrue (Arrays.equals(new double[] { 0.6 * 2.0, 0.4 * 2.0 }, tbl.toValueArray ())); } public void testClear () { FactorGraph fg = new FactorGraph (); fg.multiplyBy (tbl1); fg.multiplyBy (tbl2); assertEquals (3, fg.numVariables ()); assertEquals (2, fg.factors ().size ()); fg.clear (); assertEquals (0, fg.numVariables ()); assertEquals (0, fg.factors ().size ()); for (int vi = 0; vi < tbl1.varSet ().size (); vi++) { assertTrue (!fg.containsVar (tbl1.getVariable (vi))); } for (int vi = 0; vi < tbl2.varSet ().size (); vi++) { assertTrue (!fg.containsVar (tbl2.getVariable (vi))); } } public void testCacheExpanding () { FactorGraph baseFg = RandomGraphs.randomFrustratedGrid (25, 1.0, new java.util.Random (3324879)); Assignment assn = new Assignment (baseFg, new int[baseFg.numVariables ()]); double val = baseFg.logValue (assn); Timing timing = new Timing (); int numReps = 100; for (int rep = 0; rep < numReps; rep++) { FactorGraph fg = new FactorGraph (baseFg.numVariables ()); for (int fi = 0; fi < baseFg.factors().size(); fi++) { fg.multiplyBy (baseFg.getFactor (fi)); } assertEquals (val, fg.logValue (assn), 1e-5); } long time1 = timing.elapsedTime (); timing.tick ("No-expansion time"); for (int rep = 0; rep < numReps; rep++) { FactorGraph fg = new FactorGraph (); for (int fi = 0; fi < baseFg.factors().size(); fi++) { fg.multiplyBy (baseFg.getFactor (fi)); } assertEquals (val, fg.logValue (assn), 1e-5); } long time2 = timing.elapsedTime (); timing.tick ("With-expansion time"); assertTrue (time1 < time2); } public static Test suite () { return new TestSuite (TestFactorGraph.class); } public static void main (String[] args) throws Throwable { TestSuite theSuite; if (args.length > 0) { theSuite = new TestSuite (); for (int i = 0; i < args.length; i++) { theSuite.addTest (new TestFactorGraph (args[i])); } } else { theSuite = (TestSuite) TestFactorGraph.suite (); } junit.textui.TestRunner.run (theSuite); } }