/* Copyright (C) 2006 Univ. of Massachusetts Amherst, Computer Science Dept. This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit). http://mallet.cs.umass.edu/ This software is provided under the terms of the Common Public License, version 1.0, as published by http://www.opensource.org. For further information, see the file `LICENSE' included with this distribution. */ package cc.mallet.grmm.inference; import gnu.trove.TIntObjectIterator; import java.io.PrintWriter; import java.io.OutputStreamWriter; import cc.mallet.grmm.types.Factor; import cc.mallet.grmm.types.FactorGraph; import cc.mallet.grmm.types.LogTableFactor; import cc.mallet.grmm.types.Variable; import cc.mallet.grmm.util.MIntInt2ObjectMap; /** * Efficiently manages a array of messages in a factor graph from * variables to factors and vice versa. * * Created: Feb 1, 2006 * * @author <A HREF="mailto:casutton@cs.umass.edu>casutton@cs.umass.edu</A> * @version $Id: MessageArray.java,v 1.1 2007/10/22 21:37:49 mccallum Exp $ */ public class MessageArray { private FactorGraph fg; private MIntInt2ObjectMap messages; // messages from factor --> variable private int numV; private int numF; private boolean inLogSpace; public MessageArray (FactorGraph fg) { this.fg = fg; numV = fg.numVariables (); numF = fg.factors ().size(); messages = new MIntInt2ObjectMap (numV + numV); inLogSpace = (fg.getFactor (0) instanceof LogTableFactor); } public boolean isInLogSpace () { return inLogSpace; } public Factor get (Object from, Object to) { if (from instanceof Factor && to instanceof Variable) { return get ((Factor) from, (Variable) to); } else if (from instanceof Variable && to instanceof Factor) { return get ((Variable) from, (Factor) to); } else { throw new IllegalArgumentException (); } } public Factor get (Variable from, Factor to) { int fromIdx = getIndex (from); int toIdx = getIndex (to); return get (toIdx, fromIdx); } public Factor get (Factor from, Variable to) { int fromIdx = getIndex (from); int toIdx = getIndex (to); return get (toIdx, fromIdx); } Factor get (int toIdx, int fromIdx) { return (Factor) messages.get (toIdx, fromIdx); } public void put (Factor from, Variable to, Factor msg) { int fromIdx = getIndex (from); int toIdx = getIndex (to); messages.put (toIdx, fromIdx, msg); } public void put (Variable from, Factor to, Factor msg) { int fromIdx = getIndex (from); int toIdx = getIndex (to); messages.put (toIdx, fromIdx, msg); } // more dangerous, but for efficiency public void put (int fromIdx, int toIdx, Factor msg) { messages.put (toIdx, fromIdx, msg); } public Iterator iterator () { return new Iterator (); } public ToMsgsIterator toMessagesIterator (int toIdx) { return new ToMsgsIterator (messages, toIdx); } public MessageArray duplicate () { MessageArray dup = new MessageArray (fg); dup.messages = deepCopy (messages); return dup; } public MIntInt2ObjectMap deepCopy (MIntInt2ObjectMap msgs) { MIntInt2ObjectMap copy = new MIntInt2ObjectMap (numV + numF); int[] keys1 = msgs.keys1 (); for (int i = 0; i < keys1.length; i++) { int k1 = keys1[i]; ToMsgsIterator msgIt = new ToMsgsIterator (msgs, k1); while (msgIt.hasNext ()) { Factor msg = msgIt.next (); int from = msgIt.currentFromIdx (); copy.put (k1, from, msg.duplicate ()); } } return copy; } public int getIndex (Factor from) { return -(fg.getIndex (from) + 1); } public int getIndex (Variable to) { return fg.getIndex (to); } public Object idx2obj (int idx) { if (idx >= 0) { return fg.get (idx); } else { return fg.getFactor (-idx - 1); } } public void dump () { dump (new PrintWriter (new OutputStreamWriter (System.out), true)); } public void dump (PrintWriter out) { for (MessageArray.Iterator it = iterator (); it.hasNext ();) { Factor msg = (Factor) it.next (); Object from = it.from (); Object to = it.to (); out.println ("Message from " + from + " to " + to); out.println (msg.dumpToString ()); } } public final class Iterator implements java.util.Iterator { int idx1 = 0; int idx2 = -1; int[] keys1; int[] keys2; public Iterator () { keys1 = messages.keys1 (); if (keys1.length > 0) { keys2 = messages.keys2 (keys1[idx1]); } else { keys2 = new int [0]; } } private void increment () { idx2++; if (idx2 >= keys2.length) { idx2 = 0; idx1++; keys2 = messages.keys2 (keys1[idx1]); } } public boolean hasNext () { return (idx1+1 < keys1.length) || (idx2+1 < keys2.length); } public Object next () { increment (); return messages.get (keys1[idx1], keys2[idx2]); } public void remove () { throw new UnsupportedOperationException (); } public Object from () { return idx2obj (keys2[idx2]); } public Object to () { return idx2obj (keys1[idx1]); } } final public static class ToMsgsIterator { private TIntObjectIterator subIt; private int toIdx = -1; private ToMsgsIterator (MIntInt2ObjectMap msgs, int toIdx) { this.toIdx = toIdx; subIt = msgs.curry (toIdx); } public boolean hasNext () { return subIt.hasNext (); } public Factor next () { subIt.advance (); return currentMessage (); } int currentFromIdx () { return subIt.key (); } public Factor currentMessage () { return (Factor) subIt.value (); } public int currentToIdx () { return toIdx; } } }