package gov.sandia.cognition.learning.algorithm.semisupervised.valence;
import gov.sandia.cognition.learning.algorithm.minimization.matrix.ConjugateGradientMatrixSolver;
import gov.sandia.cognition.math.matrix.Vector;
import java.util.ArrayList;
import java.util.List;
import org.junit.Test;
import static org.junit.Assert.*;
/**
* Basic tests for the MultipartiteValenceMatrix class.
*
* @author jdwendt
*/
public class MultipartiteValenceMatrixTest
{
/**
* Creates a simple small graph w/ initial values on two nodes.
*
* @param mvm The datastructure to load the graph into
*/
private static void fillMvm(MultipartiteValenceMatrix mvm)
{
mvm.addRelationship(0, 0, 1, 0, 1);
mvm.addRelationship(0, 0, 1, 1, 1);
mvm.addRelationship(0, 0, 1, 2, 1);
mvm.addRelationship(0, 1, 1, 1, 1);
mvm.addRelationship(0, 1, 1, 4, 1);
mvm.addRelationship(0, 2, 1, 1, 1);
mvm.addRelationship(0, 2, 1, 4, 1);
mvm.addRelationship(0, 3, 1, 3, 1);
mvm.addRelationship(0, 3, 1, 4, 1);
mvm.addRelationship(0, 3, 1, 5, 1);
mvm.setElementsScore(1, 0, 1, 1);
mvm.setElementsScore(1, 5, 1, -1);
}
/**
* Simple tests that makes sure the spreading works at all.
*/
@Test
public void simpleTest()
{
List<Integer> sizes = new ArrayList<Integer>(2);
sizes.add(4);
sizes.add(6);
MultipartiteValenceMatrix mvm = new MultipartiteValenceMatrix(sizes, 10);
fillMvm(mvm);
Vector rhs = mvm.init();
ConjugateGradientMatrixSolver s = new ConjugateGradientMatrixSolver(rhs,
rhs, 1e-1);
Vector result = s.learn(mvm).getOutput();
// First make sure the group 0 nodes make sense
assertTrue(result.getElement(0) > 0.5);
assertTrue(Math.abs(result.getElement(1)) < 0.1);
assertTrue(Math.abs(result.getElement(2)) < 0.1);
assertTrue(result.getElement(3) < -0.5);
// Now make sure the group 1 nodes make sense
assertTrue(result.getElement(4) > 0.5);
assertTrue(result.getElement(5) > 0.0);
assertTrue(result.getElement(6) > 0.5);
assertTrue(result.getElement(7) < -0.5);
assertTrue(result.getElement(8) < -0.0);
assertTrue(result.getElement(9) < -0.5);
}
/**
* Tests that as you increase the power, the spread of the scores increases
*/
@Test
public void spreadTest()
{
List<Integer> sizes = new ArrayList<Integer>(2);
sizes.add(4);
sizes.add(6);
MultipartiteValenceMatrix mvm = new MultipartiteValenceMatrix(sizes, 0);
fillMvm(mvm);
Vector rhs = mvm.init();
ConjugateGradientMatrixSolver s = new ConjugateGradientMatrixSolver(
rhs, rhs, 1e-1);
Vector result = s.learn(mvm).getOutput();
// In the zero-spread case, only the seeded-nodes should have value
// First make sure the group 0 nodes make sense
assertTrue(Math.abs(result.getElement(0)) < 0.01);
assertTrue(Math.abs(result.getElement(1)) < 0.01);
assertTrue(Math.abs(result.getElement(2)) < 0.01);
assertTrue(Math.abs(result.getElement(3)) < 0.01);
// Now make sure the group 1 nodes make sense
assertTrue(result.getElement(4) >= 0.5);
assertTrue(Math.abs(result.getElement(5)) < 0.01);
assertTrue(Math.abs(result.getElement(6)) < 0.01);
assertTrue(Math.abs(result.getElement(7)) < 0.01);
assertTrue(Math.abs(result.getElement(8)) < 0.01);
assertTrue(result.getElement(9) <= -0.5);
mvm = new MultipartiteValenceMatrix(sizes, 1);
fillMvm(mvm);
rhs = mvm.init();
s = new ConjugateGradientMatrixSolver(rhs, rhs, 1e-1);
result = s.learn(mvm).getOutput();
// In the one-spread case, only the nodes connected to the seeded-nodes
// should have value
// First make sure the group 0 nodes make sense
assertTrue(result.getElement(0) >= 0.2);
assertTrue(Math.abs(result.getElement(1)) < 0.01);
assertTrue(Math.abs(result.getElement(2)) < 0.01);
assertTrue(result.getElement(3) <= -0.2);
// Now make sure the group 1 nodes make sense
assertTrue(result.getElement(4) >= 0.5);
assertTrue(Math.abs(result.getElement(5)) < 0.01);
assertTrue(Math.abs(result.getElement(6)) < 0.01);
assertTrue(Math.abs(result.getElement(7)) < 0.01);
assertTrue(Math.abs(result.getElement(8)) < 0.01);
assertTrue(result.getElement(9) <= -0.5);
mvm = new MultipartiteValenceMatrix(sizes, 2);
fillMvm(mvm);
rhs = mvm.init();
s = new ConjugateGradientMatrixSolver(rhs, rhs, 1e-1);
result = s.learn(mvm).getOutput();
// In the two-spread case, we're in the ballpark of final solution with
// this graph
// First make sure the group 0 nodes make sense
assertTrue(result.getElement(0) > 0.5);
assertTrue(Math.abs(result.getElement(1)) < 0.1);
assertTrue(Math.abs(result.getElement(2)) < 0.1);
assertTrue(result.getElement(3) < -0.5);
// Now make sure the group 1 nodes make sense
assertTrue(result.getElement(4) > 0.5);
assertTrue(result.getElement(5) > 0.0);
assertTrue(result.getElement(6) > 0.5);
assertTrue(result.getElement(7) < -0.5);
assertTrue(result.getElement(8) < -0.0);
assertTrue(result.getElement(9) < -0.5);
}
}