/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept. This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit). http://www.cs.umass.edu/~mccallum/mallet This software is provided under the terms of the Common Public License, version 1.0, as published by http://www.opensource.org. For further information, see the file `LICENSE' included with this distribution. */ /** @author Andrew McCallum <a href="mailto:mccallum@cs.umass.edu">mccallum@cs.umass.edu</a> */ package cc.mallet.fst; import java.util.ArrayList; import java.util.HashMap; import java.util.Iterator; import java.util.logging.Logger; import cc.mallet.types.Alphabet; import cc.mallet.types.Multinomial; import cc.mallet.types.Sequence; import cc.mallet.util.MalletLogger; public class FeatureTransducer extends Transducer { private static Logger logger = MalletLogger.getLogger(FeatureTransducer.class.getName()); // These next two dictionaries may be the same Alphabet inputAlphabet; Alphabet outputAlphabet; ArrayList<State> states = new ArrayList<State> (); ArrayList<State> initialStates = new ArrayList<State> (); HashMap<String,State> name2state = new HashMap<String,State> (); Multinomial.Estimator initialStateCounts; Multinomial.Estimator finalStateCounts; boolean trainable = false; public FeatureTransducer (Alphabet inputAlphabet, Alphabet outputAlphabet) { this.inputAlphabet = inputAlphabet; this.outputAlphabet = outputAlphabet; // xxx When should these be frozen? } public FeatureTransducer (Alphabet dictionary) { this (dictionary, dictionary); } public FeatureTransducer () { this (new Alphabet ()); } public Alphabet getInputAlphabet () { return inputAlphabet; } public Alphabet getOutputAlphabet () { return outputAlphabet; } public void addState (String name, double initialWeight, double finalWeight, int[] inputs, int[] outputs, double[] weights, String[] destinationNames) { if (name2state.get(name) != null) throw new IllegalArgumentException ("State with name `"+name+"' already exists."); State s = new State (name, states.size(), initialWeight, finalWeight, inputs, outputs, weights, destinationNames, this); states.add (s); if (initialWeight < IMPOSSIBLE_WEIGHT) initialStates.add (s); name2state.put (name, s); setTrainable (false); } public void addState (String name, double initialWeight, double finalWeight, Object[] inputs, Object[] outputs, double[] weights, String[] destinationNames) { this.addState (name, initialWeight, finalWeight, inputAlphabet.lookupIndices (inputs, true), outputAlphabet.lookupIndices (outputs, true), weights, destinationNames); } public int numStates () { return states.size(); } public Transducer.State getState (int index) { return states.get(index); } public Iterator<State> initialStateIterator () { return initialStates.iterator (); } public boolean isTrainable () { return trainable; } public void setTrainable (boolean f) { trainable = f; if (f) { // This wipes away any previous counts we had. // It also potentially allocates an esimator of a new size if // the number of states has increased. initialStateCounts = new Multinomial.LaplaceEstimator (states.size()); finalStateCounts = new Multinomial.LaplaceEstimator (states.size()); } else { initialStateCounts = null; finalStateCounts = null; } for (int i = 0; i < numStates(); i++) ((State)getState(i)).setTrainable(f); } public void reset () { if (trainable) { initialStateCounts.reset (); finalStateCounts.reset (); for (int i = 0; i < numStates(); i++) ((State)getState(i)).reset (); } } public void estimate () { if (initialStateCounts == null || finalStateCounts == null) throw new IllegalStateException ("This transducer not currently trainable."); Multinomial initialStateDistribution = initialStateCounts.estimate (); Multinomial finalStateDistribution = finalStateCounts.estimate (); for (int i = 0; i < states.size(); i++) { State s = states.get (i); s.initialWeight = initialStateDistribution.logProbability (i); s.finalWeight = finalStateDistribution.logProbability (i); s.estimate (); } } // Note that this is a non-static inner class, so we have access to all of // FeatureTransducer's instance variables. public class State extends Transducer.State { String name; int index; double initialWeight, finalWeight; Transition[] transitions; gnu.trove.TIntObjectHashMap input2transitions; Multinomial.Estimator transitionCounts; FeatureTransducer transducer; // Note that you cannot add transitions to a state once it is created. protected State (String name, int index, double initialWeight, double finalWeight, int[] inputs, int[] outputs, double[] weights, String[] destinationNames, FeatureTransducer transducer) { assert (inputs.length == outputs.length && inputs.length == weights.length && inputs.length == destinationNames.length); this.transducer = transducer; this.name = name; this.index = index; this.initialWeight = initialWeight; this.finalWeight = finalWeight; this.transitions = new Transition[inputs.length]; this.input2transitions = new gnu.trove.TIntObjectHashMap (); transitionCounts = null; for (int i = 0; i < inputs.length; i++) { // This constructor places the transtion into this.input2transitions transitions[i] = new Transition (inputs[i], outputs[i], weights[i], this, destinationNames[i]); transitions[i].index = i; } } public Transducer getTransducer () { return transducer; } public double getInitialWeight () { return initialWeight; } public double getFinalWeight () { return finalWeight; } public void setInitialWeight (double v) { initialWeight = v; } public void setFinalWeight (double v) { finalWeight = v; } private void setTrainable (boolean f) { if (f) transitionCounts = new Multinomial.LaplaceEstimator (transitions.length); else transitionCounts = null; } // Temporarily here for debugging public Multinomial.Estimator getTransitionEstimator() { return transitionCounts; } private void reset () { if (transitionCounts != null) transitionCounts.reset(); } public int getIndex () { return index; } public Transducer.TransitionIterator transitionIterator (Sequence input, int inputPosition, Sequence output, int outputPosition) { if (inputPosition < 0 || outputPosition < 0 || output != null) throw new UnsupportedOperationException ("Not yet implemented."); if (input == null) return transitionIterator (); return transitionIterator (input, inputPosition); } public Transducer.TransitionIterator transitionIterator (Sequence inputSequence, int inputPosition) { int inputIndex = inputAlphabet.lookupIndex (inputSequence.get(inputPosition), false); if (inputIndex == -1) throw new IllegalArgumentException ("Input not in dictionary."); return transitionIterator (inputIndex); } public Transducer.TransitionIterator transitionIterator (Object o) { int inputIndex = inputAlphabet.lookupIndex (o, false); if (inputIndex == -1) throw new IllegalArgumentException ("Input not in dictionary."); return transitionIterator (inputIndex); } public Transducer.TransitionIterator transitionIterator (int input) { return new TransitionIterator (this, input); } public Transducer.TransitionIterator transitionIterator () { return new TransitionIterator (this); } public String getName () { return name; } public void incrementInitialCount (double count) { if (initialStateCounts == null) throw new IllegalStateException ("Transducer is not currently trainable."); initialStateCounts.increment (index, count); } public void incrementFinalCount (double count) { if (finalStateCounts == null) throw new IllegalStateException ("Transducer is not currently trainable."); finalStateCounts.increment (index, count); } private void estimate () { if (transitionCounts == null) throw new IllegalStateException ("Transducer is not currently trainable."); Multinomial transitionDistribution = transitionCounts.estimate (); for (int i = 0; i < transitions.length; i++) transitions[i].weight = transitionDistribution.logProbability (i); } private static final long serialVersionUID = 1; } @SuppressWarnings("serial") protected class TransitionIterator extends Transducer.TransitionIterator { // If "index" is >= -1 we are going through all FeatureState.transitions[] by index. // If "index" is -2, we are following the chain of FeatureTransition.nextWithSameInput, // and "transition" is already initialized to the first transition. // If "index" is -3, we are following the chain of FeatureTransition.nextWithSameInput, // and the next transition should be found by following the chain. int index; Transition transition; State source; int input; // Iterate through all transitions, independent of input public TransitionIterator (State source) { //System.out.println ("FeatureTransitionIterator over all"); this.source = source; this.input = -1; this.index = -1; this.transition = null; } public TransitionIterator (State source, int input) { //System.out.println ("SymbolTransitionIterator over "+input); this.source = source; this.input = input; this.index = -2; this.transition = (Transition) source.input2transitions.get (input); } public boolean hasNext () { if (index >= -1) { //System.out.println ("hasNext index " + index); return (index < source.transitions.length-1); } return (index == -2 ? transition != null : transition.nextWithSameInput != null); }; public Transducer.State nextState () { if (index >= -1) transition = source.transitions[++index]; else if (index == -2) index = -3; else transition = transition.nextWithSameInput; return transition.getDestinationState(); } public int getIndex () { return index; } public Object getInput () { return inputAlphabet.lookupObject(transition.input); } public Object getOutput () { return outputAlphabet.lookupObject(transition.output); } public double getWeight () { return transition.weight; } public Transducer.State getSourceState () { return source; } public Transducer.State getDestinationState () { return transition.getDestinationState (); } public void incrementCount (double count) { logger.info ("FeatureTransducer incrementCount "+count); source.transitionCounts.increment (transition.index, count); } } // Note: this class has a natural ordering that is inconsistent with equals. protected class Transition { int input, output; double weight; int index; String destinationName; State destination = null; Transition nextWithSameInput; public Transition (int input, int output, double weight, State sourceState, String destinationName) { this.input = input; this.output = output; this.weight = weight; this.nextWithSameInput = (Transition) sourceState.input2transitions.get (input); sourceState.input2transitions.put (input, this); // this.index is set by the caller of this constructor this.destinationName = destinationName; } public State getDestinationState () { if (destination == null) { destination = name2state.get (destinationName); assert (destination != null); } return destination; } } private static final long serialVersionUID = 1; }