/* Copyright (C) 2006 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.grmm.test;
import cc.mallet.grmm.types.*;
import junit.framework.*;
/**
* $Id: TestPottsFactor.java,v 1.1 2007/10/22 21:37:41 mccallum Exp $
*/
public class TestPottsFactor extends TestCase {
private PottsTableFactor factor;
private Variable alpha;
private VarSet vars;
public TestPottsFactor (String name)
{
super (name);
}
/**
* @return a <code>TestSuite</code>
*/
public static TestSuite suite ()
{
return new TestSuite (TestPottsFactor.class);
}
protected void setUp () throws Exception
{
alpha = new Variable (Variable.CONTINUOUS);
Variable v1 = new Variable (2);
Variable v2 = new Variable (2);
vars = new HashVarSet (new Variable[] { v1,v2 });
factor = new PottsTableFactor (vars, alpha);
}
public void testSlice ()
{
Assignment assn = new Assignment (alpha, 1.0);
Factor sliced = factor.slice (assn);
assertTrue (sliced instanceof AbstractTableFactor);
assertTrue (sliced.varSet ().equals (vars));
TableFactor expected = new TableFactor (vars, new double[] { 1.0, Math.exp(-1), Math.exp(-1), 1.0 });
assertTrue (sliced.almostEquals (expected));
}
public void testSumGradLog ()
{
Assignment alphaAssn = new Assignment (alpha, 1.0);
double[] values = new double[] { 0.4, 0.1, 0.3, 0.2 };
Factor q = new TableFactor (vars, values);
double grad = factor.sumGradLog (q, alpha, alphaAssn);
assertEquals (-0.4, grad, 1e-5);
}
public void testSumGradLog2 ()
{
Assignment alphaAssn = new Assignment (alpha, 1.0);
double[] values = new double[] { 0.4, 0.1, 0.3, 0.2 };
Factor q1 = new TableFactor (vars, values);
Factor q2 = new TableFactor (new Variable(2), new double[] { 0.7, 0.3 });
Factor q = q1.multiply (q2);
double grad = factor.sumGradLog (q, alpha, alphaAssn);
assertEquals (-0.4, grad, 1e-5);
}
public static void main (String[] args)
{
TestSuite theSuite;
if (args.length > 0) {
theSuite = new TestSuite ();
for (int i = 0; i < args.length; i++) {
theSuite.addTest (new TestPottsFactor (args[i]));
}
} else {
theSuite = suite ();
}
junit.textui.TestRunner.run (theSuite);
}
}