/* Copyright (C) 2006 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.grmm.types;
import cc.mallet.util.Maths;
import cc.mallet.util.Randoms;
import no.uib.cipr.matrix.*;
/**
* Multivariate Gaussian factor. Currently, almost all of this class
* is a stub, except for the sample method.
* $Id: NormalFactor.java,v 1.1 2007/10/22 21:37:44 mccallum Exp $
*/
public class NormalFactor extends AbstractFactor {
private Vector mean;
private Matrix variance;
public NormalFactor (VarSet vars, Vector mean, Matrix variance)
{
super (vars);
if (!isPosDef (variance)) throw new IllegalArgumentException ("Matrix "+variance+" not positive definite.");
this.mean = mean;
this.variance = variance;
}
private boolean isPosDef (Matrix variance)
{
try {
EVD evd = EVD.factorize (variance);
double[] vals = evd.getRealEigenvalues ();
return vals[vals.length - 1] > 0;
} catch (NotConvergedException e) {
throw new RuntimeException (e);
}
}
//
protected Factor extractMaxInternal (VarSet varSet)
{
throw new UnsupportedOperationException ();
}
public double value (Assignment assn)
{
// stub
return 1.0;
}
protected double lookupValueInternal (int i)
{
throw new UnsupportedOperationException ();
}
protected Factor marginalizeInternal (VarSet varsToKeep)
{
throw new UnsupportedOperationException ();
}
public Factor normalize ()
{
return this;
}
public Assignment sample (Randoms r)
{
// generate from standard normal
double[] vals = new double [mean.size ()];
for (int k = 0; k < vals.length; k++) {
vals[k] = r.nextGaussian ();
}
// and transform
Vector Z = new DenseVector (vals, false);
DenseVector result = new DenseVector (vals.length);
variance.mult (Z, result);
result = (DenseVector) result.add (mean);
return new Assignment (vars.toVariableArray (), result.getData ());
}
public boolean almostEquals (Factor p, double epsilon)
{
return equals (p);
}
public Factor duplicate ()
{
return new NormalFactor (vars, mean, variance);
}
public boolean isNaN ()
{
return false;
}
public String dumpToString ()
{
return toString ();
}
public String toString ()
{
return "[NormalFactor "+vars+" "+mean+" ... " +variance+" ]";
}
// todo
public Factor slice (Assignment assn)
{
if (assn.varSet ().containsAll (vars)) {
// special case
return new ConstantFactor (value (assn));
} else {
throw new UnsupportedOperationException ();
}
}
public void multiplyBy (Factor f)
{
if (f instanceof ConstantFactor) {
double val = f.value (new Assignment());
// NormalFactor must be normalized right now...
if (Maths.almostEquals (val, 1.0)) {
return; // ok, it's an identity factor
}
}
throw new UnsupportedOperationException ("Can't multiply NormalFactor by "+f);
}
public void divideBy (Factor f)
{
if (f instanceof ConstantFactor) {
double val = f.value (new Assignment());
// NormalFactor must be normalized right now...
if (Maths.almostEquals (val, 1.0)) {
return; // ok, it's an identity factor
}
}
throw new UnsupportedOperationException ("Can't divide NormalFactor by "+f);
}
}