/* Copyright (C) 2006 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.grmm.types;
import java.io.ObjectInputStream;
import java.io.IOException;
import java.io.ObjectOutputStream;
import cc.mallet.util.Maths;
import cc.mallet.util.Randoms;
/**
* $Id: BetaFactor.java,v 1.1 2007/10/22 21:37:44 mccallum Exp $
*/
public class BetaFactor extends AbstractFactor {
transient private Variable var;
transient private double min;
transient private double max;
transient private double alpha;
transient private double beta;
transient private double beta12;
public BetaFactor (Variable var, double alpha, double beta)
{
this (var, alpha, beta, 0, 1);
}
public BetaFactor (Variable var, double alpha, double beta, double min, double max)
{
super (new HashVarSet (new Variable[] { var }));
if (!var.isContinuous ()) throw new IllegalArgumentException ();
if (min >= max) throw new IllegalArgumentException ();
this.var = var;
this.min = min;
this.max = max;
this.alpha = alpha;
this.beta = beta;
setBeta12 ();
}
private void setBeta12 ()
{
beta12 = 1 / Maths.beta (alpha, beta);
}
protected Factor extractMaxInternal (VarSet varSet)
{
throw new UnsupportedOperationException ();
}
public double value (Assignment assn)
{
double pct = valueToPct (assn.getDouble (var));
if ((0 < pct) && (pct < 1)) {
return beta12 * Math.pow (pct, (alpha - 1.0)) * Math.pow ((1-pct), (beta -1.0));
} else {
return 0;
}
}
private double valueToPct (double val)
{
return (val - min) / (max - min);
}
private double pctToValue (double pct)
{
return (pct * (max - min)) + min;
}
protected double lookupValueInternal (int i)
{
throw new UnsupportedOperationException ();
}
protected Factor marginalizeInternal (VarSet varsToKeep)
{
if (varsToKeep.contains (var)) {
return duplicate ();
} else {
return new ConstantFactor (1.0);
}
}
public Factor normalize ()
{
return this;
}
public Assignment sample (Randoms r)
{
double pct = r.nextBeta (alpha, beta);
double val = pctToValue (pct);
return new Assignment (var, val);
}
public boolean almostEquals (Factor p, double epsilon)
{
return equals (p);
}
public Factor duplicate ()
{
return new BetaFactor (var, alpha, beta, min, max);
}
public boolean isNaN ()
{
return Double.isNaN(alpha) || Double.isNaN(beta) || Double.isNaN (min) || Double.isNaN (max)
|| alpha <= 0 || beta <= 0;
}
public String dumpToString ()
{
return toString ();
}
public void multiplyBy (Factor f)
{
if (f instanceof ConstantFactor) {
double val = f.value (new Assignment());
// NormalFactor must be normalized right now...
if (Maths.almostEquals (val, 1.0)) {
return; // ok, it's an identity factor
}
}
throw new UnsupportedOperationException ("Can't multiply BetaFactor by "+f);
}
public void divideBy (Factor f)
{
if (f instanceof ConstantFactor) {
double val = f.value (new Assignment());
// NormalFactor must be normalized right now...
if (Maths.almostEquals (val, 1.0)) {
return; // ok, it's an identity factor
}
}
throw new UnsupportedOperationException ("Can't divide BetaFactor by "+f);
}
public String toString ()
{
return "[BetaFactor("+alpha +", "+beta +") "+var+" scale=("+min+" ... " +max+") ]";
}
public Factor slice (Assignment assn)
{
if (assn.containsVar (var)) {
return new ConstantFactor (value (assn));
} else return duplicate ();
}
// serialization nonsense
private static final long serialVersionUID = 1L;
private static final int SERIAL_VERSION = 1;
private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException
{
in.defaultReadObject ();
in.readInt (); // serial version
var = (Variable) in.readObject ();
alpha = in.readDouble ();
beta = in.readDouble ();
min = in.readDouble ();
max = in.readDouble ();
}
private void writeObject (ObjectOutputStream out) throws IOException, ClassNotFoundException
{
out.defaultWriteObject ();
out.writeInt (SERIAL_VERSION);
out.writeObject (var);
out.writeDouble (alpha);
out.writeDouble (beta);
out.writeDouble (min);
out.writeDouble (max);
setBeta12 ();
}
}