/* Copyright (C) 2003 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.grmm.inference;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import cc.mallet.grmm.types.*;
/**
* Approximate inferencer for graphical models using sampling.
* A general inference engine that takes any Sampler engine, and performs
* approximate inference using its samples.
* Created: Mar 28, 2005
*
* @author <A HREF="mailto:casutton@cs.umass.edu>casutton@cs.umass.edu</A>
* @version $Id: SamplingInferencer.java,v 1.1 2007/10/22 21:37:49 mccallum Exp $
*/
public class SamplingInferencer extends AbstractInferencer {
private int N;
private Sampler sampler;
// Could save only sufficient statistics to save on memory
transient Assignment samples;
public SamplingInferencer (Sampler sampler, int n)
{
this.sampler = sampler;
N = n;
}
public void computeMarginals (FactorGraph mdl)
{
samples = sampler.sample (mdl, N);
}
public Factor lookupMarginal (Variable var)
{
return lookupMarginal (new HashVarSet (new Variable[] { var }));
}
// don't try this for large cliques
public Factor lookupMarginal (VarSet varSet)
{
Factor mrgl = samples.marginalize (varSet);
AbstractTableFactor tbl = mrgl.asTable ();
tbl.normalize ();
return tbl;
}
// Serialization garbage
private static final long serialVersionUID = 1;
private static final int CURRENT_SERIAL_VERSION = 1;
private void writeObject (ObjectOutputStream out) throws IOException
{
out.writeInt (CURRENT_SERIAL_VERSION);
out.writeInt (N);
out.writeObject (sampler);
}
private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException
{
in.readInt (); // read version
N = in.readInt ();
sampler = (Sampler) in.readObject ();
}
public String toString ()
{
return "(SamplingInferencer: "+sampler+" )";
}
}