/*
* File: GraphWrappingEnergyFunction.java
* Authors: Jeremy D. Wendt
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright 2016, Sandia Corporation.
* Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
* license for use of this work by or on behalf of the U.S. Government.
* Export of this program may require a license from the United States
* Government. See CopyrightHistory.txt for complete details.
*
*/
package gov.sandia.cognition.graph.inference;
import gov.sandia.cognition.graph.DirectedNodeEdgeGraph;
import gov.sandia.cognition.util.Pair;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* This class is a simple wrapper for any input graph becoming a BP-capable
* graph. It does require the caller prepare a class that will provide the
* pairwise and unary potentials for edges and nodes and the possible labels for
* nodes. This class is intended to be the top layer in layered energy function
* implementations.
*
* @author jdwendt
* @param <LabelType> The type for labels assigned to nodes
* @param <NodeNameType> The type for names of nodes
*/
public class GraphWrappingEnergyFunction<LabelType, NodeNameType>
implements NodeNameAwareEnergyFunction<LabelType, NodeNameType>
{
/**
* The graph itself
*/
private final DirectedNodeEdgeGraph<NodeNameType> graph;
/**
* The handler for providing problem-specific values
*/
private final PotentialHandler<LabelType, NodeNameType> handler;
/**
* Contains the nodes which have been assigned a BP label.
*/
private final Map<Integer, LabelType> labeledNodes;
/**
* This speeds up requests for getting edges by caching them. As the edges
* are iterated over likely many times, it's worth the memory cost.
*/
private final List<Pair<Integer, Integer>> edges;
/**
* This interface defines the problem-specific methods this class requires
* as input
*
* @param <LabelType> The type for labels assigned to nodes
* @param <NodeNameType> The type for names of nodes
*/
public interface PotentialHandler<LabelType, NodeNameType>
{
/**
* Provide the pairwise potential for the specified edge
*
* @param graph The graph the edge is from
* @param edgeId The id of the edge whose potential is wanted
* @param ilabel The label being considered for the first node on the
* edge
* @param jlabel The label being considered for the second node on the
* edge
* @return the pairwise potential
*/
public double getPairwisePotential(
DirectedNodeEdgeGraph<NodeNameType> graph,
int edgeId,
LabelType ilabel,
LabelType jlabel);
/**
* Provide the unary potential for the specified node
*
* @param graph The graph the node is from
* @param i The id of the node whose potential is wanted
* @param label The label being considered for the node
* @param assignedLabel the label assigned to the node (null if none
* assigned via setLabel)
* @return the unary potential
*/
public double getUnaryPotential(
DirectedNodeEdgeGraph<NodeNameType> graph,
int i,
LabelType label,
LabelType assignedLabel);
/**
* Get the possible labels for the specified node
*
* @param graph The graph the node is from
* @param nodeId The id of the node in the graph
* @return the possible labels for that node
*/
public List<LabelType> getPossibleLabels(
DirectedNodeEdgeGraph<NodeNameType> graph,
int nodeId);
}
/**
* Creates a new instance of this class containing the input graph (shallow
* copy -- don't change it after this!) and the handler for problem-specific
* details (also only a shallow copy).
*
* @param graph The graph this contains
* @param handler The handler for problem-specific methods
*/
public GraphWrappingEnergyFunction(
DirectedNodeEdgeGraph<NodeNameType> graph,
PotentialHandler<LabelType, NodeNameType> handler)
{
this.labeledNodes = new HashMap<>();
this.graph = graph;
this.handler = handler;
this.edges = new ArrayList<>(graph.getNumEdges());
for (int i = 0; i < graph.getNumEdges(); ++i)
{
edges.add(graph.getEdgeEndpointIds(i));
}
}
/**
* @see EnergyFunction#getPossibleLabels(int)
*/
@Override
public Collection<LabelType> getPossibleLabels(int nodeId)
{
return handler.getPossibleLabels(graph, nodeId);
}
/**
* @see EnergyFunction#getEdge(int)
*/
@Override
public Pair<Integer, Integer> getEdge(int i)
{
return edges.get(i);
}
/**
* @see EnergyFunction#getUnaryPotential(int, java.lang.Object)
*/
@Override
public double getUnaryPotential(int i,
LabelType label)
{
return handler.getUnaryPotential(graph, i, label, labeledNodes.get(i));
}
/**
* @see EnergyFunction#getUnaryCost(int, java.lang.Object)
*/
@Override
public double getUnaryCost(int i,
LabelType label)
{
double p = getUnaryPotential(i, label);
// This is actually faster than evaluating -Math.log(0)
if (p == 0)
{
return Double.MAX_VALUE;
}
return -Math.log(p);
}
/**
* @see EnergyFunction#getPairwisePotential(int, java.lang.Object,
* java.lang.Object)
*/
@Override
public double getPairwisePotential(int edgeId,
LabelType ilabel,
LabelType jlabel)
{
return handler.getPairwisePotential(graph, edgeId, ilabel, jlabel);
}
/**
* @see EnergyFunction#getPairwiseCost(int, java.lang.Object,
* java.lang.Object)
*/
@Override
public double getPairwiseCost(int edgeId,
LabelType ilabel,
LabelType jlabel)
{
double p = getPairwisePotential(edgeId, ilabel, jlabel);
// This is actually faster than evaluating -Math.log(0)
if (p == 0)
{
return Double.MAX_VALUE;
}
return -Math.log(p);
}
/**
* @see EnergyFunction#numEdges()
*/
@Override
public int numEdges()
{
return graph.getNumEdges();
}
/**
* @see EnergyFunction#numNodes()
*/
@Override
public int numNodes()
{
return graph.getNumNodes();
}
/**
* @see NodeTypeAwareEnergyFunction#setLabel(java.lang.Object,
* java.lang.Object)
*/
@Override
public void setLabel(NodeNameType node,
LabelType label)
{
int nodeId = graph.getNodeId(node);
if (!handler.getPossibleLabels(graph, nodeId).contains(label))
{
throw new IllegalArgumentException("Input label (" + label
+ ") can't be assigned to node " + node);
}
labeledNodes.put(nodeId, label);
}
/**
* Clears labels previously set. Critical for if you want to re-use the same
* energy function on multiple different runs.
*/
public void clearLabels()
{
labeledNodes.clear();
}
/**
* @see NodeTypeAwareEnergyFunction#getBeliefs(java.lang.Object,
* gov.sandia.cognition.graph.inference.BeliefPropagation)
*/
@Override
public Map<LabelType, Double> getBeliefs(NodeNameType node,
EnergyFunctionSolver<LabelType> bp)
{
Map<LabelType, Double> ret = new HashMap<>();
int nodeId = graph.getNodeId(node);
List<LabelType> labels = handler.getPossibleLabels(graph, nodeId);
for (int i = 0; i < labels.size(); ++i)
{
ret.put(labels.get(i), bp.getBelief(nodeId, i));
}
return ret;
}
}