/*
* Copyright 2015 Red Hat, Inc. and/or its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.drools.beliefs.bayes;
import junit.framework.AssertionFailedError;
import org.drools.beliefs.graph.Graph;
import org.drools.beliefs.graph.GraphNode;
import org.drools.core.util.bitmask.OpenBitSet;
import org.junit.Test;
import java.math.BigDecimal;
import java.util.Arrays;
import static org.drools.beliefs.bayes.GraphTest.addNode;
import static org.drools.beliefs.bayes.GraphTest.bitSet;
import static org.drools.beliefs.bayes.GraphTest.connectParentToChildren;
import static org.drools.beliefs.bayes.PotentialMultiplier.indexToKey;
import static org.drools.beliefs.bayes.PotentialMultiplier.keyToIndex;
import static org.junit.Assert.assertEquals;
public class JunctionTreeTest {
@Test
public void testIndextoKeyMapping1() {
// tests simple index to key mapping for a 2x2 array.
BayesVariable a = new BayesVariable<String>( "A", 0, new String[] {"A1", "A2"}, null);
BayesVariable b = new BayesVariable<String>( "B", 0, new String[] {"B1", "B2"}, null);
BayesVariable[] vars = new BayesVariable[] {a, b};
int numberOfStates = PotentialMultiplier.createNumberOfStates(vars);
int[] indexMultipliers = PotentialMultiplier.createIndexMultipliers(vars, numberOfStates);
assertEquals( 4, numberOfStates );
assertIndexToKeyMapping(numberOfStates, indexMultipliers);
}
@Test
public void testIndextoKeyMapping2() {
// tests simple index to key mapping for a 2x3 array.
BayesVariable a = new BayesVariable<String>( "A", 0, new String[] {"A1", "A2", "A3"}, null);
BayesVariable b = new BayesVariable<String>( "B", 0, new String[] {"B1", "B2", "B3"}, null);
BayesVariable[] vars = new BayesVariable[] {a, b};
int numberOfStates = PotentialMultiplier.createNumberOfStates(vars);
int[] indexMultipliers = PotentialMultiplier.createIndexMultipliers(vars, numberOfStates);
assertEquals( 9, numberOfStates );
assertIndexToKeyMapping(numberOfStates, indexMultipliers);
}
@Test
public void testIndextoKeyMapping3() {
// tests a slightly more complex array, which has different lengths for rows. This maps to the Year2000 problem, which uses this array size and shape.
BayesVariable a = new BayesVariable<String>( "A", 0, new String[] {"A1", "A2", "A3"}, null);
BayesVariable b = new BayesVariable<String>( "B", 0, new String[] {"B1", "B2", "B3"}, null);
BayesVariable c = new BayesVariable<String>( "C", 0, new String[] {"C1", "C2", "C3", "C4"}, null);
BayesVariable d = new BayesVariable<String>( "D", 0, new String[] {"D1", "D2", "D3"}, null);
BayesVariable[] vars = new BayesVariable[] {a, b, c, d};
int numberOfStates = PotentialMultiplier.createNumberOfStates(vars);
int[] indexMultipliers = PotentialMultiplier.createIndexMultipliers(vars, numberOfStates);
assertEquals( 108, numberOfStates);
assertIndexToKeyMapping(numberOfStates, indexMultipliers);
}
@Test
public void testPotentialMultiplication1() {
// This tests a simple clique, where the variable being multiplied only has one parent.
// There are no gaps in the variable key, compared to the path
BayesVariable a = new BayesVariable<String>( "A", 0, new String[] {"A1", "A2"}, null);
BayesVariable b = new BayesVariable<String>( "B", 0, new String[] {"B1", "B2"}, new double[][] {{0.1, 0.2}, { 0.3, 0.4 }});
BayesVariable[] vars = new BayesVariable[] {a, b};
int numberOfStates = PotentialMultiplier.createNumberOfStates(vars);
int[] multipliers = PotentialMultiplier.createIndexMultipliers(vars, numberOfStates);
assertEquals( 4, numberOfStates);
assertIndexToKeyMapping(numberOfStates, multipliers);
double[] potentials = new double[numberOfStates];
Arrays.fill(potentials, 1);
BayesVariable[] parents = new BayesVariable[] { a };
int[] parentVarPos = PotentialMultiplier.createSubsetVarPos(vars, parents);
int parentsNumberOfStates = PotentialMultiplier.createNumberOfStates(parents);
int[] parentIndexMultipliers = PotentialMultiplier.createIndexMultipliers(parents, parentsNumberOfStates);
PotentialMultiplier m = new PotentialMultiplier(b.getProbabilityTable(), 1, parentVarPos, parentIndexMultipliers, vars, multipliers, potentials);
m.multiple();
assertArray(new double[]{0.1, 0.2, 0.3, 0.4}, potentials);
// test that it's applying variable multiplications correctly ontop of each other. This simulates the application of project variabe multiplications
m.multiple();
assertArray(new double[]{0.01, 0.04, 0.09, 0.16}, scaleDouble( 3, potentials ));
}
@Test
public void testPotentialMultiplication2() {
// This clique has 4 variables. The variable being multiplied has two parents, directly above it.
// There is a non parent, after it. While d is not part of the key, it's still part of over all path, iterated through by the cross products,
BayesVariable a = new BayesVariable<String>( "A", 0, new String[] {"A1", "A2"}, null);
BayesVariable b = new BayesVariable<String>( "B", 0, new String[] {"B1", "B2"}, null);
BayesVariable c = new BayesVariable<String>( "C", 0, new String[] {"C1", "C2"}, new double[][] {{0.1, 0.2}, {0.3, 0.4}, {0.5, 0.6}, { 0.7, 0.8 }});
BayesVariable d = new BayesVariable<String>( "D", 0, new String[] {"D1", "D2"}, null);
BayesVariable[] vars = new BayesVariable[] {a, b, c, d};
int numberOfStates = PotentialMultiplier.createNumberOfStates(vars);
int[] multipliers = PotentialMultiplier.createIndexMultipliers(vars, numberOfStates);
assertEquals( 16, numberOfStates);
assertIndexToKeyMapping(numberOfStates, multipliers);
double[] potentials = new double[numberOfStates];
Arrays.fill(potentials, 1);
BayesVariable[] parents = new BayesVariable[] { a, b };
int[] parentVarPos = PotentialMultiplier.createSubsetVarPos(vars, parents);
int parentsNumberOfStates = PotentialMultiplier.createNumberOfStates(parents);
int[] parentIndexMultipliers = PotentialMultiplier.createIndexMultipliers(parents, parentsNumberOfStates);
PotentialMultiplier m = new PotentialMultiplier(c.getProbabilityTable(), 2, parentVarPos, parentIndexMultipliers, vars, multipliers, potentials);
m.multiple();
assertArray(new double[]{0.1, 0.1, 0.2, 0.2, 0.3, 0.3, 0.4, 0.4, 0.5, 0.5, 0.6, 0.6, 0.7, 0.7, 0.8, 0.8}, scaleDouble( 3, potentials ));
// test that it's applying variable multiplications correctly ontop of each other. This simulates the application of project variabe multiplications
m.multiple();
assertArray(new double[]{0.01, 0.01, 0.04, 0.04, 0.09, 0.09, 0.16, 0.16, 0.25, 0.25, 0.36, 0.36, 0.49, 0.49, 0.64, 0.64}, scaleDouble( 3, potentials ) );
}
@Test
public void testPotentialMultiplication3() {
// This clique has 4 variables. One parent is before and the other parent is after the variable being multiplied.
// While a is not part of the parent key, it's still part of over all path, iterated through by the cross products,
BayesVariable a = new BayesVariable<String>( "A", 0, new String[] {"A1", "A2"}, null);
BayesVariable b = new BayesVariable<String>( "B", 0, new String[] {"B1", "B2"}, null);
BayesVariable c = new BayesVariable<String>( "C", 0, new String[] {"C1", "C2"}, new double[][] {{0.1, 0.2}, {0.3, 0.4}, {0.5, 0.6}, { 0.7, 0.8 }});
BayesVariable d = new BayesVariable<String>( "D", 0, new String[] {"D1", "D2"}, null);
BayesVariable[] vars = new BayesVariable[] {a, b, c, d};
int numberOfStates = PotentialMultiplier.createNumberOfStates(vars);
int[] multipliers = PotentialMultiplier.createIndexMultipliers(vars, numberOfStates);
assertEquals( 16, numberOfStates);
assertIndexToKeyMapping(numberOfStates, multipliers);
double[] potentials = new double[numberOfStates];
Arrays.fill(potentials, 1);
BayesVariable[] parents = new BayesVariable[] { b, d };
int[] parentVarPos = PotentialMultiplier.createSubsetVarPos(vars, parents);
int parentsNumberOfStates = PotentialMultiplier.createNumberOfStates(parents);
int[] parentIndexMultipliers = PotentialMultiplier.createIndexMultipliers(parents, parentsNumberOfStates);
PotentialMultiplier m = new PotentialMultiplier(c.getProbabilityTable(), 2, parentVarPos, parentIndexMultipliers, vars, multipliers, potentials);
m.multiple();
assertArray(new double[]{0.1, 0.3, 0.2, 0.4, 0.5, 0.7, 0.6, 0.8, 0.1, 0.3, 0.2, 0.4, 0.5, 0.7, 0.6, 0.8}, potentials);
// test that it's applying variable multiplications correctly ontop of each other. This simulates the application of project variabe multiplications
m.multiple();
assertArray(new double[]{0.01, 0.09, 0.04, 0.16, 0.25, 0.49, 0.36, 0.64, 0.01, 0.09, 0.04, 0.16, 0.25, 0.49, 0.36, 0.64}, scaleDouble( 3, potentials ) );
}
@Test
public void testJunctionTreeInitialisation() {
// creates JunctionTree where node1 has only B as a family memory.
// node 2 has both c and d as family, and c is the parent of d.
BayesVariable a = new BayesVariable<String>( "A", 0, new String[] {"A1", "A2"}, new double[][] {{0.1, 0.2}});
BayesVariable b = new BayesVariable<String>( "B", 1, new String[] {"B1", "B2"}, new double[][] {{0.1, 0.2}});
BayesVariable c = new BayesVariable<String>( "C", 2, new String[] {"C1", "C2"}, new double[][] {{0.1, 0.2}});
BayesVariable d = new BayesVariable<String>( "D", 3, new String[] {"D1", "D2"}, new double[][] {{0.1, 0.2}, {0.3, 0.4}});
Graph<BayesVariable> graph = new BayesNetwork();
GraphNode x0 = addNode(graph);
GraphNode x1 = addNode(graph);
GraphNode x2 = addNode(graph);
GraphNode x3 = addNode(graph);
//connectParentToChildren(x0, x2);
connectParentToChildren(x2, x3);
x0.setContent( a );
x1.setContent( b );
x2.setContent( c );
x3.setContent( d );
JunctionTreeClique node1 = new JunctionTreeClique(0, graph, bitSet("0011") );
JunctionTreeClique node2 = new JunctionTreeClique(1, graph, bitSet("1100") );
new JunctionTreeSeparator(0, node1, node2, new OpenBitSet(), graph);
node1.addToFamily( b );
b.setFamily( node1.getId() );
node2.addToFamily( c );
c.setFamily( node2.getId() );
node2.addToFamily( d );
d.setFamily( node2.getId() );
JunctionTree jtree = new JunctionTree(graph, node1, new JunctionTreeClique[] { node1, node2 }, null );
assertArray(new double[]{0.1, 0.2, 0.1, 0.2}, scaleDouble( 3, node1.getPotentials() ));
assertArray(new double[]{0.01, 0.02, 0.06, 0.08}, scaleDouble( 3, node2.getPotentials() ));
}
public static void assertArray(double[] expected, double[] actual) {
if ( !Arrays.equals(expected, actual) ) {
System.err.print( "expected " );
for ( int i = 0; i <expected.length; i++ ) {
System.err.format("%.7f ", expected[i]);
}
System.err.println("");
System.err.print( "actual " );
for ( int i = 0; i <actual.length; i++ ) {
System.err.format("%.7f ", actual[i]);
}
System.err.println("");
throw new AssertionFailedError("Arrays are not Equal");
}
}
public static void assertIndexToKeyMapping(int numberOfStates, int[] indexMultipliers) {
for (int i = 0; i < numberOfStates; i++) {
int[] key = indexToKey(i, indexMultipliers);
int index = keyToIndex(key, indexMultipliers);
assertEquals(i, index);
}
}
public static double scaleDouble(int scale, double d) {
return new BigDecimal(d).setScale(scale, BigDecimal.ROUND_HALF_UP).doubleValue();
}
public static double[] scaleDouble(int scale, double[] array) {
for ( int i = 0; i < array.length; i++ ) {
array[i] = scaleDouble(scale, array[i]);
}
return array;
}
}