package edu.berkeley.nlp.syntax;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import edu.berkeley.nlp.util.CollectionUtils;
import edu.berkeley.nlp.util.Counter;
import edu.berkeley.nlp.util.Factory;
/**
* Assumes the type V is hashable
*
* @author adampauls
*
* @param <V>
*/
public class UnaryClosureComputer<V>
{
public static class Edge<V>
{
@Override
public int hashCode()
{
final int prime = 31;
int result = 1;
result = prime * result + ((child == null) ? 0 : child.hashCode());
result = prime * result + ((parent == null) ? 0 : parent.hashCode());
return result;
}
@Override
public boolean equals(Object obj)
{
if (this == obj) return true;
if (obj == null) return false;
if (getClass() != obj.getClass()) return false;
Edge other = (Edge) obj;
if (child == null)
{
if (other.child != null) return false;
}
else if (!child.equals(other.child)) return false;
if (parent == null)
{
if (other.parent != null) return false;
}
else if (!parent.equals(other.parent)) return false;
return true;
}
public void setParent(V parent)
{
this.parent = parent;
}
public void setChild(V child)
{
this.child = child;
}
private V parent;
private V child;
private double score;
private Edge(V parent, V child)
{
this.parent = parent;
this.child = child;
}
public V getParent()
{
return parent;
}
public V getChild()
{
return child;
}
public double getScore()
{
return score;
}
public void setScore(double d)
{
score = d;
}
}
private Factory<Edge> unaryRuleFactory = new Factory<Edge>()
{
public Edge newInstance(Object... args)
{
return new Edge(args[0], args[1]);
}
};
Map<V, List<Edge<V>>> closedUnaryRulesByChild = new HashMap<V, List<Edge<V>>>();
Map<V, List<Edge<V>>> closedUnaryRulesByParent = new HashMap<V, List<Edge<V>>>();
Map<Edge<V>, List<V>> pathMap = new HashMap<Edge<V>, List<V>>();
Set<Edge<V>> unaryRules = new HashSet<Edge<V>>();
private boolean sumInsteadOfMultipy;
/**
* First is parent, second is child;
*
* @return
*/
public Map<V, List<Edge<V>>> getAllClosedRulesByChildren()
{
return closedUnaryRulesByChild;
}
public List<Edge<V>> getClosedUnaryRulesByChild(V child)
{
return CollectionUtils.getValueList(closedUnaryRulesByChild, child);
}
public List<Edge<V>> getClosedUnaryRulesByParent(V parent)
{
return CollectionUtils.getValueList(closedUnaryRulesByParent, parent);
}
public List<V> getPath(Edge unaryRule)
{
return pathMap.get(unaryRule);
}
@Override
public String toString()
{
StringBuilder sb = new StringBuilder();
for (V parent : closedUnaryRulesByParent.keySet())
{
for (Edge unaryRule : getClosedUnaryRulesByParent(parent))
{
List<V> path = getPath(unaryRule);
// if (path.size() == 2) continue;
sb.append(unaryRule);
sb.append(" ");
sb.append(path);
sb.append("\n");
}
}
return sb.toString();
}
public UnaryClosureComputer(boolean sumInsteadOfMultiply)
{
this.sumInsteadOfMultipy = sumInsteadOfMultiply;
}
public void add(V parent, V child, double score)
{
final Edge edge = new Edge(parent, child);
edge.setScore(score);
unaryRules.add(edge);
}
public void solve()
{
Map<Edge<V>, List<V>> closureMap = computeUnaryClosure(unaryRules);
for (Edge<V> unaryRule : closureMap.keySet())
{
addUnary(unaryRule, closureMap.get(unaryRule));
}
}
private void addUnary(Edge<V> unaryRule, List<V> path)
{
CollectionUtils.addToValueList(closedUnaryRulesByChild, unaryRule.getChild(), unaryRule);
CollectionUtils.addToValueList(closedUnaryRulesByParent, unaryRule.getParent(), unaryRule);
pathMap.put(unaryRule, path);
}
private Map<Edge<V>, List<V>> computeUnaryClosure(Collection<Edge<V>> unaryRules)
{
Map<Edge<V>, V> intermediateStates = new HashMap<Edge<V>, V>();
Counter<Edge<V>> pathCosts = new Counter<Edge<V>>();
Map<V, List<Edge<V>>> closedUnaryRulesByChild = new HashMap<V, List<Edge<V>>>();
Map<V, List<Edge<V>>> closedUnaryRulesByParent = new HashMap<V, List<Edge<V>>>();
Set<V> states = new HashSet<V>();
for (Edge<V> unaryRule : unaryRules)
{
relax(pathCosts, intermediateStates, closedUnaryRulesByChild, closedUnaryRulesByParent, unaryRule, null, unaryRule.getScore());
states.add(unaryRule.getParent());
states.add(unaryRule.getChild());
}
for (V intermediateState : states)
{
List<Edge<V>> incomingRules = closedUnaryRulesByChild.get(intermediateState);
List<Edge<V>> outgoingRules = closedUnaryRulesByParent.get(intermediateState);
if (incomingRules == null || outgoingRules == null) continue;
for (Edge<V> incomingRule : incomingRules)
{
for (Edge<V> outgoingRule : outgoingRules)
{
Edge<V> rule = unaryRuleFactory.newInstance(incomingRule.getParent(), outgoingRule.getChild());
double newScore = combinePathCosts(pathCosts, incomingRule, outgoingRule);
relax(pathCosts, intermediateStates, closedUnaryRulesByChild, closedUnaryRulesByParent, rule, intermediateState, newScore);
}
}
}
for (V state : states)
{
Edge<V> selfLoopRule = unaryRuleFactory.newInstance(state, state);
relax(pathCosts, intermediateStates, closedUnaryRulesByChild, closedUnaryRulesByParent, selfLoopRule, null, 0.0);
}
Map<Edge<V>, List<V>> closureMap = new HashMap<Edge<V>, List<V>>();
for (Edge<V> unaryRule : pathCosts.keySet())
{
unaryRule.setScore(pathCosts.getCount(unaryRule));
List<V> path = extractPath(unaryRule, intermediateStates);
closureMap.put(unaryRule, path);
}
return closureMap;
}
/**
* @param pathCosts
* @param incomingRule
* @param outgoingRule
* @return
*/
private double combinePathCosts(Counter<Edge<V>> pathCosts, Edge<V> incomingRule, Edge<V> outgoingRule)
{
return this.sumInsteadOfMultipy ? (pathCosts.getCount(incomingRule) + pathCosts.getCount(outgoingRule)) : (pathCosts.getCount(incomingRule) * pathCosts
.getCount(outgoingRule));
}
private List<V> extractPath(Edge<V> unaryRule, Map<Edge<V>, V> intermediateStates)
{
List<V> path = new ArrayList<V>();
path.add(unaryRule.getParent());
V intermediateState = intermediateStates.get(unaryRule);
if (intermediateState != null)
{
List<V> parentPath = extractPath(unaryRuleFactory.newInstance(unaryRule.getParent(), intermediateState), intermediateStates);
for (int i = 1; i < parentPath.size() - 1; i++)
{
V state = parentPath.get(i);
path.add(state);
}
path.add(intermediateState);
List<V> childPath = extractPath(unaryRuleFactory.newInstance(intermediateState, unaryRule.getChild()), intermediateStates);
for (int i = 1; i < childPath.size() - 1; i++)
{
V state = childPath.get(i);
path.add(state);
}
}
if (path.size() == 1 && unaryRule.getParent() == unaryRule.getChild()) return path;
path.add(unaryRule.getChild());
return path;
}
private void relax(Counter<Edge<V>> pathCosts, Map<Edge<V>, V> intermediateStates, Map<V, List<Edge<V>>> closedUnaryRulesByChild,
Map<V, List<Edge<V>>> closedUnaryRulesByParent, Edge<V> unaryRule, V intermediateState, double newScore)
{
if (intermediateState != null && (intermediateState.equals(unaryRule.getParent()) || intermediateState.equals(unaryRule.getChild()))) return;
boolean isNewRule = !pathCosts.containsKey(unaryRule);
double oldScore = (isNewRule ? Double.NEGATIVE_INFINITY : pathCosts.getCount(unaryRule));
if (oldScore > newScore) return;
if (isNewRule)
{
CollectionUtils.addToValueList(closedUnaryRulesByChild, unaryRule.getChild(), unaryRule);
CollectionUtils.addToValueList(closedUnaryRulesByParent, unaryRule.getParent(), unaryRule);
}
pathCosts.setCount(unaryRule, newScore);
intermediateStates.put(unaryRule, intermediateState);
}
public double getProb(V parent, V child)
{
if (parent == child) return 0.0;
final List<Edge<V>> byParent = closedUnaryRulesByParent.get(parent);
if (byParent == null) return Double.POSITIVE_INFINITY;
int childIndex = byParent.indexOf(unaryRuleFactory.newInstance(parent, child));
if (childIndex < 0) return Double.POSITIVE_INFINITY;
final Edge<V> unaryRule = byParent.get(childIndex);
return unaryRule.getScore();
}
}