/* Copyright (C) 2003 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.grmm.test;
import junit.framework.*;
import junit.textui.TestRunner;
import java.util.Random;
import java.util.Iterator;
import java.io.File;
import java.io.FileWriter;
import java.io.PrintWriter;
import java.io.IOException;
import cc.mallet.grmm.inference.*;
import cc.mallet.grmm.types.*;
/**
* Created: Mar 26, 2005
*
* @author <A HREF="mailto:casutton@cs.umass.edu>casutton@cs.umass.edu</A>
* @version $Id: TestRandomGraphs.java,v 1.1 2007/10/22 21:37:40 mccallum Exp $
*/
public class TestRandomGraphs extends TestCase {
public TestRandomGraphs (String name)
{
super (name);
}
public static Test suite ()
{
return new TestSuite (TestRandomGraphs.class);
}
public void testAttractiveGraphs () throws IOException
{
Random r = new Random (31421);
for (int rep = 0; rep < 5; rep++) {
FactorGraph mdl = RandomGraphs.randomAttractiveGrid (5, 0.5, r);
System.out.println ("************"); mdl.dump ();
TRP trp = TRP.createForMaxProduct ();
trp.computeMarginals (mdl);
Assignment assn = trp.bestAssignment ();
PrintWriter out = new PrintWriter (new FileWriter (new File ("attract."+rep+".dot")));
mdl.printAsDot (out, assn);
out.close ();
}
}
public void testRepulsiveGraphs () throws IOException
{
Random r = new Random (31421);
for (int rep = 0; rep < 5; rep++) {
FactorGraph mdl = RandomGraphs.randomRepulsiveGrid (5, 0.5, r);
TRP trp = TRP.createForMaxProduct ();
trp.computeMarginals (mdl);
Assignment assn = trp.bestAssignment ();
PrintWriter out = new PrintWriter (new FileWriter (new File ("repulse."+rep+".dot")));
mdl.printAsDot (out, assn);
out.close ();
}
}
public void testFrustratedGraphs () throws IOException
{
Random r = new Random (31421);
for (int rep = 0; rep < 5; rep++) {
FactorGraph mdl = RandomGraphs.randomFrustratedGrid (5, 0.5, r);
TRP trp = TRP.createForMaxProduct ();
trp.computeMarginals (mdl);
Assignment assn = trp.bestAssignment ();
PrintWriter out = new PrintWriter (new FileWriter (new File ("mixed."+rep+".dot")));
mdl.printAsDot (out, assn);
out.close ();
}
}
public void testFrustratedIsGrid () throws IOException
{
Random r = new Random (0);
for (int rep = 0; rep < 100; rep++) {
FactorGraph mdl = RandomGraphs.randomFrustratedGrid (10, 1.0, r);
// 100 variable factors + 180 edge factors
assertEquals (280, mdl.factors ().size ());
assertEquals (100, mdl.numVariables ());
int[] counts = new int [6];
for (int i = 0; i < mdl.numVariables (); i++) {
Variable var = mdl.get (i);
int degree = mdl.getDegree (var);
assertTrue ("Variable "+var+" has degree "+degree, (degree >= 3) && (degree <= 5));
counts[degree]++;
}
assertEquals (counts[0], 0);
assertEquals (counts[1], 0);
assertEquals (counts[2], 0);
assertEquals (counts[3], 4);
assertEquals (counts[4], 32);
assertEquals (counts[5], 64);
}
}
public void testUniformGrid ()
{
UndirectedGrid grid = (UndirectedGrid) RandomGraphs.createUniformGrid (3);
assertEquals (9, grid.numVariables ());
assertEquals (12, grid.factors ().size());
BruteForceInferencer inf = new BruteForceInferencer ();
TableFactor joint = (TableFactor) inf.joint (grid);
for (AssignmentIterator it = joint.assignmentIterator (); it.hasNext(); it.advance ()) {
assertEquals (-9 * Math.log (2), joint.logValue (it), 1e-3);
}
}
public void testUniformGridWithObservations ()
{
FactorGraph grid = RandomGraphs.createGridWithObs (
new RandomGraphs.UniformFactorGenerator (),
new RandomGraphs.UniformFactorGenerator (),
3);
assertEquals (18, grid.numVariables ());
assertEquals (12 + 9, grid.factors ().size());
Inferencer inf = new LoopyBP ();
inf.computeMarginals (grid);
for (Iterator it = grid.variablesIterator (); it.hasNext ();) {
Variable var = (Variable) it.next ();
Factor marg = inf.lookupMarginal (var);
for (AssignmentIterator assnIt = marg.assignmentIterator (); assnIt.hasNext();) {
assertEquals (-Math.log (2), marg.logValue (assnIt), 1e-3);
assnIt.advance ();
}
}
}
public static void main (String[] args) throws Throwable
{
TestSuite theSuite;
if (args.length > 0) {
theSuite = new TestSuite ();
for (int i = 0; i < args.length; i++) {
theSuite.addTest (new TestRandomGraphs (args[i]));
}
} else {
theSuite = (TestSuite) suite ();
}
TestRunner.run (theSuite);
}
}