/* Copyright (C) 2006 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://mallet.cs.umass.edu/
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.grmm.test;
import junit.framework.TestCase;
import junit.framework.Test;
import junit.framework.TestSuite;
import java.io.IOException;
import cc.mallet.grmm.types.*;
import cc.mallet.types.tests.TestSerializable;
/**
* Created: Aug 11, 2004
*
* @author <A HREF="mailto:casutton@cs.umass.edu>casutton@cs.umass.edu</A>
* @version $Id: TestAssignment.java,v 1.1 2007/10/22 21:37:40 mccallum Exp $
*/
public class TestAssignment extends TestCase {
private Variable[] vars;
/**
* Constructs a test case with the given name.
*/
public TestAssignment (String name)
{
super (name);
}
protected void setUp () throws Exception
{
vars = new Variable[] {
new Variable (2),
new Variable (2),
};
}
public void testSimple ()
{
Assignment assn = new Assignment (vars, new int[] { 1, 0 });
assertEquals (1, assn.get (vars [0]));
assertEquals (0, assn.get (vars [1]));
assertEquals (new Integer (0), assn.getObject (vars[1]));
}
public void testScale ()
{
Assignment assn = new Assignment (vars, new int[] { 1, 0 });
assn.addRow (vars, new int[] { 1, 0 });
assn.addRow (vars, new int[] { 1, 1 });
Assignment assn2 = new Assignment (vars, new int[] { 1, 0 });
assn.normalize ();
assertEquals (0.666666, assn.value (assn2), 1e-5);
}
public void testScaleMarginalize ()
{
Assignment assn = new Assignment (vars, new int[] { 1, 0 });
assn.addRow (vars, new int[] { 1, 0 });
assn.addRow (vars, new int[] { 1, 1 });
assn.normalize ();
Factor mrg = assn.marginalize (vars[1]);
Assignment assn2 = new Assignment (vars[1], 0);
assertEquals (0.666666, mrg.value (assn2), 1e-5);
}
public void testSerialization () throws IOException, ClassNotFoundException
{
Assignment assn = new Assignment (vars, new int[] { 1, 0 });
Assignment assn2 = (Assignment) TestSerializable.cloneViaSerialization (assn);
assertEquals (2, assn2.numVariables ());
assertEquals (1, assn2.numRows ());
assertEquals (1, assn.get (vars [0]));
assertEquals (0, assn.get (vars [1]));
}
public void testMarginalize ()
{
Assignment assn = new Assignment ();
assn.addRow (vars, new int[] { 1, 1 });
assn.addRow (vars, new int[] { 1, 0 });
Assignment assn2 = (Assignment) assn.marginalize (vars[0]);
assertEquals (2, assn2.numRows ());
assertEquals (1, assn2.size ());
assertEquals (vars[0], assn2.getVariable (0));
assertEquals (1, assn.get (0, vars[0]));
assertEquals (1, assn.get (1, vars[0]));
}
public void testMarginalizeOut ()
{
Assignment assn = new Assignment ();
assn.addRow (vars, new int[] { 1, 1 });
assn.addRow (vars, new int[] { 1, 0 });
Assignment assn2 = (Assignment) assn.marginalizeOut (vars[1]);
assertEquals (2, assn2.numRows ());
assertEquals (1, assn2.size ());
assertEquals (vars[0], assn2.getVariable (0));
assertEquals (1, assn.get (0, vars[0]));
assertEquals (1, assn.get (1, vars[0]));
}
public void testUnion ()
{
Assignment assn1 = new Assignment ();
assn1.addRow (new Variable[] { vars[0] }, new int[] { 1 });
Assignment assn2 = new Assignment ();
assn2.addRow (new Variable[] { vars[1] }, new int[] { 0 });
Assignment assn3 = Assignment.union (assn1, assn2);
assertEquals (1, assn3.numRows ());
assertEquals (2, assn3.numVariables ());
assertEquals (1, assn3.get (0, vars[0]));
assertEquals (0, assn3.get (0, vars[1]));
}
public void testMultiRow ()
{
Assignment assn = new Assignment ();
assn.addRow (vars, new int[] { 1, 1 });
assn.addRow (vars, new int[] { 1, 0 });
assertEquals (2, assn.numRows ());
assertEquals (1, assn.get (0, vars[1]));
assertEquals (0, assn.get (1, vars[1]));
try {
assn.get (vars[1]);
fail ();
} catch (IllegalArgumentException e) {}
}
public void testSetRow ()
{
Assignment assn = new Assignment ();
assn.addRow (vars, new int[] { 1, 1 });
assn.addRow (vars, new int[] { 1, 0 });
assertEquals (1, assn.get (0, vars[0]));
assn.setRow (0, new int[] { 0, 0 });
assertEquals (2, assn.numRows ());
assertEquals (0, assn.get (0, vars[0]));
assertEquals (0, assn.get (0, vars[1]));
assertEquals (1, assn.get (1, vars[0]));
assertEquals (0, assn.get (1, vars[1]));
}
public void testSetRowFromAssn ()
{
Assignment assn = new Assignment ();
assn.addRow (vars, new int[] { 1, 1 });
assn.addRow (vars, new int[] { 1, 0 });
assertEquals (1, assn.get (0, vars[0]));
Assignment assn2 = new Assignment ();
assn2.addRow (vars, new int[] { 0, 0 });
assn.setRow (0, assn2);
assertEquals (2, assn.numRows ());
assertEquals (0, assn.get (0, vars[0]));
assertEquals (0, assn.get (0, vars[1]));
assertEquals (1, assn.get (1, vars[0]));
assertEquals (0, assn.get (1, vars[1]));
}
public void testSetValue ()
{
Assignment assn = new Assignment ();
assn.addRow (vars, new int[] { 1, 1 });
assn.setValue (vars[0], 0);
assertEquals (1, assn.numRows ());
assertEquals (0, assn.get (0, vars[0]));
assertEquals (1, assn.get (0, vars[1]));
}
public void testSetValueDup ()
{
Assignment assn = new Assignment ();
assn.addRow (vars, new int[] { 1, 1 });
Assignment dup = (Assignment) assn.duplicate ();
dup.setValue (vars[0], 0);
assertEquals (1, dup.numRows ());
assertEquals (0, dup.get (0, vars[0]));
assertEquals (1, dup.get (0, vars[1]));
}
public void testSetValueExpand ()
{
Assignment assn = new Assignment ();
assn.addRow (vars, new int[] { 0, 0 });
Variable v3 = new Variable (2);
assn.setValue (v3, 1);
assertEquals (3, assn.size ());
assertEquals (0, assn.get (vars[0]));
assertEquals (0, assn.get (vars[1]));
assertEquals (1, assn.get (v3));
}
public void testAsTable ()
{
Assignment assn = new Assignment ();
assn.addRow (vars, new int[] { 1, 1 });
assn.addRow (vars, new int[] { 1, 0 });
assn.addRow (vars, new int[] { 1, 0 });
AbstractTableFactor tbl = assn.asTable ();
TableFactor exp = new TableFactor (vars, new double[] { 0, 0, 2, 1 });
assertTrue (exp.almostEquals (tbl));
}
public void testAddRowMixed ()
{
Assignment assn = new Assignment ();
assn.addRow (vars, new int[] { 1, 1 });
assn.addRow (vars, new int[] { 1, 0 });
Assignment assn2 = new Assignment ();
assn2.addRow (new Variable[] { vars[1], vars[0] }, new int[] { 0, 1 });
assn.addRow (assn2);
AbstractTableFactor tbl = assn.asTable ();
TableFactor exp = new TableFactor (vars, new double[] { 0, 0, 2, 1 });
assertTrue (exp.almostEquals (tbl));
}
public static Test suite()
{
return new TestSuite (TestAssignment.class);
}
public static void main(String[] args) throws Exception
{
TestSuite theSuite;
if (args.length > 0) {
theSuite = new TestSuite();
for (int i = 0; i < args.length; i++) {
theSuite.addTest(new TestAssignment (args[i]));
}
} else {
theSuite = (TestSuite) TestAssignment.suite ();
}
junit.textui.TestRunner.run(theSuite);
}
}