/*
* File: CostSpeedupEnergyFunctionTest.java
* Authors: Jeremy D. Wendt
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright 2016, Sandia Corporation.
* Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
* license for use of this work by or on behalf of the U.S. Government.
* Export of this program may require a license from the United States
* Government. See CopyrightHistory.txt for complete details.
*
*/
package gov.sandia.cognition.graph.inference;
import gov.sandia.cognition.graph.DenseMemoryGraph;
import gov.sandia.cognition.graph.DirectedNodeEdgeGraph;
import java.util.ArrayList;
import java.util.List;
import static org.junit.Assert.*;
import org.junit.Test;
/**
*
* @author jdwendt
*/
public class CostSpeedupEnergyFunctionTest
{
@Test
public void basicTest()
{
DenseMemoryGraph<Integer> graph = new DenseMemoryGraph<>(3, 2);
graph.addEdge(0, 1);
graph.addEdge(1, 2);
GraphWrappingEnergyFunction.PotentialHandler<Integer, Integer> handler
= new GraphWrappingEnergyFunction.PotentialHandler<Integer, Integer>()
{
private final double[][] pairwisePotentials0 =
{
{
1, 2
},
{
4, 1
}
};
private final double[][] pairwisePotentials1 =
{
{
1, 3
},
{
3, 1
}
};
@Override
public double getPairwisePotential(
DirectedNodeEdgeGraph<Integer> graph,
int edgeId,
Integer ilabel,
Integer jlabel)
{
if (edgeId == 0)
{
return pairwisePotentials0[ilabel][jlabel];
}
else if (edgeId == 1)
{
return pairwisePotentials1[ilabel][jlabel];
}
throw new RuntimeException("Input edge id: " + edgeId
+ " does not exist in the graph");
}
@Override
public double getUnaryPotential(
DirectedNodeEdgeGraph<Integer> graph,
int i,
Integer label,
Integer assignedLabel)
{
return 1.0;
}
@Override
public List<Integer> getPossibleLabels(
DirectedNodeEdgeGraph<Integer> graph,
int nodeId)
{
List<Integer> ret = new ArrayList<>();
ret.add(0);
ret.add(1);
return ret;
}
};
GraphWrappingEnergyFunction<Integer, Integer> fn
= new GraphWrappingEnergyFunction<>(graph, handler);
EnergyFunctionSolver<Integer> bp = new SumProductBeliefPropagation<>();
bp.init(new CostSpeedupEnergyFunction<>(fn));
assertTrue(bp.solve());
InferenceHelper.testExactResults(graph, fn, bp, new double[]
{
0.375, 0.625, 0.625, 0.375, 0.4375, 0.5625
}, false);
}
}