/**
*
* Copyright 1999-2012 Carnegie Mellon University.
* Portions Copyright 2002 Sun Microsystems, Inc.
* Portions Copyright 2002 Mitsubishi Electric Research Laboratories.
* All Rights Reserved. Use is subject to license terms.
*
* See the file "license.terms" for information on usage and
* redistribution of this file, and for a DISCLAIMER OF ALL
* WARRANTIES.
*
*/
package edu.cmu.sphinx.fst;
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Calendar;
import java.util.HashMap;
import java.util.HashSet;
import edu.cmu.sphinx.fst.semiring.Semiring;
/**
* A mutable finite state transducer implementation.
*
* Holds an ArrayList of {@link edu.cmu.sphinx.fst.State} objects allowing
* additions/deletions.
*
* @author John Salatas
*/
public class Fst {
// fst states
private ArrayList<State> states = null;
// initial state
protected State start;
// input symbols map
protected String[] isyms;
// output symbols map
protected String[] osyms;
// semiring
protected Semiring semiring;
/**
* Default Constructor
*/
public Fst() {
states = new ArrayList<State>();
}
/**
* Constructor specifying the initial capacity of the states ArrayList (this
* is an optimization used in various operations)
*
* @param numStates
* the initial capacity
*/
public Fst(int numStates) {
if (numStates > 0) {
states = new ArrayList<State>(numStates);
}
}
/**
* Constructor specifying the fst's semiring
*
* @param s
* the fst's semiring
*/
public Fst(Semiring s) {
this();
this.semiring = s;
}
/**
* Get the initial states
* @return the initial state
*/
public State getStart() {
return start;
}
/**
* Get the semiring
* @return
* used semiring
*/
public Semiring getSemiring() {
return semiring;
}
/**
* Set the Semiring
*
* @param semiring
* the semiring to set
*/
public void setSemiring(Semiring semiring) {
this.semiring = semiring;
}
/**
* Set the initial state
*
* @param start
* the initial state
*/
public void setStart(State start) {
this.start = start;
}
/**
* Get the number of states in the fst
* @return number of states
*/
public int getNumStates() {
return this.states.size();
}
public State getState(int index) {
return states.get(index);
}
/**
* Adds a state to the fst
*
* @param state
* the state to be added
*/
public void addState(State state) {
this.states.add(state);
state.id = states.size() - 1;
}
/**
* Get the input symbols' array
* @return array of input symbols
*/
public String[] getIsyms() {
return isyms;
}
/**
* Set the input symbols
*
* @param isyms
* the isyms to set
*/
public void setIsyms(String[] isyms) {
this.isyms = isyms;
}
/**
* Get the output symbols' array
* @return array fo output symbols
*/
public String[] getOsyms() {
return osyms;
}
/**
* Set the output symbols
*
* @param osyms
* the osyms to set
*/
public void setOsyms(String[] osyms) {
this.osyms = osyms;
}
/**
* Serializes a symbol map to an ObjectOutputStream
*
* @param out
* the ObjectOutputStream. It should be already be initialized by
* the caller.
* @param map
* the symbol map to serialize
* @throws IOException
*/
private void writeStringMap(ObjectOutputStream out, String[] map)
throws IOException {
out.writeInt(map.length);
for (int i = 0; i < map.length; i++) {
out.writeObject(map[i]);
}
}
/**
* Serializes the current Fst instance to an ObjectOutputStream
*
* @param out
* the ObjectOutputStream. It should be already be initialized by
* the caller.
* @throws IOException
*/
private void writeFst(ObjectOutputStream out) throws IOException {
writeStringMap(out, isyms);
writeStringMap(out, osyms);
out.writeInt(states.indexOf(start));
out.writeObject(semiring);
out.writeInt(states.size());
HashMap<State, Integer> stateMap = new HashMap<State, Integer>(
states.size(), 1.f);
for (int i = 0; i < states.size(); i++) {
State s = states.get(i);
out.writeInt(s.getNumArcs());
out.writeFloat(s.getFinalWeight());
out.writeInt(s.getId());
stateMap.put(s, i);
}
int numStates = states.size();
for (int i = 0; i < numStates; i++) {
State s = states.get(i);
int numArcs = s.getNumArcs();
for (int j = 0; j < numArcs; j++) {
Arc a = s.getArc(j);
out.writeInt(a.getIlabel());
out.writeInt(a.getOlabel());
out.writeFloat(a.getWeight());
out.writeInt(stateMap.get(a.getNextState()));
}
}
}
/**
* Saves binary model to disk
*
* @param filename
* the binary model filename
* @throws IOException if IO went wrong
*/
public void saveModel(String filename) throws IOException {
FileOutputStream fos = new FileOutputStream(filename);
BufferedOutputStream bos = new BufferedOutputStream(fos);
ObjectOutputStream oos = new ObjectOutputStream(bos);
writeFst(oos);
oos.flush();
oos.close();
bos.close();
fos.close();
}
/**
* Deserializes a symbol map from an ObjectInputStream
*
* @param in
* the ObjectInputStream. It should be already be initialized by
* the caller.
* @return the deserialized symbol map
* @throws IOException if IO went wrong
* @throws ClassNotFoundException if serialization went wrong
*/
protected static String[] readStringMap(ObjectInputStream in)
throws IOException, ClassNotFoundException {
int mapSize = in.readInt();
String[] map = new String[mapSize];
for (int i = 0; i < mapSize; i++) {
String sym = (String) in.readObject();
map[i] = sym;
}
return map;
}
/**
* Deserializes an Fst from an ObjectInputStream
*
* @param in
* the ObjectInputStream. It should be already be initialized by
* the caller.
* @return Created FST
* @throws IOException
* @throws ClassNotFoundException
*/
private static Fst readFst(ObjectInputStream in) throws IOException,
ClassNotFoundException {
String[] is = readStringMap(in);
String[] os = readStringMap(in);
int startid = in.readInt();
Semiring semiring = (Semiring) in.readObject();
int numStates = in.readInt();
Fst res = new Fst(numStates);
res.isyms = is;
res.osyms = os;
res.semiring = semiring;
for (int i = 0; i < numStates; i++) {
int numArcs = in.readInt();
State s = new State(numArcs + 1);
float f = in.readFloat();
if (f == res.semiring.zero()) {
f = res.semiring.zero();
} else if (f == res.semiring.one()) {
f = res.semiring.one();
}
s.setFinalWeight(f);
s.id = in.readInt();
res.states.add(s);
}
res.setStart(res.states.get(startid));
numStates = res.getNumStates();
for (int i = 0; i < numStates; i++) {
State s1 = res.getState(i);
for (int j = 0; j < s1.initialNumArcs - 1; j++) {
Arc a = new Arc();
a.setIlabel(in.readInt());
a.setOlabel(in.readInt());
a.setWeight(in.readFloat());
a.setNextState(res.states.get(in.readInt()));
s1.addArc(a);
}
}
return res;
}
/**
* Deserializes an Fst from disk
*
* @param filename
* the binary model filename
* @return deserialized FST
* @throws IOException io IO went wrong
* @throws ClassNotFoundException if serialization went wrong
*/
public static Fst loadModel(String filename) throws IOException,
ClassNotFoundException {
long starttime = Calendar.getInstance().getTimeInMillis();
Fst obj;
FileInputStream fis = null;
BufferedInputStream bis = null;
ObjectInputStream ois = null;
fis = new FileInputStream(filename);
bis = new BufferedInputStream(fis);
ois = new ObjectInputStream(bis);
obj = readFst(ois);
ois.close();
bis.close();
fis.close();
System.err.println("Load Time: "
+ (Calendar.getInstance().getTimeInMillis() - starttime)
/ 1000.);
return obj;
}
/*
* (non-Javadoc)
*
* @see java.lang.Object#equals(java.lang.Object)
*/
@Override
public boolean equals(Object obj) {
if (this == obj)
return true;
if (obj == null)
return false;
if (getClass() != obj.getClass())
return false;
Fst other = (Fst) obj;
if (!Arrays.equals(isyms, other.isyms))
return false;
if (!Arrays.equals(osyms, other.osyms))
return false;
if (start == null) {
if (other.start != null)
return false;
} else if (!start.equals(other.start))
return false;
if (states == null) {
if (other.states != null)
return false;
} else if (!states.equals(other.states))
return false;
if (semiring == null) {
if (other.semiring != null)
return false;
} else if (!semiring.equals(other.semiring))
return false;
return true;
}
@Override
public int hashCode() {
return 31 * (Arrays.hashCode(isyms) +
31 * (Arrays.hashCode(osyms) +
31 * ((start == null ? 0 : start.hashCode()) +
31 * ((states == null ? 0 : states.hashCode()) +
31 * ((semiring == null ? 0 : semiring.hashCode()))))));
}
/*
* (non-Javadoc)
*
* @see java.lang.Object#toString()
*/
@Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append("Fst(start=" + start + ", isyms=" + Arrays.toString(isyms) + ", osyms="
+ Arrays.toString(osyms) + ", semiring=" + semiring + ")\n");
int numStates = states.size();
for (int i = 0; i < numStates; i++) {
State s = states.get(i);
sb.append(" " + s + "\n");
int numArcs = s.getNumArcs();
for (int j = 0; j < numArcs; j++) {
Arc a = s.getArc(j);
sb.append(" " + a + "\n");
}
}
return sb.toString();
}
/**
* Deletes a state
*
* @param state
* the state to delete
*/
public void deleteState(State state) {
if (state == start) {
System.err.println("Cannot delete start state.");
return;
}
states.remove(state);
for (State s1 : states) {
ArrayList<Arc> newArcs = new ArrayList<Arc>();
for (int j = 0; j < s1.getNumArcs(); j++) {
Arc a = s1.getArc(j);
if (!a.getNextState().equals(state)) {
newArcs.add(a);
}
}
s1.setArcs(newArcs);
}
}
/**
* Remaps the states' ids.
*
* States' ids are renumbered starting from 0 up to @see
* {@link edu.cmu.sphinx.fst.Fst#getNumStates()}
*/
public void remapStateIds() {
int numStates = states.size();
for (int i = 0; i < numStates; i++) {
states.get(i).id = i;
}
}
public void deleteStates(HashSet<State> toDelete) {
if (toDelete.contains(start)) {
System.err.println("Cannot delete start state.");
return;
}
ArrayList<State> newStates = new ArrayList<State>();
for (State s1 : states) {
if (!toDelete.contains(s1)) {
newStates.add(s1);
ArrayList<Arc> newArcs = new ArrayList<Arc>();
for (int j = 0; j < s1.getNumArcs(); j++) {
Arc a = s1.getArc(j);
if (!toDelete.contains(a.getNextState())) {
newArcs.add(a);
}
}
s1.setArcs(newArcs);
}
}
states = newStates;
remapStateIds();
}
}