package edu.isi.karma.modeling.research.graph.roek.nlpged.algorithm; import java.util.List; import edu.isi.karma.modeling.research.graph.konstantinosnedas.HungarianAlgorithm; import edu.isi.karma.modeling.research.graph.roek.nlpged.graph.Edge; import edu.isi.karma.modeling.research.graph.roek.nlpged.graph.Graph; import edu.isi.karma.modeling.research.graph.roek.nlpged.graph.Node; public class GraphEditDistance { private double[][] costMatrix; protected final double SUBSTITUTE_COST; protected final double INSERT_COST; protected final double DELETE_COST; private Graph g1, g2; // private Map<String, Double> posEditWeights, deprelEditWeights; public GraphEditDistance(Graph g1, Graph g2, double subCost, double insCost, double delCost) { this.SUBSTITUTE_COST = subCost; this.INSERT_COST = insCost; this.DELETE_COST = delCost; this.g1 = g1; this.g2 = g2; // this.posEditWeights = posEditWeights; // this.deprelEditWeights = deprelEditWeights; this.costMatrix = createCostMatrix(); } public static void main(String[] args) { Graph g1 = new Graph(); Graph g2 = new Graph(); Node m1 = new Node("", "1"); Node m2 = new Node("", "2"); Edge e1 = new Edge("", m1, m2, "e1"); Node n1 = new Node("", "1"); Node n2 = new Node("", "2"); Edge l1 = new Edge("", n1, n2, "e2"); g1.addNode(m1); g1.addNode(m2); g1.addEdge(e1); g2.addNode(n1); g2.addNode(n2); g2.addEdge(l1); System.out.println(new GraphEditDistance(g1, g2).getDistance()); System.out.println(new GraphEditDistance(g1, g2).getNormalizedDistance()); } public GraphEditDistance(Graph g1, Graph g2) { this(g1, g2, 1, 1, 1); } public double getNormalizedDistance() { /** * Retrieves the approximated graph edit distance between the two graphs g1 & g2. * The distance is normalized on graph length */ double graphLength = (g1.getSize()+g2.getSize())/2; return getDistance() / graphLength; } public double getDistance() { /** * Retrieves the approximated graph edit distance between the two graphs g1 & g2. */ int[][] assignment = HungarianAlgorithm.hgAlgorithm(this.costMatrix, "min"); double sum = 0; for (int i=0; i<assignment.length; i++){ sum = (sum + costMatrix[assignment[i][0]][assignment[i][1]]); } return sum; } public double[][] getCostMatrix() { if(costMatrix==null) { this.costMatrix = createCostMatrix(); } return costMatrix; } public double[][] createCostMatrix() { /** * Creates the cost matrix used as input to Munkres algorithm. * The matrix consists of 4 sectors: upper left, upper right, bottom left, bottom right. * Upper left represents the cost of all N x M node substitutions. * Upper right node deletions * Bottom left node insertions. * Bottom right represents delete -> delete operations which should have any cost, and is filled with zeros. */ int n = g1.getNodes().size(); int m = g2.getNodes().size(); double[][] costMatrix = new double[n+m][n+m]; for (int i = 0; i < n; i++) { for (int j = 0; j < m; j++) { costMatrix[i][j] = getSubstituteCost(g1.getNode(i), g2.getNode(j)); } } for (int i = 0; i < m; i++) { for (int j = 0; j < m; j++) { costMatrix[i+n][j] = getInsertCost(i, j); } } for (int i = 0; i < n; i++) { for (int j = 0; j < n; j++) { costMatrix[j][i+m] = getDeleteCost(i, j); } } return costMatrix; } private double getInsertCost(int i, int j) { if(i == j) { return getPosWeight(g2.getNode(j)) * INSERT_COST; } return Double.MAX_VALUE; } private double getDeleteCost(int i, int j) { if(i == j) { return getPosWeight(g1.getNode(i)) * DELETE_COST; } return Double.MAX_VALUE; } public double getSubstituteCost(Node node1, Node node2) { double diff = (getRelabelCost(node1, node2) + getEdgeDiff(node1, node2)) / 2; return diff * SUBSTITUTE_COST; } public double getRelabelCost(Node node1, Node node2) { double diff = 0; if(!node1.equals(node2)) { diff = getPosWeight(node1, node2); } return diff ; } public double getPosWeight(Node node) { // return getPosWeight(node.getAttributes().get(0)); return 1.0; } public double getPosWeight(Node node1, Node node2) { // return getPosWeight(node1.getAttributes().get(0)+","+node2.getAttributes().get(0)); return 1.0; } public double getPosWeight(String key) { // Double posWeight = posEditWeights.get(key); // if(posWeight == null) { // return 1; // } // return posWeight; return 1.0; } public double getEdgeDiff(Node node1, Node node2) { List<Edge> edges1 = g1.getEdges(node1); List<Edge> edges2 = g2.getEdges(node2); if(edges1.size() == 0 || edges2.size() == 0) { return getWeightSum(edges1) + getWeightSum(edges2); } int n = edges1.size(); int m = edges2.size(); double[][] edgeCostMatrix = new double[n+m][m+n]; for (int i = 0; i < n; i++) { for (int j = 0; j < m; j++) { edgeCostMatrix[i][j] = getEdgeEditCost(edges1.get(i), edges2.get(j)); } } for (int i = 0; i < m; i++) { for (int j = 0; j < m; j++) { edgeCostMatrix[i+n][j] = getEdgeInsertCost(i, j, edges2.get(j)); } } for (int i = 0; i < n; i++) { for (int j = 0; j < n; j++) { edgeCostMatrix[j][i+m] = getEdgeDeleteCost(i, j, edges1.get(i)); } } int[][] assignment = HungarianAlgorithm.hgAlgorithm(edgeCostMatrix, "min"); double sum = 0; for (int i=0; i<assignment.length; i++){ sum += edgeCostMatrix[assignment[i][0]][assignment[i][1]]; } return sum / ((n+m)); } public double getDeprelWeight(Edge edge) { // Double weight = deprelEditWeights.get(edge.getLabel()); // if(weight == null) { // return 1; // } // return weight; return 1.0; } public double getWeightSum(List<Edge> edges) { double sum = 0; for (Edge edge : edges) { sum += getDeprelWeight(edge); } return sum; } public double getEdgeInsertCost(int i, int j, Edge edge2) { if(i==j) { return getDeprelWeight(edge2) * INSERT_COST; } return Double.MAX_VALUE; } public double getEdgeDeleteCost(int i, int j, Edge edge1) { if(i==j) { return getDeprelWeight(edge1) * DELETE_COST; } return Double.MAX_VALUE; } public double getEdgeEditCost(Edge edge1, Edge edge2) { if(edge1.equals(edge2) && edge1.getFrom().getLabel().equals(edge2.getFrom().getLabel()) && edge1.getTo().getLabel().equals(edge2.getTo().getLabel())) { return 0; } return 1; } public void printMatrix() { System.out.println("-------------"); System.out.println("Cost matrix: "); for (int i = 0; i < costMatrix.length; i++) { for (int j = 0; j < costMatrix.length; j++) { if(costMatrix[i][j] == Double.MAX_VALUE) { System.out.print("inf\t"); }else{ System.out.print(String.format("%.2f", costMatrix[i][j])+"\t"); } } System.out.println(); } } }