/*
* File: InferenceHelper.java
* Authors: Jeremy D. Wendt
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright 2016, Sandia Corporation.
* Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
* license for use of this work by or on behalf of the U.S. Government.
* Export of this program may require a license from the United States
* Government. See CopyrightHistory.txt for complete details.
*
*/
package gov.sandia.cognition.graph.inference;
import gov.sandia.cognition.graph.DenseMemoryGraph;
import java.util.Collection;
import java.util.Map;
import static org.junit.Assert.*;
/**
* Helper that tests results in a standardized manner used by various tests in
* this package.
*
* @author jdwendt
*/
class InferenceHelper
{
public static <NodeLabelType, LabelType> void testExactResults(
DenseMemoryGraph<NodeLabelType> graph,
NodeNameAwareEnergyFunction<LabelType, NodeLabelType> fn,
EnergyFunctionSolver<LabelType> solver,
double[] expectedResultsInOrder,
boolean printResults)
{
int idx = 0;
for (int i = 0; i < graph.getNumNodes(); ++i)
{
NodeLabelType node = graph.getNode(i);
Map<LabelType, Double> b = fn.getBeliefs(node, solver);
Collection<LabelType> lbls = fn.getPossibleLabels(i);
if (printResults)
{
System.out.print("Node " + node + ": ");
String sep = ": ";
for (LabelType lbl : lbls)
{
System.out.print(sep + b.get(lbl));
sep = ", ";
}
System.out.println();
}
for (LabelType lbl : lbls)
{
assertEquals(expectedResultsInOrder[idx], b.get(lbl), 1e-4);
++idx;
}
}
}
public static <NodeLabelType, LabelType> void testApproximateResults(
DenseMemoryGraph<NodeLabelType> graph,
NodeNameAwareEnergyFunction<LabelType, NodeLabelType> fn,
EnergyFunctionSolver<LabelType> solver,
double[] expectedResultsInOrder,
double acceptableError,
boolean printResults)
{
int idx = 0;
for (int i = 0; i < graph.getNumNodes(); ++i)
{
NodeLabelType node = graph.getNode(i);
Map<LabelType, Double> b = fn.getBeliefs(node, solver);
Collection<LabelType> lbls = fn.getPossibleLabels(i);
if (printResults)
{
System.out.print("Node " + node + ": ");
String sep = ": ";
for (LabelType lbl : lbls)
{
System.out.print(sep + b.get(lbl));
sep = ", ";
}
System.out.println();
}
for (LabelType lbl : lbls)
{
assertTrue(expectedResultsInOrder[idx] + " - " + b.get(lbl),
Math.abs(expectedResultsInOrder[idx] - b.get(lbl))
< acceptableError);
++idx;
}
}
}
}