/* Copyright (C) 2006 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.grmm.types;
import cc.mallet.util.Randoms;
/**
* A factor over a continuous variable theta and binary variables <tt>var</tt>.
* such that <tt>phi(x|theta)<tt> is Potts. That is, for fixed theta, <tt>phi(x)</tt> = 1
* if all x are equal, and <tt>exp^{-theta}</tt> otherwise.
* $Id: BinaryUnaryFactor.java,v 1.1 2007/10/22 21:37:44 mccallum Exp $
*/
public class BinaryUnaryFactor extends AbstractFactor implements ParameterizedFactor {
private Variable theta1;
private Variable theta2;
private Variable var; // The binary variable
public BinaryUnaryFactor (Variable var, Variable theta1, Variable theta2)
{
super (BinaryUnaryFactor.combineVariables (theta1, theta2, var));
this.theta1 = theta1;
this.theta2 = theta2;
this.var = var;
if (var.getNumOutcomes () != 2) {
throw new IllegalArgumentException ("Discrete variable "+var+" in BoltzmannUnary must be binary.");
}
if (!theta1.isContinuous ()) {
throw new IllegalArgumentException ("Parameter "+theta1+" in BinaryUnary must be continuous.");
}
if (!theta2.isContinuous ()) {
throw new IllegalArgumentException ("Parameter "+theta2+" in BinaryUnary must be continuous.");
}
}
private static VarSet combineVariables (Variable theta1, Variable theta2, Variable var)
{
VarSet ret = new HashVarSet ();
ret.add (theta1);
ret.add (theta2);
ret.add (var);
return ret;
}
protected Factor extractMaxInternal (VarSet varSet)
{
throw new UnsupportedOperationException ();
}
protected double lookupValueInternal (int i)
{
throw new UnsupportedOperationException ();
}
protected Factor marginalizeInternal (VarSet varsToKeep)
{
throw new UnsupportedOperationException ();
}
/* Inefficient, but this will seldom be called. */
public double value (AssignmentIterator it)
{
Assignment assn = it.assignment();
Factor tbl = sliceForAlpha (assn);
return tbl.value (assn);
}
private Factor sliceForAlpha (Assignment assn)
{
double th1 = assn.getDouble (theta1);
double th2 = assn.getDouble (theta2);
double[] vals = new double[] { th1, th2 };
return new TableFactor (var, vals);
}
public Factor normalize ()
{
throw new UnsupportedOperationException ();
}
public Assignment sample (Randoms r)
{
throw new UnsupportedOperationException ();
}
public double logValue (AssignmentIterator it)
{
return Math.log (value (it));
}
public Factor slice (Assignment assn)
{
Factor alphSlice = sliceForAlpha (assn);
// recursively slice, in case assn includes some of the xs
return alphSlice.slice (assn);
}
public String dumpToString ()
{
StringBuffer buf = new StringBuffer ();
buf.append ("[BinaryUnary : var=");
buf.append (var);
buf.append (" theta1=");
buf.append (theta1);
buf.append (" theta2=");
buf.append (theta2);
buf.append (" ]");
return buf.toString ();
}
public double sumGradLog (Factor q, Variable param, Assignment paramAssn)
{
Factor q_xs = q.marginalize (var);
Assignment assn;
if (param == theta1) {
assn = new Assignment (var, 0);
} else if (param == theta2) {
assn = new Assignment (var, 1);
} else {
throw new IllegalArgumentException ("Attempt to take gradient of "+this+" wrt "+param+
"but factor does not depend on that variable.");
}
return q_xs.value (assn);
}
public Factor duplicate ()
{
return new BinaryUnaryFactor (var, theta1, theta2);
}
public boolean almostEquals (Factor p, double epsilon)
{
return equals (p);
}
public boolean isNaN ()
{
return false;
}
public boolean equals (Object o)
{
if (this == o) return true;
if (o == null || getClass () != o.getClass ()) return false;
final BinaryUnaryFactor that = (BinaryUnaryFactor) o;
if (theta1 != null ? !theta1.equals (that.theta1) : that.theta1 != null) return false;
if (theta2 != null ? !theta2.equals (that.theta2) : that.theta2 != null) return false;
if (var != null ? !var.equals (that.var) : that.var != null) return false;
return true;
}
public int hashCode ()
{
int result;
result = (theta1 != null ? theta1.hashCode () : 0);
result = 29 * result + (theta2 != null ? theta2.hashCode () : 0);
result = 29 * result + (var != null ? var.hashCode () : 0);
return result;
}
}