/* This file is part of the Joshua Machine Translation System.
*
* Joshua is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2.1
* of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this library; if not, write to the Free
* Software Foundation, Inc., 59 Temple Place, Suite 330, Boston,
* MA 02111-1307 USA
*/
package joshua.lattice;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import joshua.corpus.vocab.SymbolTable;
/**
* A lattice representation of a directed graph.
*
* @author Lane Schwartz
* @since 2008-07-08
* @version $LastChangedDate: 2009-08-28 11:02:40 -0500 (Fri, 28 Aug 2009) $
*
* @param Label Type of label associated with an arc.
*/
public class Lattice<Value> implements Iterable<Node<Value>> {
/**
* True if there is more than one path through the lattice.
*/
private final boolean latticeHasAmbiguity;
/**
* Costs of the best path between each pair of nodes in the
* lattice.
*/
private final double[][] costs;
/**
* List of all nodes in the lattice. Nodes are assumed to
* be in topological order.
*/
private final List<Node<Value>> nodes;
/** Logger for this class. */
private static final Logger logger =
Logger.getLogger(Lattice.class.getName());
/**
* Constructs a new lattice from an existing list of
* (connected) nodes.
* <p>
* The list of nodes must already be in topological order.
* If the list is not in topological order, the behavior
* of the lattice is not defined.
*
* @param nodes A list of nodes which must be in topological
* order.
*/
public Lattice(List<Node<Value>> nodes) {
this.nodes = nodes;
this.costs = calculateAllPairsShortestPath(nodes);
this.latticeHasAmbiguity = true;
}
public Lattice(List<Node<Value>> nodes, boolean isAmbiguous) {
//Node<Value> sink = new Node<Value>(nodes.size());
//nodes.add(sink);
this.nodes = nodes;
this.costs = calculateAllPairsShortestPath(nodes);
this.latticeHasAmbiguity = isAmbiguous;
}
public Lattice(Value[] linearChain) {
this.latticeHasAmbiguity = false;
this.nodes = new ArrayList<Node<Value>>();
Node<Value> previous = new Node<Value>(0);
nodes.add(previous);
int i=1;
for (Value value : linearChain) {
Node<Value> current = new Node<Value>(i);
float cost = 0.0f;
// if (i > 4) cost = (float)i/1.53432f;
previous.addArc(current, cost, value);
nodes.add(current);
previous = current;
i++;
}
this.costs = calculateAllPairsShortestPath(nodes);
}
public final boolean hasMoreThanOnePath() {
return latticeHasAmbiguity;
}
/**
* Convenience method to get a lattice from an int[].
*
* This method is useful because Java's generics won't allow
* a primitive array to be passed as a generic array.
*
* @param linearChain
* @return Lattice representation of the linear chain.
*/
public static Lattice<Integer> createLattice(int[] linearChain) {
Integer[] integerSentence = new Integer[linearChain.length];
for (int i = 0; i < linearChain.length; i++) {
integerSentence[i] = linearChain[i];
}
return new Lattice<Integer>(integerSentence);
}
public static Lattice<Integer> createFromString(String data, SymbolTable symbolTable) {
Map<Integer,Node<Integer>> nodes = new HashMap<Integer,Node<Integer>>();
Pattern nodePattern = Pattern.compile("(.+?)\\((\\(.+?\\),)\\)(.*)");
Pattern arcPattern = Pattern.compile("\\('(.+?)',(-?\\d+\\.?\\d+?),(\\d+)\\),(.*)");
Matcher nodeMatcher = nodePattern.matcher(data);
int nodeID = -1;
boolean latticeIsAmbiguous = false;
while (nodeMatcher.matches()) {
String nodeData = nodeMatcher.group(2);
String remainingData = nodeMatcher.group(3);
nodeID++;
Node<Integer> currentNode = null;
if (nodes.containsKey(nodeID)) {
currentNode = nodes.get(nodeID);
} else {
currentNode = new Node<Integer>(nodeID);
nodes.put(nodeID, currentNode);
}
if (logger.isLoggable(Level.FINE)) logger.fine("Node " + nodeID + ":");
Matcher arcMatcher = arcPattern.matcher(nodeData);
int numArcs = 0;
if (!arcMatcher.matches()) { throw new RuntimeException("Parse error!"); }
while (arcMatcher.matches()) {
numArcs++;
String arcLabel = arcMatcher.group(1);
double arcWeight = Double.valueOf(arcMatcher.group(2));
int destinationNodeID = nodeID + Integer.valueOf(arcMatcher.group(3));
Node<Integer> destinationNode;
if (nodes.containsKey(destinationNodeID)) {
destinationNode = nodes.get(destinationNodeID);
} else {
destinationNode = new Node<Integer>(destinationNodeID);
nodes.put(destinationNodeID, destinationNode);
}
String remainingArcs = arcMatcher.group(4);
if (logger.isLoggable(Level.FINE)) logger.fine("\t" + arcLabel + " " + arcWeight + " " + destinationNodeID);
Integer intArcLabel = symbolTable.getID(arcLabel);
currentNode.addArc(destinationNode, arcWeight, intArcLabel);
arcMatcher = arcPattern.matcher(remainingArcs);
}
if (numArcs > 1) latticeIsAmbiguous = true;
nodeMatcher = nodePattern.matcher(remainingData);
}
List<Node<Integer>> nodeList = new ArrayList<Node<Integer>>(nodes.values());
Collections.sort(nodeList, new NodeIdentifierComparator());
if (logger.isLoggable(Level.FINE)) logger.fine(nodeList.toString());
return new Lattice<Integer>(nodeList, latticeIsAmbiguous);
}
/**
* Constructs a lattice from a given string representation.
*
* @param data String representation of a lattice.
* @return A lattice that corresponds to the given string.
*/
public static Lattice<String> createFromString(String data) {
Map<Integer,Node<String>> nodes = new HashMap<Integer,Node<String>>();
Pattern nodePattern = Pattern.compile("(.+?)\\((\\(.+?\\),)\\)(.*)");
Pattern arcPattern = Pattern.compile("\\('(.+?)',(\\d+.\\d+),(\\d+)\\),(.*)");
Matcher nodeMatcher = nodePattern.matcher(data);
int nodeID = -1;
while (nodeMatcher.matches()) {
String nodeData = nodeMatcher.group(2);
String remainingData = nodeMatcher.group(3);
nodeID++;
Node<String> currentNode;
if (nodes.containsKey(nodeID)) {
currentNode = nodes.get(nodeID);
} else {
currentNode = new Node<String>(nodeID);
nodes.put(nodeID, currentNode);
}
if (logger.isLoggable(Level.FINE)) logger.fine("Node " + nodeID + ":");
Matcher arcMatcher = arcPattern.matcher(nodeData);
while (arcMatcher.matches()) {
String arcLabel = arcMatcher.group(1);
double arcWeight = Double.valueOf(arcMatcher.group(2));
int destinationNodeID = nodeID + Integer.valueOf(arcMatcher.group(3));
Node<String> destinationNode;
if (nodes.containsKey(destinationNodeID)) {
destinationNode = nodes.get(destinationNodeID);
} else {
destinationNode = new Node<String>(destinationNodeID);
nodes.put(destinationNodeID, destinationNode);
}
String remainingArcs = arcMatcher.group(4);
if (logger.isLoggable(Level.FINE)) logger.fine("\t" + arcLabel + " " + arcWeight + " " + destinationNodeID);
currentNode.addArc(destinationNode, arcWeight, arcLabel);
arcMatcher = arcPattern.matcher(remainingArcs);
}
nodeMatcher = nodePattern.matcher(remainingData);
}
List<Node<String>> nodeList = new ArrayList<Node<String>>(nodes.values());
Collections.sort(nodeList, new NodeIdentifierComparator());
if (logger.isLoggable(Level.FINE)) logger.fine(nodeList.toString());
return new Lattice<String>(nodeList);
}
/**
* Gets the cost of the shortest path between two nodes.
*
* @param from ID of the starting node.
* @param to ID of the ending node.
* @return The cost of the shortest path between the two
* nodes.
*/
public double getShortestPath(int from, int to) {
return costs[from][to];
}
/**
* Gets the node with a specified integer identifier.
*
* @param index Integer identifier for a node.
* @return The node with the specified integer identifier
*/
public Node<Value> getNode(int index) {
return nodes.get(index);
}
/**
* Returns an iterator over the nodes in this lattice.
*
* @return An iterator over the nodes in this lattice.
*/
public Iterator<Node<Value>> iterator() {
return nodes.iterator();
}
/**
* Returns the number of nodes in this lattice.
*
* @return The number of nodes in this lattice.
*/
public int size() {
return nodes.size();
}
/**
* Calculate the all-pairs shortest path for all pairs of
* nodes.
* <p>
* Note: This method assumes no backward arcs. If there are
* backward arcs, the returned shortest path costs for that
* node may not be accurate.
*
* @param nodes A list of nodes which must be in topological
* order.
* @return The all-pairs shortest path for all pairs of nodes.
*/
private double[][] calculateAllPairsShortestPath(List<Node<Value>> nodes) {
int size = nodes.size();
double[][] costs = new double[size][size];
// Initialize pairwise costs to be infinite for
// each pair of nodes
for (int from = 0; from < size; from++) {
for (int to = 0; to < size; to++) {
costs[from][to] = Double.POSITIVE_INFINITY;
}
}
// Loop over all pairs of immediate neighbors and
// record the actual costs.
for (Node<Value> head : nodes) {
for (Arc<Value> arc : head.outgoingArcs) {
Node<Value> tail = arc.tail;
int from = head.id;
int to = tail.id;
// this is slightly different
// than it was defined in Dyer et al 2008
double cost = arc.cost;
// minimally, cost should be weighted by
// the feature weight assigned, so we just
// set this to 1.0 for now
cost = 1.0;
if (cost < costs[from][to]) {
costs[from][to] = cost;
}
}
}
// Loop over every possible starting node (the last
// node is assumed to not be a starting node)
for (int i=0; i < size-2; i++) {
// Loop over every possible ending node,
// starting two nodes past the starting
// node (this assumes no backward arcs)
for (int j=i+2; j < size; j++) {
// Loop over every possible middle
// node, starting one node past the
// starting node (this assumes no
// backward arcs)
for (int k=i+1; k < j; k++) {
// The best cost is the
// minimum of the previously
// recorded cost and the sum
// of costs in the currently
// considered path
costs[i][j] = Math.min(costs[i][j], costs[i][k] + costs[k][j]);
}
}
}
return costs;
}
@Override
public String toString() {
StringBuilder s = new StringBuilder();
for (Node<Value> start : this) {
for (Arc<Value> arc : start.getOutgoingArcs()) {
s.append(arc.toString());
s.append('\n');
}
}
return s.toString();
}
public static void main(String[] args) {
List<Node<String>> nodes = new ArrayList<Node<String>>();
for (int i=0; i < 4; i++) {
nodes.add(new Node<String>(i));
}
nodes.get(0).addArc(nodes.get(1), 1.0, "x");
nodes.get(1).addArc(nodes.get(2), 1.0, "y");
nodes.get(0).addArc(nodes.get(2), 1.5, "a");
nodes.get(2).addArc(nodes.get(3), 3.0, "b");
nodes.get(2).addArc(nodes.get(3), 5.0, "c");
Lattice<String> graph = new Lattice<String>(nodes);
System.out.println("Shortest path from 0 to 3: " + graph.getShortestPath(0,3));
}
}