package edu.stanford.nlp.semparse.open.util;
import java.util.*;
import edu.stanford.nlp.semparse.open.dataset.entity.TargetEntity;
public class BipartiteMatcher {
private final int SOURCE = 1;
private final int SINK = -1;
private final Map<Object, Integer> fromMap;
private final Map<Object, Integer> toMap;
private final Map<Integer, List<Integer>> edges;
public BipartiteMatcher() {
this.fromMap = new HashMap<>();
this.toMap = new HashMap<>();
this.edges = new HashMap<>();
}
public BipartiteMatcher(List<TargetEntity> targetEntities, List<String> predictedEntities) {
this();
for (int i = 0; i < targetEntities.size(); i++) {
TargetEntity targetEntity = targetEntities.get(i);
for (int j = 0; j < predictedEntities.size(); j++) {
if (targetEntity.match(predictedEntities.get(j))) {
this.addEdge(i, j);
}
}
}
}
public void addEdge(Object fromObj, Object toObj) {
Integer from = fromMap.get(fromObj), to = toMap.get(toObj);
if (from == null) {
from = 2 + fromMap.size();
fromMap.put(fromObj, from);
if (!edges.containsKey(SOURCE)) edges.put(SOURCE, new ArrayList<>());
edges.get(SOURCE).add(from);
}
if (to == null) {
to = - 2 - toMap.size();
toMap.put(toObj, to);
if (!edges.containsKey(to)) edges.put(to, new ArrayList<>());
edges.get(to).add(SINK);
}
if (!edges.containsKey(from)) edges.put(from, new ArrayList<>());
edges.get(from).add(to);
}
private List<Integer> foundPath;
private Set<Integer> foundNodes;
public int findMaximumMatch() {
int count = 0;
this.foundPath = new ArrayList<>();
this.foundNodes = new HashSet<>();
while (findPath(SOURCE)) {
count++;
for (int i = 0; i < foundPath.size() - 1; i++) {
int from = foundPath.get(i), to = foundPath.get(i+1);
edges.get(from).remove(Integer.valueOf(to));
if (!edges.containsKey(to)) edges.put(to, new ArrayList<>());
edges.get(to).add(from);
}
foundPath.clear();
foundNodes.clear();
}
return count;
}
private boolean findPath(int node) {
// DFS
foundNodes.add(node);
foundPath.add(node);
if (node == SINK) return true;
for (int dest : edges.get(node)) {
if (!foundNodes.contains(dest)) {
if (findPath(dest)) return true;
}
}
foundPath.remove(foundPath.size() - 1);
return false;
}
public static void main(String[] args) {
// Test Method
BipartiteMatcher bm = new BipartiteMatcher();
bm.addEdge("A", 1); bm.addEdge("A", 2); bm.addEdge("A", 4);
bm.addEdge("B", 1); bm.addEdge("C", 2); bm.addEdge("C", 1);
bm.addEdge("D", 4); bm.addEdge("D", 5); bm.addEdge("E", 3);
System.out.println(bm.findMaximumMatch());
}
}