/* * File: EdgeMergingEnergyFunction.java * Authors: Jeremy D. Wendt * Company: Sandia National Laboratories * Project: Cognitive Foundry * * Copyright 2016, Sandia Corporation. * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive * license for use of this work by or on behalf of the U.S. Government. * Export of this program may require a license from the United States * Government. See CopyrightHistory.txt for complete details. * */ package gov.sandia.cognition.graph.inference; import gov.sandia.cognition.util.DefaultPair; import gov.sandia.cognition.util.Pair; import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; /** * Our implementation of belief propagation requires that there be at most one * edge between any pair of nodes. This class wraps any graph (with any number * of edges between pairs of nodes) and merges any edges between any two nodes * into one BP-visible edge while preserving the correct results. This wrapper * makes the edges undirected (changing the direction stored in this class from * that stored in the parent class), so don't use with * SumProductPairwiseBayesNet. * * @author jdwendt * @param <LabelType> * @param <NodeNameType> */ public class EdgeMergingEnergyFunction<LabelType, NodeNameType> implements NodeNameAwareEnergyFunction<LabelType, NodeNameType> { /** * A helper class for storing the (possibly) multiple parallel edges between * two nodes in the input graph. * * @author jdwendt * */ private static class UniqueEdge { /** * The edge's lower vertex index (note: This will change the direction * of an edge to store the lower index here) */ private final int idxi; /** * The edge's higher vertex index (note: This will change the direction * of an edge to store the higher index here) */ private final int idxj; /** * Initializes this edge with appropriate values * * @param src The edge's source vertex index * @param dst The edge's destination vertex index */ public UniqueEdge(int src, int dst) { // Note that as BP runs messages in both directions across an edge, // this can ignore node order on the edge if (src < dst) { this.idxi = src; this.idxj = dst; } else { this.idxi = dst; this.idxj = src; } } /** * This method returns true if the input src/dst indices were flipped * for storage in this class. This is necessary information for belief * propagation as flipping the indices requires flipping the potential * matrix values. This class can't store that internally as it would * affect the equals and hash methods and those need to find edges equal * if they have the same indices regardless of whether they were * flipped. * * NOTE: This method will return false if you hand it a src and dst that * don't match the internal indices at all. While it may be more robust * to throw an exception, this is a private, internal class that should * only be used herein in specific ways. * * @param src The original source index for the edge * @param dst The original destination index for the edge * @return true if the edge indices were flipped internal to this class, * else false. */ public boolean wasFlipped(int src, int dst) { return (src == this.idxj) && (dst == this.idxi); } @Override public boolean equals(Object o) { if (!(o instanceof UniqueEdge)) { return false; } UniqueEdge e = (UniqueEdge) o; if (idxi != e.idxi) { return false; } else if (idxj != e.idxj) { return false; } return true; } @Override public int hashCode() { int ret = 7; ret += ret * 5 + idxi; ret += ret * 5 + idxj; return ret; } }; /** * Internal helper class that stores values for graph edges of the original * graph. These values are used when computing things like the pairwise * potential. */ private static class EdgeIndex { /** * The index of this edge in the underlying graph */ public int internalIdx; /** * True if the edge was flipped to match default edge-node ordering */ public boolean wasFlipped; /** * Stores the specified values in this data type object * * @param internalIdx The index of this edge in the underlying graph * @param wasFlipped True if the edge was flipped to match default * edge-node ordering */ public EdgeIndex(int internalIdx, boolean wasFlipped) { this.internalIdx = internalIdx; this.wasFlipped = wasFlipped; } } /** * Private helper class that stores the map from canonically ordered edges * to the (possibly multiple) edges. That is, this is a one-to-many map * where the outside world sees each edge only once, but the graph being * wrapped may have multiple edges between any pair of nodes. */ private static class UniqueEdgeMap { /** * This is the editable version of the data. Once initialized fully, * call convert to switch to the faster-to-run-through-and-look-up list. */ private Map<UniqueEdge, List<EdgeIndex>> editableEdgeMap; /** * This is the faster, read-only version of the edges */ private List<Pair<UniqueEdge, List<EdgeIndex>>> listEdgeMap; /** * Initialize an empty edge map */ public UniqueEdgeMap() { editableEdgeMap = new HashMap<>(); listEdgeMap = null; } /** * Add edges to the map. This may or may not add new visible edges * (depending on if the input edge is a repeat of an already externally * visible edge). This method should not be called after convert is * called. * * @param src The source node of the edge * @param dst The destination node of the edge * @param wrappedIdx The edge's index in the wrapped graph * @throws RuntimeException if this method called after convert called */ public void addEdge(int src, int dst, int wrappedIdx) { if (editableEdgeMap == null) { throw new RuntimeException( "Can't addEdges once converted to list"); } UniqueEdge e = new UniqueEdge(src, dst); if (!editableEdgeMap.containsKey(e)) { editableEdgeMap.put(e, new ArrayList<>()); } editableEdgeMap.get(e).add(new EdgeIndex(wrappedIdx, e.wasFlipped( src, dst))); } /** * Converts from the editable edge map to the dense-memory edge list */ public void convert() { listEdgeMap = new ArrayList<>(editableEdgeMap.size()); for (Map.Entry<UniqueEdge, List<EdgeIndex>> e : editableEdgeMap.entrySet()) { listEdgeMap.add(new DefaultPair<>(e.getKey(), e.getValue())); } editableEdgeMap.clear(); editableEdgeMap = null; } /** * Returns the number of externally visible edges stored herein * * @return the number of externally visible edges stored herein */ public int size() { if (editableEdgeMap != null) { return editableEdgeMap.size(); } else { return listEdgeMap.size(); } } /** * Returns the externally visible edge stored at index i (may contain * multiple wrapped-graph edges). * * @param i The edge index * @return the externally visible edge stored at index i */ public Pair<UniqueEdge, List<EdgeIndex>> getEdge(int i) { if (listEdgeMap == null) { throw new RuntimeException( "Edge map not finalized before this call"); } return listEdgeMap.get(i); } /** * Returns the source/dest pair for the edge stored at index i * * @param i The edge index * @return the source/dest pair for the edge stored at index i */ public Pair<Integer, Integer> getEdgePair(int i) { UniqueEdge e = getEdge(i).getFirst(); return new DefaultPair<>(e.idxi, e.idxj); } }; /** * The energy function being wrapped by this */ private final NodeNameAwareEnergyFunction<LabelType, NodeNameType> wrapped; /** * The externally visible edges generated by removing repeated edges in * wrapped */ private final UniqueEdgeMap edgeMap; /** * Initializes this edge merging energy function * * @param wrapMe The energy function to wrap */ public EdgeMergingEnergyFunction( NodeNameAwareEnergyFunction<LabelType, NodeNameType> wrapMe) { this.wrapped = wrapMe; this.edgeMap = new UniqueEdgeMap(); for (int i = 0; i < wrapMe.numEdges(); ++i) { Pair<Integer, Integer> edge = wrapMe.getEdge(i); edgeMap.addEdge(edge.getFirst(), edge.getSecond(), i); } edgeMap.convert(); } /** * @see NodeNameAwareEnergyFunction#setLabel(java.lang.Object, * java.lang.Object) */ @Override public void setLabel(NodeNameType node, LabelType label) { wrapped.setLabel(node, label); } /** * @see NodeNameAwareEnergyFunction#getBeliefs(java.lang.Object, * gov.sandia.cognition.graph.inference.EnergyFunctionSolver) */ @Override public Map<LabelType, Double> getBeliefs(NodeNameType node, EnergyFunctionSolver<LabelType> bp) { return wrapped.getBeliefs(node, bp); } /** * @see NodeNameAwareEnergyFunction#getPossibleLabels(int) */ @Override public Collection<LabelType> getPossibleLabels(int nodeId) { return wrapped.getPossibleLabels(nodeId); } /** * @see NodeNameAwareEnergyFunction#numEdges() */ @Override public int numEdges() { return edgeMap.size(); } /** * @see NodeNameAwareEnergyFunction#numNodes() */ @Override public int numNodes() { return wrapped.numNodes(); } /** * @see NodeNameAwareEnergyFunction#getEdge(int) */ @Override public Pair<Integer, Integer> getEdge(int i) { return edgeMap.getEdgePair(i); } /** * @see NodeNameAwareEnergyFunction#getUnaryPotential(int, java.lang.Object) */ @Override public double getUnaryPotential(int i, LabelType label) { return wrapped.getUnaryPotential(i, label); } /** * @see NodeNameAwareEnergyFunction#getPairwisePotential(int, * java.lang.Object, java.lang.Object) */ @Override public double getPairwisePotential(int edgeId, LabelType ilabel, LabelType jlabel) { Pair<UniqueEdge, List<EdgeIndex>> edge = edgeMap.getEdge(edgeId); double ret = 1.0; for (EdgeIndex ei : edge.getSecond()) { if (ei.wasFlipped) { ret *= wrapped.getPairwisePotential(ei.internalIdx, jlabel, ilabel); } else { ret *= wrapped.getPairwisePotential(ei.internalIdx, ilabel, jlabel); } } return ret; } /** * @see NodeNameAwareEnergyFunction#getUnaryCost(int, java.lang.Object) */ @Override public double getUnaryCost(int i, LabelType label) { return wrapped.getUnaryCost(i, label); } /** * @see NodeNameAwareEnergyFunction#getPairwiseCost(int, java.lang.Object, * java.lang.Object) */ @Override public double getPairwiseCost(int edgeId, LabelType ilabel, LabelType jlabel) { return -Math.log(getPairwisePotential(edgeId, ilabel, jlabel)); } }