/*
* File: CostSpeedupEnergyFunction.java
* Authors: Jeremy D. Wendt
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright 2016, Sandia Corporation.
* Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
* license for use of this work by or on behalf of the U.S. Government.
* Export of this program may require a license from the United States
* Government. See CopyrightHistory.txt for complete details.
*
*/
package gov.sandia.cognition.graph.inference;
import gov.sandia.cognition.util.Pair;
import java.util.Collection;
import java.util.Map;
/**
* This class trades memory usage (to store all of the costs) for compute time.
* We've found that sometimes one of the slowest parts of belief propagation is
* the multiple log operations, and this class eliminates a significant number
* of those. It appears that it depends on the size of the graph if the memory
* overhead outweighs the benefits of the less computations.
*
* @author jdwendt
* @param <LabelType> The labels that can be assigned to nodes
* @param <NodeNameType> The type used to name nodes
*/
public class CostSpeedupEnergyFunction<LabelType, NodeNameType>
implements NodeNameAwareEnergyFunction<LabelType, NodeNameType>
{
/**
* The energy function being wrapped by this
*/
private final NodeNameAwareEnergyFunction<LabelType, NodeNameType> wrapped;
/**
* The local cache of pairwise costs
*/
private final double[][] pairwiseCosts;
/**
* The local cache of unary costs
*/
private final double[][] unaryCosts;
/**
* Initializes this with the wrapped function and empty values for the
* pairwise costs (which are only computed and stored as needed).
*
* @param wrapme The function to wrap with this one
*/
public CostSpeedupEnergyFunction(
NodeNameAwareEnergyFunction<LabelType, NodeNameType> wrapme)
{
this.wrapped = wrapme;
int m = wrapped.numEdges();
this.pairwiseCosts = new double[m][];
for (int i = 0; i < m; ++i)
{
Pair<Integer, Integer> edge = wrapme.getEdge(i);
int srcLabelsCnt = wrapme.getPossibleLabels(edge.getFirst()).size();
int dstLabelsCnt = wrapme.getPossibleLabels(edge.getSecond()).size();
int size = srcLabelsCnt * dstLabelsCnt;
this.pairwiseCosts[i] = new double[size];
for (int j = 0; j < size; ++j)
{
this.pairwiseCosts[i][j] = Double.MAX_VALUE;
}
}
int n = wrapped.numNodes();
this.unaryCosts = new double[n][];
for (int i = 0; i < n; ++i)
{
int size = wrapme.getPossibleLabels(i).size();
this.unaryCosts[i] = new double[size];
for (int j = 0; j < size; ++j)
{
this.unaryCosts[i][j] = Double.MAX_VALUE;
}
}
}
/**
* Clears the pre-computed costs that were stored to keep from calling log
* and potential each time. Necessary for if we want to re-use the function
* without re-initializing memory.
*/
public void clearStoredCosts()
{
int m = wrapped.numEdges();
for (int i = 0; i < m; ++i)
{
for (int j = 0; j < this.pairwiseCosts[i].length; ++j)
{
this.pairwiseCosts[i][j] = Double.MAX_VALUE;
}
}
int n = wrapped.numNodes();
for (int i = 0; i < n; ++i)
{
for (int j = 0; j < this.unaryCosts[i].length; ++j)
{
this.unaryCosts[i][j] = Double.MAX_VALUE;
}
}
}
/**
* @see NodeNameAwareEnergyFunction#setLabel(java.lang.Object,
* java.lang.Object)
*/
@Override
public void setLabel(NodeNameType node,
LabelType label)
{
wrapped.setLabel(node, label);
}
/**
* @see NodeNameAwareEnergyFunction#getBeliefs(java.lang.Object,
* gov.sandia.cognition.graph.inference.BeliefPropagation)
*/
@Override
public Map<LabelType, Double> getBeliefs(NodeNameType node,
EnergyFunctionSolver<LabelType> bp)
{
return wrapped.getBeliefs(node, bp);
}
/**
* @see EnergyFunction#getPossibleLabels(int)
*/
@Override
public Collection<LabelType> getPossibleLabels(int nodeId)
{
return wrapped.getPossibleLabels(nodeId);
}
/**
* @see EnergyFunction#numEdges()
*/
@Override
public int numEdges()
{
return wrapped.numEdges();
}
/**
* @see EnergyFunction#numNodes()
*/
@Override
public int numNodes()
{
return wrapped.numNodes();
}
/**
* @see EnergyFunction#getEdge(int)
*/
@Override
public Pair<Integer, Integer> getEdge(int i)
{
return wrapped.getEdge(i);
}
/**
* @see EnergyFunction#getUnaryPotential(int, java.lang.Object)
*/
@Override
public double getUnaryPotential(int i,
LabelType label)
{
return wrapped.getUnaryPotential(i, label);
}
/**
* @see EnergyFunction#getPairwisePotential(int, java.lang.Object,
* java.lang.Object)
*/
@Override
public double getPairwisePotential(int edgeId,
LabelType ilabel,
LabelType jlabel)
{
return wrapped.getPairwisePotential(edgeId, ilabel, jlabel);
}
/**
* Helper that takes a collection and a value which should be from that
* collection and returns the index of that value from that collection.
*
* @param <LabelType> The type for the values in the collection
* @param label The value whose index is needed
* @param labels The collection
* @return The index of the value in the collection
*/
private static <LabelType> int indexOf(LabelType label,
Collection<LabelType> labels)
{
int idx = 0;
for (LabelType l : labels)
{
if (l.equals(label))
{
return idx;
}
++idx;
}
throw new RuntimeException("Unable to find input label (" + label
+ ") in input");
}
/**
* @see EnergyFunction#getUnaryCost(int, java.lang.Object)
*/
@Override
public double getUnaryCost(int i,
LabelType label)
{
Collection<LabelType> labels = wrapped.getPossibleLabels(i);
int idx = indexOf(label, labels);
if (unaryCosts[i][idx] == Double.MAX_VALUE)
{
unaryCosts[i][idx] = wrapped.getUnaryCost(i, label);
}
return unaryCosts[i][idx];
}
/**
* @see EnergyFunction#getPairwiseCost(int, java.lang.Object,
* java.lang.Object)
*/
@Override
public double getPairwiseCost(int edgeId,
LabelType ilabel,
LabelType jlabel)
{
Pair<Integer, Integer> endpoints = wrapped.getEdge(edgeId);
Collection<LabelType> ilabels = wrapped.getPossibleLabels(
endpoints.getFirst());
Collection<LabelType> jlabels = wrapped.getPossibleLabels(
endpoints.getSecond());
int idx = indexOf(ilabel, ilabels) * jlabels.size() + indexOf(jlabel,
jlabels);
if (pairwiseCosts[edgeId][idx] == Double.MAX_VALUE)
{
pairwiseCosts[edgeId][idx] = wrapped.getPairwiseCost(edgeId, ilabel,
jlabel);
}
return pairwiseCosts[edgeId][idx];
}
}