/* Copyright (C) 2011 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.fst.semi_supervised.pr;
import cc.mallet.fst.CRF.State;
import cc.mallet.fst.Transducer;
import cc.mallet.types.Sequence;
/**
* TransitionIterator that caches dot products.
*
* @author Gregory Druck
*/
public class CachedDotTransitionIterator extends Transducer.TransitionIterator {
State source;
int index, nextIndex;
protected double[] weights;
Object input;
public CachedDotTransitionIterator(State source,
Sequence inputSeq, int inputPosition, String output,
double[] dots) {
this(source, inputSeq.get(inputPosition), output, dots);
}
protected CachedDotTransitionIterator(State source, Object fv,
String output, double[] dots) {
this.source = source;
this.input = fv;
this.weights = new double[source.numDestinations()];
for (int i = 0; i < source.numDestinations(); i++) {
weights[i] = dots[source.getDestinationState(i).getIndex()];
}
// Prepare nextIndex, pointing at the next non-impossible transition
nextIndex = 0;
while (nextIndex < source.numDestinations()
&& weights[nextIndex] == Transducer.IMPOSSIBLE_WEIGHT)
nextIndex++;
}
public boolean hasNext() {
return nextIndex < source.numDestinations();
}
public Transducer.State nextState() {
assert (nextIndex < source.numDestinations());
index = nextIndex;
nextIndex++;
while (nextIndex < source.numDestinations()
&& weights[nextIndex] == Transducer.IMPOSSIBLE_WEIGHT)
nextIndex++;
return source.getDestinationState(index);
}
// These "final"s are just to try to make this more efficient. Perhaps some of
// them will have to go away
public final int getIndex() {
return index;
}
public final Object getInput() {
return input;
}
public final Object getOutput() {
return source.getLabelName(index);
}
public final double getWeight() {
return weights[index];
}
public final Transducer.State getSourceState() {
return source;
}
public final Transducer.State getDestinationState() {
return source.getDestinationState(index);
}
private static final long serialVersionUID = 1;
}