/*
* File: SumProductBeliefPropagationTest.java
* Authors: Jeremy D. Wendt
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright 2016, Sandia Corporation.
* Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
* license for use of this work by or on behalf of the U.S. Government.
* Export of this program may require a license from the United States
* Government. See CopyrightHistory.txt for complete details.
*
*/
package gov.sandia.cognition.graph.inference;
import gov.sandia.cognition.graph.DenseMemoryGraph;
import gov.sandia.cognition.util.DefaultKeyValuePair;
import gov.sandia.cognition.util.Pair;
import java.util.ArrayList;
import java.util.List;
import org.junit.Test;
import static org.junit.Assert.*;
/**
* Test case for sum-product belief propagation.
*
* @author tong
*
*/
public class SumProductBeliefPropagationTest
{
@Test
public void testSolve()
{
// We'll use our simple example consisting of 3 nodes.
EnergyFunctionSolver<Integer> bp = new SumProductBeliefPropagation<>();
bp.init(new ThreeNodeEnergyFunction());
assertTrue(bp.solve());
// for (int node = 0; node < 3; node++)
// {
// System.out.print(bp.getBelief(node, 0));
// System.out.print('\t');
// System.out.print(bp.getBelief(node, 1));
// System.out.println();
// }
double delta = 0.0001;
assertEquals(0.375, bp.getBelief(0, 0), delta);
assertEquals(0.625, bp.getBelief(0, 1), delta);
assertEquals(0.625, bp.getBelief(1, 0), delta);
assertEquals(0.375, bp.getBelief(1, 1), delta);
assertEquals(0.4375, bp.getBelief(2, 0), delta);
assertEquals(0.5625, bp.getBelief(2, 1), delta);
}
@Test
public void testSolveDiffEdgeOrder()
{
// We'll use our simple example consisting of 3 nodes.
EnergyFunctionSolver<Integer> bp
= new SumProductBeliefPropagation<>();
ThreeNodeEnergyFunction fn = new ThreeNodeEnergyFunction();
fn.flipOrder = true;
bp.init(fn);
assertTrue(bp.solve());
// for (int node = 0; node < 3; node++)
// {
// System.out.print(bp.getBelief(node, 0));
// System.out.print('\t');
// System.out.print(bp.getBelief(node, 1));
// System.out.println();
// }
double delta = 0.0001;
assertEquals(0.375, bp.getBelief(0, 0), delta);
assertEquals(0.625, bp.getBelief(0, 1), delta);
assertEquals(0.625, bp.getBelief(1, 0), delta);
assertEquals(0.375, bp.getBelief(1, 1), delta);
assertEquals(0.4375, bp.getBelief(2, 0), delta);
assertEquals(0.5625, bp.getBelief(2, 1), delta);
}
/**
* Three nodes, 2 edges. 0-1-2.
*
* @author tong
*
*/
class ThreeNodeEnergyFunction
implements EnergyFunction<Integer>
{
private double[][] pairwisePotentials0 =
{
{
1, 2
},
{
4, 1
}
};
private final double[][] pairwisePotentials1 =
{
{
1, 3
},
{
3, 1
}
};
private boolean flipOrder = false;
@Override
public int numNodes()
{
return 3;
}
@Override
public int numEdges()
{
return 2;
}
@Override
public Pair<Integer, Integer> getEdge(int edge)
{
if (edge == 0)
{
if (flipOrder)
{
return new DefaultKeyValuePair<>(1, 0);
}
else
{
return new DefaultKeyValuePair<>(0, 1);
}
}
else if (edge == 1)
{
return new DefaultKeyValuePair<>(1, 2);
}
else
{
throw new IllegalArgumentException("Edge " + edge
+ " is not in the graph.");
}
}
@Override
public List<Integer> getPossibleLabels(int node)
{
ArrayList<Integer> labels = new ArrayList<>();
labels.add(0);
labels.add(1);
return labels;
}
@Override
public double getUnaryPotential(int node,
Integer label)
{
return 1;
}
@Override
public double getPairwisePotential(int edge,
Integer ilabel,
Integer jlabel)
{
if (flipOrder)
{
pairwisePotentials0 = new double[][]
{
{
1, 4
},
{
2, 1
}
};
}
if (edge == 0)
{
return pairwisePotentials0[ilabel][jlabel];
}
else if (edge == 1)
{
return pairwisePotentials1[ilabel][jlabel];
}
else
{
throw new IllegalArgumentException("Edge " + edge
+ " is not in the graph.");
}
}
@Override
public double getUnaryCost(int i,
Integer label)
{
return -Math.log(getUnaryPotential(i, label));
}
@Override
public double getPairwiseCost(int edgeId,
Integer ilabel,
Integer jlabel)
{
return -Math.log(getPairwisePotential(edgeId, ilabel, jlabel));
}
}
@Test
public void exactTest()
{
DenseMemoryGraph<Integer> graph = new DenseMemoryGraph<>(3, 2);
graph.addEdge(1, 2);
graph.addEdge(2, 3);
BasicHomogeneousHandler<Integer> handler = new BasicHomogeneousHandler<>();
NodeNameAwareEnergyFunction<Integer, Integer> fn
= new GraphWrappingEnergyFunction<>(graph, handler);
fn.setLabel(1, 0);
EnergyFunctionSolver<Integer> bp = new SumProductBeliefPropagation<>(
100, 1e-6, 1);
bp.init(fn);
assertTrue(bp.solve());
InferenceHelper.testExactResults(graph, fn, bp, new double[]
{
1.0, 0.0, 0.99, 0.01, 0.9802, 0.0198
}, false);
handler.setSpecialUnaryPotential(3, graph, 0.8);
fn.setLabel(3, 1);
bp.init(fn);
assertTrue(bp.solve());
InferenceHelper.testExactResults(graph, fn, bp, new double[]
{
1.0, 0.0, 0.9625, 0.0375, 0.9252, 0.0748
}, false);
}
}