/*******************************************************************************
* Copyright 2014 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.test.FactorFunctions.core;
import static com.analog.lyric.math.Utilities.*;
import static com.analog.lyric.util.test.ExceptionTester.*;
import static java.util.Objects.*;
import static org.hamcrest.CoreMatchers.*;
import static org.junit.Assert.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.Random;
import java.util.concurrent.TimeUnit;
import org.junit.Ignore;
import org.junit.Test;
import com.analog.lyric.collect.ArrayUtil;
import com.analog.lyric.collect.BitSetUtil;
import com.analog.lyric.collect.Comparators;
import com.analog.lyric.collect.Tuple2;
import com.analog.lyric.dimple.exceptions.DimpleException;
import com.analog.lyric.dimple.factorfunctions.core.FactorTable;
import com.analog.lyric.dimple.factorfunctions.core.FactorTableEntry;
import com.analog.lyric.dimple.factorfunctions.core.FactorTableRepresentation;
import com.analog.lyric.dimple.factorfunctions.core.IFactorTable;
import com.analog.lyric.dimple.factorfunctions.core.IFactorTableBase;
import com.analog.lyric.dimple.factorfunctions.core.IFactorTableIterator;
import com.analog.lyric.dimple.model.domains.DiscreteDomain;
import com.analog.lyric.dimple.model.domains.DiscreteIndicesIterator;
import com.analog.lyric.dimple.model.domains.JointDiscreteDomain;
import com.analog.lyric.dimple.model.domains.JointDomainIndexer;
import com.analog.lyric.dimple.model.domains.JointDomainReindexer;
import com.analog.lyric.dimple.model.values.Value;
import com.analog.lyric.dimple.test.DimpleTestBase;
import com.analog.lyric.util.test.SerializationTester;
import com.google.common.base.Stopwatch;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.Multimap;
import cern.colt.map.OpenIntIntHashMap;
public class TestFactorTable extends DimpleTestBase
{
final Random rand = new Random(42);
final DiscreteDomain domain2 = DiscreteDomain.range(0,1);
final DiscreteDomain domain3 = DiscreteDomain.range(0,2);
final DiscreteDomain domain6 = DiscreteDomain.range(0,5);
final DiscreteDomain domain32 = DiscreteDomain.range(0,31);
final DiscreteDomain domain256 = DiscreteDomain.range(0, 255);
final DiscreteDomain domainMax = DiscreteDomain.range(0,Integer.MAX_VALUE - 1);
@Test
public void testFactorTable()
{
IFactorTable t2x3 = FactorTable.create(domain2, domain3);
assertEquals(2, t2x3.getDimensions());
assertEquals(domain2, t2x3.getDomainIndexer().get(0));
assertEquals(domain3, t2x3.getDomainIndexer().get(1));
assertFalse(t2x3.getRepresentation().hasDense());
assertFalse(t2x3.isDirected());
assertEquals(FactorTableRepresentation.SPARSE_ENERGY, t2x3.getRepresentation());
assertInvariants(t2x3);
assertEquals(0.0, t2x3.density(), 0.0);
t2x3.setRepresentation(FactorTableRepresentation.DENSE_ENERGY);
assertEquals(FactorTableRepresentation.DENSE_ENERGY, t2x3.getRepresentation());
assertTrue(t2x3.getRepresentation().hasDense());
assertInvariants(t2x3);
t2x3.setRepresentation(FactorTableRepresentation.SPARSE_ENERGY);
assertFalse(t2x3.getRepresentation().hasDense());
assertEquals(0, t2x3.sparseSize());
assertInvariants(t2x3);
t2x3.setRepresentation(FactorTableRepresentation.ALL_ENERGY);
t2x3.randomizeWeights(rand);
assertInvariants(t2x3);
assertEquals(1.0, t2x3.density(), 0.0);
assertEquals(FactorTableRepresentation.ALL_ENERGY, t2x3.getRepresentation());
t2x3.setRepresentation(FactorTableRepresentation.ALL_VALUES);
assertEquals(FactorTableRepresentation.ALL_VALUES, t2x3.getRepresentation());
assertInvariants(t2x3);
t2x3.normalize();
assertInvariants(t2x3);
t2x3.setRepresentation(FactorTableRepresentation.SPARSE_WEIGHT);
assertEquals(FactorTableRepresentation.SPARSE_WEIGHT, t2x3.getRepresentation());
assertInvariants(t2x3);
try
{
t2x3.setWeightsDense(new double[] { 1,2,3 });
fail("expected IllegalArgumentException");
}
catch (IllegalArgumentException ex)
{
assertThat(ex.getMessage(), containsString("Bad dense length"));
}
try
{
t2x3.setEnergiesDense(new double[] { 1,2,3 });
fail("expected IllegalArgumentException");
}
catch (IllegalArgumentException ex)
{
assertThat(ex.getMessage(), containsString("Bad dense length"));
}
t2x3.setWeightsDense(new double[] { 1, 2, 3, 4, 5, 6});
assertEquals(FactorTableRepresentation.DENSE_WEIGHT, t2x3.getRepresentation());
for (int ji = 0; ji < 6; ++ji)
{
assertEquals(ji + 1, t2x3.getWeightForJointIndex(ji), 0.0);
}
assertInvariants(t2x3);
t2x3.setEnergiesDense(new double[] { 2,4,6,8, 10, 12});
assertEquals(FactorTableRepresentation.DENSE_ENERGY, t2x3.getRepresentation());
for (int ji = 0; ji < 6; ++ji)
{
assertEquals(2 * (ji + 1), t2x3.getEnergyForJointIndex(ji), 0.0);
}
assertInvariants(t2x3);
try
{
t2x3.setWeightsSparse(new int[] {1}, new double[] {2,3});
fail("Expected IllegalArgumentException");
}
catch (IllegalArgumentException ex)
{
assertThat(ex.getMessage(), containsString("Arrays have different sizes"));
}
try
{
t2x3.setWeightsSparse(new int[] {1, 6}, new double[] {2,3});
fail("Expected IllegalArgumentException");
}
catch (IllegalArgumentException ex)
{
assertThat(ex.getMessage(), containsString("Joint index 6 is out of range"));
}
try
{
t2x3.setWeightsSparse(new int[] {1, -1}, new double[] {2,3});
fail("Expected IllegalArgumentException");
}
catch (IllegalArgumentException ex)
{
assertThat(ex.getMessage(), containsString("Joint index -1 is out of range"));
}
try
{
t2x3.setWeightsSparse(new int[] {1, 2, 1}, new double[] {2,3,4});
fail("Expected IllegalArgumentException");
}
catch (IllegalArgumentException ex)
{
assertThat(ex.getMessage(), containsString("Multiple entries with same set of indices [1, 0]"));
}
t2x3.setWeightsSparse(new int[] { 1,2,3}, new double[] { 1,2,3});
assertEquals(3, t2x3.sparseSize());
assertEquals(FactorTableRepresentation.SPARSE_WEIGHT, t2x3.getRepresentation());
for (int i = 0; i < 3; ++i)
{
assertEquals(i+1, t2x3.getWeightForSparseIndex(i), 0.0);
assertEquals(i+1, t2x3.getWeightForJointIndex(i+1), 0.0);
}
assertInvariants(t2x3);
t2x3.setWeightsSparse(new int[] { 3,1,2}, new double[] { 3,1,2});
assertEquals(3, t2x3.sparseSize());
assertEquals(FactorTableRepresentation.SPARSE_WEIGHT, t2x3.getRepresentation());
for (int i = 0; i < 3; ++i)
{
assertEquals(i+1, t2x3.getWeightForSparseIndex(i), 0.0);
assertEquals(i+1, t2x3.getWeightForJointIndex(i+1), 0.0);
}
assertInvariants(t2x3);
t2x3.setEnergiesSparse(new int[] { 1,2,3}, new double[] { 1,2,3});
assertEquals(3, t2x3.sparseSize());
assertEquals(FactorTableRepresentation.SPARSE_ENERGY, t2x3.getRepresentation());
for (int i = 0; i < 3; ++i)
{
assertEquals(i+1, t2x3.getEnergyForSparseIndex(i), 0.0);
assertEquals(i+1, t2x3.getEnergyForJointIndex(i+1), 0.0);
}
assertInvariants(t2x3);
BitSet xor3Output = new BitSet(3);
xor3Output.set(1);
IFactorTable xor2 = FactorTable.create(xor3Output, domain2, domain2, domain2);
assertInvariants(xor2);
assertTrue(xor2.isDirected());
assertFalse(xor2.isDeterministicDirected());
for (int i = 0; i < 2; ++i)
{
for (int j = 0; j < 2; ++j)
{
xor2.setWeightForIndices(1.0, i, i^j, j);
}
}
assertTrue(xor2.isDeterministicDirected());
assertEquals(.5, xor2.density(), 0.0);
assertInvariants(xor2);
xor2.setDirected(null);
assertFalse(xor2.isDirected());
assertInvariants(xor2);
xor2.setConditional(xor3Output);
assertTrue(xor2.isConditional());
assertTrue(xor2.isDeterministicDirected());
xor2.setEnergyForSparseIndex(23.0, 1);
assertEquals(23.0, xor2.getEnergyForSparseIndex(1), 0.0);
assertFalse(xor2.isDeterministicDirected());
xor2.setEnergyForSparseIndex(0.0, 1);
assertTrue(xor2.isDeterministicDirected());
assertInvariants(xor2);
xor2.setRepresentation(FactorTableRepresentation.SPARSE_WEIGHT);
xor2.setDirected(null);
xor2.setDirected(xor3Output);
assertTrue(xor2.isDeterministicDirected());
xor2.setWeightForSparseIndex(23.0, 1);
assertEquals(23.0, xor2.getWeightForSparseIndex(1), 0.0);
assertFalse(xor2.isDeterministicDirected());
assertFalse(xor2.isConditional());
try
{
xor2.setConditional(Objects.requireNonNull(xor2.getDomainIndexer().getOutputSet()));
fail("expected exception");
}
catch (DimpleException ex)
{
assertThat(ex.getMessage(), containsString("weights must be normalized correctly for directed"));
}
xor2.normalizeConditional();
assertTrue(xor2.isConditional());
assertTrue(xor2.isDeterministicDirected());
assertEquals(1.0, xor2.getWeightForSparseIndex(1), 0.0);
xor2.setWeightForSparseIndex(23.0, 1);
xor2.makeConditional(Objects.requireNonNull(xor2.getDomainIndexer().getOutputSet()));
assertTrue(xor2.isDeterministicDirected());
xor2.setRepresentation(FactorTableRepresentation.DENSE_ENERGY);
xor2.setDirected(null);
xor2.setDirected(xor3Output);
assertTrue(xor2.isDeterministicDirected());
xor2.setEnergyForJointIndex(23.0, 0);
assertEquals(23.0, xor2.getEnergyForJointIndex(0), 0.0);
assertFalse(xor2.isDeterministicDirected());
xor2.setEnergyForJointIndex(0.0, 0);
assertTrue(xor2.isDeterministicDirected());
xor2.setRepresentation(FactorTableRepresentation.DENSE_WEIGHT);
xor2.setDirected(null);
xor2.setDirected(xor3Output);
assertTrue(xor2.isDeterministicDirected());
testRandomOperations(xor2, 10000);
// Test automatic representation changes by get* methods
xor2.setRepresentation(FactorTableRepresentation.SPARSE_ENERGY);
xor2.setEnergyForSparseIndex(2.3, 0);
assertEquals(energyToWeight(2.3), xor2.getWeightForSparseIndex(0), 1e-12);
assertEquals(FactorTableRepresentation.SPARSE_ENERGY, xor2.getRepresentation());
xor2.setRepresentation(FactorTableRepresentation.DENSE_WEIGHT);
assertEquals(energyToWeight(2.3), xor2.getWeightForSparseIndex(0), 1e-12);
assertEquals(FactorTableRepresentation.ALL_WEIGHT, xor2.getRepresentation());
xor2.setRepresentation(FactorTableRepresentation.DENSE_ENERGY);
assertEquals(energyToWeight(2.3), xor2.getWeightForSparseIndex(0), 1e-12);
assertEquals(FactorTableRepresentation.ALL_ENERGY, xor2.getRepresentation());
IFactorTable t2x2x2 = xor2.clone();
assertBaseEqual(t2x2x2, xor2);
t2x2x2.setWeightForIndices(.5, 1, 1, 1);
assertFalse(t2x2x2.isDeterministicDirected());
assertInvariants(t2x2x2);
IFactorTable t2x3x4 = FactorTable.create(domain2, domain3, domain6);
t2x3.setRepresentation(FactorTableRepresentation.DENSE_WEIGHT);
t2x3x4.randomizeWeights(rand);
assertInvariants(t2x3x4);
testRandomOperations(t2x3x4, 10000);
expectThrow(NullPointerException.class, xor2, "setConditional", new Object[] { null } );
}
@Test
public void testSparseFactorTable()
{
final DiscreteDomain[] d2by32 = new DiscreteDomain[32];
Arrays.fill(d2by32, domain2);
IFactorTable table = FactorTable.create(d2by32);
assertEquals(0, table.sparseSize());
assertFalse(table.hasMaximumDensity());
assertFalse(table.isConditional());
assertFalse(table.supportsJointIndexing());
assertInvariants(table);
testRandomOperations(table, 500);
final DiscreteDomain[] dmaxby2 = new DiscreteDomain[2];
Arrays.fill(dmaxby2, domainMax);
table = FactorTable.create(dmaxby2);
assertEquals(0, table.sparseSize());
assertFalse(table.hasMaximumDensity());
assertFalse(table.isConditional());
assertFalse(table.supportsJointIndexing());
assertInvariants(table);
testRandomOperations(table, 500);
table = FactorTable.create(dmaxby2);
assertEquals(0, table.sparseSize());
assertFalse(table.hasMaximumDensity());
assertFalse(table.isConditional());
assertFalse(table.supportsJointIndexing());
assertInvariants(table);
testRandomOperations(table, 1000);
table.setDirected(BitSetUtil.bitsetFromIndices(12, 0));
assertTrue(table.isDirected());
assertArrayEquals(new int[] {0}, table.getDomainIndexer().getOutputDomainIndices());
assertInvariants(table);
int nOutputsPerInput = 3;
int nInputs = 33;
int n = nOutputsPerInput * nInputs;
int[][] sparseIndices = new int[n][];
for (int i = 0; i < nInputs; ++i)
{
int[] indices = table.getDomainIndexer().randomIndices(rand, null);
for (int j = 0; j < nOutputsPerInput; ++j)
{
indices[0] = j;
sparseIndices[i*nOutputsPerInput + j] = indices.clone();
}
}
double[] sparseWeights = new double[n];
for (int i = 0; i < n; ++i)
{
sparseWeights[i] = rand.nextDouble();
}
table.setWeightsSparse(sparseIndices, sparseWeights);
assertTrue(table.isDirected());
assertFalse(table.isConditional());
assertFalse(table.isConditional());
table.normalizeConditional();
assertTrue(table.isConditional());
expectThrow(DimpleException.class, ".*weights must be normalized.*", table, "setConditional",
BitSetUtil.bitsetFromIndices(2, 1));
assertArrayEquals(new int[] {1}, table.getDomainIndexer().getOutputDomainIndices());
assertTrue(table.isDirected());
assertFalse(table.isConditional());
table.normalizeConditional();
assertTrue(table.isConditional());
table.makeConditional(BitSetUtil.bitsetFromIndices(2, 0));
assertTrue(table.isConditional());
expectNotDense(table, "fullIterator");
expectNotDense(table, "getEnergyForIndicesDense", 1,2,3);
expectNotDense(table, "getWeightForIndicesDense", 1,2,3);
expectNotDense(table, "getEnergyForJointIndex", 42);
expectNotDense(table, "getWeightForJointIndex", 42);
expectNotDense(table, "sparseIndexFromJointIndex", 42);
expectNotDense(table, "sparseIndexToJointIndex", 42);
expectNotDense(table, "setEnergiesDense", ArrayUtil.EMPTY_DOUBLE_ARRAY);
expectNotDense(table, "setWeightsDense", ArrayUtil.EMPTY_DOUBLE_ARRAY);
expectNotDense(table, "setEnergyForJointIndex", 0.0, 42);
expectNotDense(table, "setWeightForJointIndex", 0.0, 42);
expectNotDense(table, "setEnergiesSparse", ArrayUtil.EMPTY_INT_ARRAY, ArrayUtil.EMPTY_DOUBLE_ARRAY);
expectNotDense(table, "setWeightsSparse", ArrayUtil.EMPTY_INT_ARRAY, ArrayUtil.EMPTY_DOUBLE_ARRAY);
}
/**
* Test for {@link FactorTable#product}
*/
@Test
public void testProduct()
{
final Map<IFactorTable, int[]> tables = new HashMap<IFactorTable, int[]>();
final IFactorTable AxB = FactorTable.create(domain2, domain3);
AxB.setRepresentation(FactorTableRepresentation.DENSE_ENERGY);
AxB.randomizeWeights(rand);
final IFactorTable BxC = FactorTable.create(domain3, domain2);
BxC.setRepresentation(FactorTableRepresentation.ALL);
BxC.randomizeWeights(rand);
final IFactorTable BxD = FactorTable.create(domain3, domain256);
BxD.setRepresentation(FactorTableRepresentation.SPARSE_WEIGHT);
for (int n = 20; --n>=0;)
{
for (int b = 0; b < 3; ++b)
{
int d = rand.nextInt(256);
BxD.setWeightForIndices(rand.nextDouble(), b, d);
}
}
// Empty table map
ArrayList<Tuple2<IFactorTable,int[]>> list = new ArrayList<Tuple2<IFactorTable,int[]>>();
assertNull(FactorTable.product(list, FactorTableRepresentation.SPARSE_WEIGHT));
// Errors
list.add(Tuple2.create(AxB, new int[] {}));
expectThrow(IllegalArgumentException.class,
FactorTable.class, "product", list, FactorTableRepresentation.ALL);
list.set(0, Tuple2.create(AxB, new int[] { 0, 1, 2}));
expectThrow(IllegalArgumentException.class, ".*does not match table dimensions.*",
FactorTable.class, "product", list, FactorTableRepresentation.ALL);
list.set(0, Tuple2.create(AxB, new int[] { 0, -1}));
expectThrow(IllegalArgumentException.class, "Negative index mapping.*",
FactorTable.class, "product", list, FactorTableRepresentation.ALL);
list.set(0, Tuple2.create(AxB, new int[] { 0, 1 }));
list.add(Tuple2.create(BxC, new int[] { 0, 2 }));
expectThrow(IllegalArgumentException.class, "Conflicting domain mapping for entry.*",
FactorTable.class, "product", list, FactorTableRepresentation.ALL);
tables.clear();
// Single table with same domain order produces a clone
tables.put(AxB, new int[] { 0 , 1 });
testProduct(tables);
// Try swapping order.
tables.put(AxB, new int[] { 1, 0 });
testProduct(tables);
tables.put(AxB, new int[] { 2, 1});
tables.put(BxC, new int[] { 1, 0});
testProduct(tables);
// Test a sparse case
tables.put(BxD, new int[] { 1, 3 });
testProduct(tables);
}
private void testProduct(Map<IFactorTable, int[]> entryMap)
{
final ArrayList<Tuple2<IFactorTable, int[]>> entries =
new ArrayList<Tuple2<IFactorTable, int[]>>(entryMap.size());
for (Map.Entry<IFactorTable, int[]> entry : entryMap.entrySet())
{
entries.add(new Tuple2<IFactorTable,int[]>(entry));
}
final IFactorTable newTable =
requireNonNull(FactorTable.product(entries, FactorTableRepresentation.ALL_SPARSE));
assertEquals(FactorTableRepresentation.ALL_SPARSE, newTable.getRepresentation());
class Tuple {
final IFactorTable table;
final int dimension;
Tuple(IFactorTable table, int dimension)
{
this.table = table;
this.dimension = dimension;
}
}
final Map<IFactorTable, int[]> oldTableIndices = new HashMap<IFactorTable, int[]>();
final Multimap<Integer, Tuple> inverseMap = HashMultimap.create();
final ArrayList<DiscreteDomain> domains = new ArrayList<DiscreteDomain>();
for (Tuple2<IFactorTable,int[]> entry : entries)
{
final IFactorTable oldTable = entry.first;
final int[] old2new = entry.second;
oldTableIndices.put(oldTable, oldTable.getDomainIndexer().allocateIndices(null));
for (int from = 0; from < old2new.length; ++from)
{
int to = old2new[from];
while (domains.size() <= to)
{
domains.add(null);
}
domains.set(to, oldTable.getDomainIndexer().get(from));
inverseMap.put(to, new Tuple(oldTable, from));
}
}
final JointDomainIndexer newIndexer = newTable.getDomainIndexer();
assertArrayEquals(newIndexer.toArray(), domains.toArray());
DiscreteIndicesIterator indicesIterator = new DiscreteIndicesIterator(newIndexer);
while (indicesIterator.hasNext())
{
final int[] newIndices = indicesIterator.next();
// Set corresponding indices for old factor tables
for (int to = 0; to < newIndices.length; ++to)
{
for (Tuple tuple : inverseMap.get(to))
{
int[] oldIndices = oldTableIndices.get(tuple.table);
oldIndices[tuple.dimension] = newIndices[to];
}
}
double expectedWeight = 1.0;
for (Tuple2<IFactorTable,int[]> entry : entries)
{
final IFactorTable oldTable = entry.first;
int[] oldIndices = oldTableIndices.get(oldTable);
expectedWeight *= oldTable.getWeightForIndices(oldIndices);
}
final double actualWeight = newTable.getWeightForIndices(newIndices);
assertEquals(expectedWeight, actualWeight, 1e-10);
}
}
/**
* Test for {@link IFactorTable#createTableConditionedOn(int[])} method.
*/
@Test
public void testConditionOn()
{
IFactorTable table = FactorTable.create(domain3, domain3, domain3);
expectThrow(ArrayIndexOutOfBoundsException.class, table, "createTableConditionedOn", new int[] { 1, 1} );
expectThrow(ArrayIndexOutOfBoundsException.class, table, "createTableConditionedOn", new int[] { 1, 1} );
expectThrow(IndexOutOfBoundsException.class, table, "createTableConditionedOn", new int[] { -1, 1, 3 } );
table.setEnergyForIndices(2, 0, 0, 0);
table.setEnergyForIndices(3, 1, 1, 1);
table.setEnergyForIndices(4, 2, 2, 2);
table.setEnergyForIndices(5, 0, 1, 2);
testConditionOn(table, -1, -1, 2);
testConditionOn(table, 0, -1, -1);
table.setRepresentation(FactorTableRepresentation.SPARSE_WEIGHT_WITH_INDICES);
testConditionOn(table, -1, -1, 2);
testConditionOn(table, 0, -1, -1);
table.setRepresentation(FactorTableRepresentation.DENSE_ENERGY);
table.randomizeWeights(rand);
testConditionOn(table, 1, -1, -1);
testConditionOn(table, -1, 0, -1);
testConditionOn(table, -1, -1, 2);
table.setRepresentation(FactorTableRepresentation.SPARSE_WEIGHT);
testConditionOn(table, 1, -1, 0);
testConditionOn(table, -1, 2, -1);
table.setDirected(BitSetUtil.bitsetFromIndices(3, 2));
table.setDeterministicOutputIndices(new int[] { 0, 1, 2, 0, 1, 2, 0, 1, 2});
testConditionOn(table, 1, -1, -1);
testConditionOn(table, -1, 1, -1);
table = FactorTable.create(domain256, domain256, domain256, domain256);
int[][] valueMatrix = new int[4][4];
for (int i = 0; i < 4; ++i)
{
for (int j = 0; j < 4; ++j)
{
valueMatrix[i][j] = rand.nextInt(256);
}
}
for (int a : valueMatrix[0])
for (int b : valueMatrix[1])
for (int c : valueMatrix[2])
for (int d : valueMatrix[3])
table.setWeightForIndices(rand.nextDouble(), a, b, c, d);
testConditionOn(table, valueMatrix[0][0], -1, -1, -1);
testConditionOn(table, -1, valueMatrix[1][1], -1, -1);
testConditionOn(table, valueMatrix[0][3], -1, valueMatrix[2][2], valueMatrix[3][1]);
table.setRepresentation(FactorTableRepresentation.SPARSE_WEIGHT);
testConditionOn(table, -1, -1, -1, valueMatrix[3][0]);
testConditionOn(table, -1, valueMatrix[1][3], valueMatrix[2][2], valueMatrix[3][1]);
}
private void testConditionOn(IFactorTable table, int ... valueIndices)
{
final JointDomainIndexer oldDomains = table.getDomainIndexer();
assertEquals(oldDomains.size(), valueIndices.length);
final IFactorTable newTable = table.createTableConditionedOn(valueIndices);
// retain same # of dimensions
final IFactorTable newTable2 = table.createTableConditionedOn(valueIndices, true);
final JointDomainIndexer newDomains = newTable.getDomainIndexer();
final JointDomainIndexer newDomains2 = newTable2.getDomainIndexer();
assertEquals(oldDomains.size(), newDomains2.size());
int nRemoved = 0;
int[] oldToNew = new int[oldDomains.size()];
int[] newToOld = new int[newDomains.size()];
for (int i = 0, j = 0; i < valueIndices.length; ++i)
{
final DiscreteDomain oldDomain = oldDomains.get(i);
final DiscreteDomain newDomain2 = newDomains2.get(i);
int valueIndex = valueIndices[i];
if (valueIndex < 0)
{
assertSame(newDomain2, oldDomain);
oldToNew[i] = j;
newToOld[j] = i;
++j;
}
else
{
assertEquals(1, newDomain2.size());
assertEquals(oldDomain.getElement(valueIndex), newDomain2.getElement(0));
oldToNew[i] = -1;
++nRemoved;
}
}
assertEquals(nRemoved, oldDomains.size() - newDomains.size());
int[] oldIndices = oldDomains.allocateIndices(null);
int[] newIndices = newDomains.allocateIndices(null);
int[] newIndices2 = newDomains2.allocateIndices(null);
DiscreteIndicesIterator newIndicesIterator = new DiscreteIndicesIterator(newDomains, newIndices);
DiscreteIndicesIterator newIndicesIterator2 = new DiscreteIndicesIterator(newDomains2, newIndices2);
while (newIndicesIterator.hasNext())
{
assertTrue(newIndicesIterator2.hasNext());
newIndicesIterator.next();
newIndicesIterator2.next();
for (int i = 0; i < oldIndices.length; ++i)
{
final int vi = valueIndices[i];
oldIndices[i] = oldToNew[i] >=0 ? newIndices[oldToNew[i]] : vi;
assertEquals(oldIndices[i], vi < 0 ? newIndices2[i] : vi);
}
double oldWeight = table.getWeightForIndices(oldIndices);
double newWeight = newTable.getWeightForIndices(newIndices);
double newWeight2 = newTable2.getWeightForIndices(newIndices2);
assertEquals(oldWeight, newWeight, 1e-12);
assertEquals(oldWeight, newWeight2, 1e-12); }
IFactorTableIterator oldIter = table.iterator();
outer:
while (oldIter.hasNext())
{
FactorTableEntry entry = oldIter.next();
requireNonNull(entry);
entry.indices(oldIndices);
for (int i = 0; i < oldIndices.length; ++i)
{
if (oldToNew[i] >= 0)
{
newIndices[oldToNew[i]] = oldIndices[i];
}
else
{
if (valueIndices[i] != oldIndices[i])
{
continue outer;
}
}
}
double oldWeight = entry.weight();
double newWeight = newTable.getWeightForIndices(newIndices);
assertEquals(oldWeight, newWeight, 1e-12);
}
}
/**
* Test for {@link FactorTable#createMarginal} constructor.
*/
@Test
public void testCreateMarginal()
{
testCreateMarginal(domain2, domain3, domain6);
}
private void testCreateMarginal(DiscreteDomain ... domains)
{
final JointDiscreteDomain<?> jointDomain = DiscreteDomain.joint(domains);
final JointDomainIndexer jointIndexer = jointDomain.getDomainIndexer();
final int nDomains = domains.length;
final int jointSize = jointDomain.size();
final int[] indices = jointIndexer.allocateIndices(null);
assertEquals(nDomains, jointDomain.getDimensions());
for (int di = 0; di < nDomains; ++di)
{
IFactorTable table = FactorTable.createMarginal(di, jointDomain);
assertTrue(table.isDeterministicDirected());
assertEquals(jointDomain.size(), table.sparseSize());
assertSame(jointDomain, table.getDomainIndexer().get(1));
assertSame(jointIndexer.get(di), table.getDomainIndexer().get(0));
for (int ji = 0; ji < jointSize; ++ji)
{
jointDomain.getElementIndices(ji, indices);
assertEquals(1.0, table.getWeightForIndices(indices[di], ji), 0.0);
}
}
}
@Test
public void testCreateFromMultidimensionalArray()
{
// Regression case for bug 417
final DiscreteDomain rowDomain = DiscreteDomain.range(1, 3);
final DiscreteDomain colDomain = DiscreteDomain.range(1, 4);
final double[][] matrix =
new double[][] {
new double[] { 1, 3, 0, 0},
new double[] { 2, 0, 5, 0},
new double[] { 0, 4, 0, 6}
};
IFactorTable table = FactorTable.create(matrix, new DiscreteDomain[] { rowDomain, colDomain });
assertArrayEquals(new double[] { 1, 2, 3, 4, 5, 6 }, table.getWeightsSparseUnsafe(), 0.0);
assertEquals(table.countNonZeroWeights(), table.sparseSize());
}
@Test
@Ignore
public void performanceComparison()
{
int iterations = 10000;
DiscreteDomain domain10 = DiscreteDomain.range(0,9);
DiscreteDomain domain20 = DiscreteDomain.range(0,19);
DiscreteDomain domain5 = DiscreteDomain.range(0,4);
DiscreteDomain oneDie = DiscreteDomain.range(1,6);
DiscreteDomain twoDice = DiscreteDomain.range(2,12);
IFactorTable newTable = FactorTable.create(domain10, domain20, domain5);
newTable.setRepresentation(FactorTableRepresentation.DENSE_WEIGHT);
newTable.randomizeWeights(rand);
FactorTablePerformanceTester tester = new FactorTablePerformanceTester(newTable, iterations);
System.out.println("==== dense 10x20x5 table ==== ");
tester.testGetWeightIndexFromTableIndices();
tester.testGetWeightForIndices();
tester.testEvalAsFactorFunction();
newTable.setRepresentation(FactorTableRepresentation.DENSE_ENERGY);
tester.testGibbsUpdateMessage();
newTable.setRepresentation(FactorTableRepresentation.ALL_WEIGHT_WITH_INDICES);
tester.testSumProductUpdate();
System.out.println("\n==== sparse 10x20x5 table ==== ");
// Randomly sparsify the tables
for (int i = newTable.jointSize() / 2; --i>=0;)
{
newTable.setWeightForJointIndex(0.0, rand.nextInt(newTable.jointSize()));
}
newTable.setRepresentation(FactorTableRepresentation.SPARSE_WEIGHT);
tester.testEvalAsFactorFunction();
newTable.setRepresentation(FactorTableRepresentation.ALL_WEIGHT);
tester.testGetWeightIndexFromTableIndices();
tester.testGetWeightForIndices();
newTable.setRepresentation(FactorTableRepresentation.ALL_ENERGY);
tester.testGibbsUpdateMessage();
newTable.setRepresentation(FactorTableRepresentation.SPARSE_WEIGHT_WITH_INDICES);
tester.testSumProductUpdate();
System.out.println("\n==== deterministic 11x6x6 table ====");
newTable = FactorTable.create(BitSetUtil.bitsetFromIndices(3, 0), twoDice, oneDie, oneDie);
JointDomainIndexer indexer = newTable.getDomainIndexer();
int[] deterministicIndices = new int[indexer.getInputCardinality()];
int[] indices = indexer.allocateIndices(null);
for (int ii = 0, iiend = indexer.getInputCardinality(); ii < iiend; ++ii)
{
indexer.inputIndexToIndices(ii, indices);
deterministicIndices[ii] = indices[1] + indices[2];
}
newTable.setDeterministicOutputIndices(deterministicIndices);
assertTrue(newTable.isDeterministicDirected());
tester = new FactorTablePerformanceTester(newTable, iterations);
tester.testGetWeightIndexFromTableIndices();
tester.testGetWeightForIndices();
tester.testEvalAsFactorFunction();
newTable.setRepresentation(FactorTableRepresentation.DETERMINISTIC_WITH_INDICES);
newTable.getWeightsSparseUnsafe();
tester.testSumProductUpdate();
System.out.println("\n==== DONE ====");
}
/**
* Crude speed test of binary search vs. Colt's IntInt hash table to see at which
* point the hash table is a win.
* <p>
* On my machine, the binary search is not much worse than the hash table until the
* size gets above 120-150 elements.
*/
@Test
@Ignore
public void binarySearchVsHashTable()
{
binarySearchVsHashTable(100, 1000);
int iterations = 1000;
for (int size = 10; size < 200; ++size)
{
binarySearchVsHashTable(size, iterations);
}
}
private void binarySearchVsHashTable(int size, int iterations)
{
int[] array = new int[size];
for (int i = 0; i < size; ++i)
{
array[i] = rand.nextInt();
}
Stopwatch timer = Stopwatch.createUnstarted();
timer.start();
for (int i = iterations; --i>=0;)
{
for (int j = size; --j>=0;)
{
Arrays.binarySearch(array, array[j]);
}
}
timer.stop();
long bsns = timer.elapsed(TimeUnit.NANOSECONDS);
timer.reset();
timer.start();
OpenIntIntHashMap map = new OpenIntIntHashMap(size);
for (int i = 0; i < size; ++i)
{
map.put(array[i], i);
}
for (int i = iterations; --i>=0;)
{
for (int j = size; --j>=0;)
{
map.get(array[j]);
}
}
timer.stop();
long hmns = timer.elapsed(TimeUnit.NANOSECONDS);
long calls = (long)size * (long)iterations;
long nsPerBS = bsns / calls;
long nsPerHM = hmns / calls;
System.out.format("Size %d: BS %d vs HT %d ns/call\n", size, nsPerBS, nsPerHM);
}
private void testRandomOperations(IFactorTable table, int nOperations)
{
final boolean supportsJoint = table.supportsJointIndexing();
final JointDomainIndexer domains = table.getDomainIndexer();
int[] indices = new int[table.getDimensions()];
Object[] arguments = new Object[table.getDimensions()];
while (--nOperations >= 0)
{
int operation = rand.nextInt(12);
// For debugging:
// if (nOperations == badIterationNumber)
// {
// Misc.breakpoint();
// }
// System.out.format("%d: operation %d\n", nOperations, operation);
switch (operation)
{
case 0:
// Randomly zero out an entry.
if (table.hasSparseRepresentation())
{
if (table.sparseSize() > 0)
{
int si = rand.nextInt(table.sparseSize());
if (table.hasSparseWeights() || table.getRepresentation().isDeterministic())
{
table.setWeightForSparseIndex(0.0, si);
}
else
{
table.setEnergyForSparseIndex(Double.POSITIVE_INFINITY, si);
}
}
}
else // dense
{
int ji = rand.nextInt(table.jointSize());
if (table.hasDenseWeights())
{
table.setWeightForJointIndex(0.0, ji);
}
else
{
table.setEnergyForJointIndex(Double.POSITIVE_INFINITY, ji);
}
}
break;
case 1:
// Randomly set the representation
int nReps = FactorTableRepresentation.values().length;
FactorTableRepresentation oldRep = table.getRepresentation();
FactorTableRepresentation newRep = FactorTableRepresentation.values()[rand.nextInt(nReps)];
try
{
int nonZeroCount = table.countNonZeroWeights();
table.setRepresentation(newRep);
FactorTableRepresentation actualNewRep = table.getRepresentation();
assertEquals(newRep, actualNewRep);
if (!oldRep.hasSparse() && nonZeroCount < table.jointSize() && newRep.hasSparse())
{
// Make sure that conversion does weed out the zero weights
assertEquals(table.sparseSize(), table.countNonZeroWeights());
}
}
catch (DimpleException ex)
{
assertEquals(oldRep, table.getRepresentation());
if (table.supportsJointIndexing())
{
assertTrue(newRep.isDeterministic());
}
else
{
assertTrue(newRep.isDeterministic() || newRep.hasDense());
}
}
break;
case 3:
{
// Normalize
try
{
table.normalize();
assertTrue(table.isNormalized());
}
catch (UnsupportedOperationException ex)
{
assertTrue(table.isDirected());
assertFalse(table.isNormalized());
assertThat(ex.getMessage(), containsString("not supported for directed factor table"));
}
catch (DimpleException ex)
{
assertThat(ex.getMessage(), containsString("Cannot normalize undirected factor table with zero"));
assertFalse(table.isNormalized());
}
boolean expectZeros = false;
try
{
table.normalizeConditional();
assertTrue(table.isConditional());
}
catch (UnsupportedOperationException ex)
{
assertFalse(table.isConditional());
assertFalse(table.isDirected());
}
catch (DimpleException ex)
{
assertThat(ex.getMessage(), containsString("Cannot normalize directed factor table with zero"));
assertFalse(table.isConditional());
expectZeros = true;
}
try
{
int nNotNormalized = table.normalizeConditional(true);
if (expectZeros)
{
assertTrue(nNotNormalized > 0);
assertFalse(table.isConditional());
assertTrue(table.isDirected());
}
else
{
assertEquals(0, nNotNormalized);
}
}
catch (UnsupportedOperationException ex)
{
assertFalse(table.isConditional());
assertFalse(table.isDirected());
}
break;
}
case 4:
// Compact
int expectedCompacted =
table.hasSparseRepresentation() ? table.sparseSize() - table.countNonZeroWeights() : 0;
int actualCompacted = table.compact();
assertEquals(expectedCompacted, actualCompacted);
assertEquals(0, table.compact());
break;
case 5:
{
// Test get*Slice methods
// Randomly select a set of indices to condition on.
JointDomainIndexer indexer = table.getDomainIndexer();
if (table.hasSparseRepresentation())
{
if (table.sparseSize() == 0)
{
continue;
}
table.sparseIndexToIndices(rand.nextInt(table.sparseSize()), indices);
}
else
{
indexer.jointIndexToIndices(rand.nextInt(table.jointSize()), indices);
}
for (int i = 0; i < indices.length; ++i)
{
final int domainSize = indexer.getDomainSize(i);
if (domainSize > 100000)
{
// Skip test for large domain sizes to avoid taxing heap.
continue;
}
int saved = indices[i];
double[] slice1 = table.getEnergySlice(i, indices);
for (int j = 0; j < domainSize; ++j)
{
indices[i] = j;
assertEquals(slice1[j], table.getEnergyForIndices(indices), 0.0);
}
double[] slice2 = table.getWeightSlice(slice1, i, indices);
assertSame(slice1, slice2);
for (int j = 0; j < domainSize; ++j)
{
indices[i] = j;
assertEquals(slice2[j], table.getWeightForIndices(indices), 0.0);
}
slice2 = table.getWeightSlice(i, indices);
assertNotSame(slice1, slice2);
assertArrayEquals(slice1, slice2, 0.0);
slice2 = table.getWeightSlice(ArrayUtil.EMPTY_DOUBLE_ARRAY, i, indices);
assertArrayEquals(slice1, slice2, 0.0);
slice1 = table.getEnergySlice(ArrayUtil.EMPTY_DOUBLE_ARRAY, i, indices);
for (int j = 0; j < domainSize; ++j)
{
indices[i] = j;
assertEquals(slice1[j], table.getEnergyForIndices(indices), 0.0);
}
assertSame(slice1, table.getEnergySlice(slice1, i, indices));
indices[i] = saved;
}
break;
}
case 6:
// Replace sparse values
{
if (table.hasSparseRepresentation())
{
int [][] oldIndices = table.getIndicesSparseUnsafe();
int newSize = oldIndices.length;
int[][] newIndices = null;
if (oldIndices.length > 1)
{
// Remove one at random: FIXME what if empty?
--newSize;
int i = rand.nextInt(oldIndices.length);
newIndices = Arrays.copyOf(oldIndices, newSize);
System.arraycopy(oldIndices, i + 1, newIndices, i, newIndices.length - i);
}
else
{
++newSize;
// Add one at random
do
{
domains.randomIndices(rand, indices);
} while (table.sparseIndexFromIndices(indices) >= 0);
newIndices = Arrays.copyOf(oldIndices, newSize);
newIndices[newSize - 1] = indices.clone();
}
// Sort using different comparator than one used by table to test that method
// handles out-of-order data.
Arrays.sort(newIndices, Comparators.lexicalIntArray());
double[] newValues = new double[newSize];
for (int i = 0; i < newSize; ++i)
{
newValues[i] = rand.nextDouble();
}
FactorTableRepresentation expectedRep = rand.nextBoolean() ?
FactorTableRepresentation.SPARSE_ENERGY : FactorTableRepresentation.SPARSE_WEIGHT;
if (expectedRep == FactorTableRepresentation.SPARSE_WEIGHT)
{
table.setWeightsSparse(newIndices, newValues);
}
else
{
table.setEnergiesSparse(newIndices, newValues);
}
assertEquals(newSize, table.sparseSize());
assertEquals(expectedRep, table.getRepresentation());
for (int i = 0; i < newSize; ++i)
{
if (expectedRep == FactorTableRepresentation.SPARSE_WEIGHT)
{
assertEquals(newValues[i], table.getWeightForIndices(newIndices[i]), 1e-12);
}
else
{
assertEquals(newValues[i], table.getEnergyForIndices(newIndices[i]), 1e-12);
}
}
}
break;
}
default:
// Random assignments
double weight = rand.nextDouble();
int jointIndex = 0, location = 0;
if (supportsJoint)
{
jointIndex = domains.randomJointIndex(rand);
table.setWeightForJointIndex(weight, jointIndex);
assertWeight(table, weight, jointIndex);
}
else
{
domains.randomIndices(rand, indices);
table.setWeightForIndices(weight, indices);
assertWeight(table, weight, indices);
}
assertFalse(table.isNormalized());
weight = rand.nextDouble();
if (supportsJoint)
{
table.setEnergyForJointIndex(-Math.log(weight), jointIndex);
assertWeight(table, weight, jointIndex);
}
else
{
table.setEnergyForIndices(-Math.log(weight), indices);
assertWeight(table, weight, indices);
}
weight = rand.nextDouble();
if (supportsJoint)
{
jointIndex = domains.randomJointIndex(rand);
location = table.sparseIndexFromJointIndex(jointIndex);
}
else
{
domains.randomIndices(rand, indices);
location = table.sparseIndexFromIndices(indices);
}
if (location >= 0)
{
int si = table.sparseIndexFromJointIndex(jointIndex);
table.setWeightForSparseIndex(weight, si);
assertWeight(table, weight, jointIndex);
weight = weight + 1;
table.setWeightForSparseIndex(weight, si);
assertWeight(table, weight, jointIndex);
}
weight = rand.nextDouble();
if (supportsJoint)
{
jointIndex = domains.randomJointIndex(rand);
location = table.sparseIndexFromJointIndex(jointIndex);
}
else
{
domains.randomIndices(rand, indices);
location = table.sparseIndexFromIndices(indices);
}
if (location >= 0)
{
table.setEnergyForSparseIndex(-Math.log(weight), location);
if (supportsJoint)
assertWeight(table, weight, jointIndex);
else
assertWeight(table, weight, indices);
}
weight = rand.nextDouble();
if (supportsJoint)
{
jointIndex = domains.randomJointIndex(rand);
domains.jointIndexToIndices(jointIndex, indices);
}
else
{
domains.randomIndices(rand, indices);
}
table.setWeightForIndices(weight, indices);
if (supportsJoint)
assertWeight(table, weight, jointIndex);
else
assertWeight(table, weight, indices);
weight = weight + 1;
table.setWeightForIndices(weight, indices);
if (supportsJoint)
assertWeight(table, weight, jointIndex);
else
assertWeight(table, weight, indices);
weight = rand.nextDouble();
if (supportsJoint)
{
jointIndex = domains.randomJointIndex(rand);
domains.jointIndexToIndices(jointIndex, indices);
}
else
{
domains.randomIndices(rand, indices);
}
table.setEnergyForIndices(-Math.log(weight), indices);
if (supportsJoint)
assertWeight(table, weight, jointIndex);
else
assertWeight(table, weight, indices);
weight = rand.nextDouble();
if (supportsJoint)
{
jointIndex = domains.randomJointIndex(rand);
domains.jointIndexToElements(jointIndex, arguments);
}
else
{
domains.randomIndices(rand, indices);
domains.elementsFromIndices(indices, arguments);
}
table.setWeightForElements(weight, arguments);
if (supportsJoint)
assertWeight(table, weight, jointIndex);
else
assertWeight(table, weight, indices);
weight = rand.nextDouble();
if (supportsJoint)
{
jointIndex = domains.randomJointIndex(rand);
domains.jointIndexToElements(jointIndex, arguments);
}
else
{
domains.randomIndices(rand, indices);
domains.elementsFromIndices(indices, arguments);
}
table.setEnergyForElements(-Math.log(weight), arguments);
if (supportsJoint)
assertWeight(table, weight, jointIndex);
else
assertWeight(table, weight, indices);
}
assertInvariants(table);
}
}
private static void assertWeight(IFactorTable table, double weight, int jointIndex)
{
double energy = -Math.log(weight);
assertEquals(energy, table.getEnergyForJointIndex(jointIndex), 1e-12);
assertEquals(weight, table.getWeightForJointIndex(jointIndex), 1e-12);
int sparseIndex = table.sparseIndexFromJointIndex(jointIndex);
if (sparseIndex >= 0)
{
assertEquals(energy, table.getEnergyForSparseIndex(sparseIndex), 1e-12);
assertEquals(weight, table.getWeightForSparseIndex(sparseIndex), 1e-12);
}
int[] indices = table.getDomainIndexer().jointIndexToIndices(jointIndex, null);
assertEquals(energy, table.getEnergyForIndices(indices), 1e-12);
assertEquals(weight, table.getWeightForIndices(indices), 13-12);
Object[] arguments = table.getDomainIndexer().jointIndexToElements(jointIndex, null);
assertEquals(energy, table.getEnergyForElements(arguments), 1e-12);
assertEquals(weight, table.getWeightForElements(arguments), 13-12);
}
private static void assertWeight(IFactorTable table, double weight, int[] indices)
{
double energy = -Math.log(weight);
assertEquals(energy, table.getEnergyForIndices(indices), 1e-12);
assertEquals(weight, table.getWeightForIndices(indices), 1e-12);
int sparseIndex = table.sparseIndexFromIndices(indices);
if (sparseIndex >= 0)
{
assertEquals(energy, table.getEnergyForSparseIndex(sparseIndex), 1e-12);
assertEquals(weight, table.getWeightForSparseIndex(sparseIndex), 1e-12);
}
Object[] elements = table.getDomainIndexer().elementsFromIndices(indices);
assertEquals(energy, table.getEnergyForElements(elements), 1e-12);
assertEquals(weight, table.getWeightForElements(elements), 13-12);
}
public static void assertInvariants(IFactorTable table)
{
FactorTableRepresentation representation = table.getRepresentation();
assertBaseInvariants(table);
table.setRepresentation(representation);
assertEquals(representation, table.getRepresentation());
assertEquals(representation.isDeterministic(), table.hasDeterministicRepresentation());
// Ok to setDirected if it doesn't change anything.
BitSet outputSet = table.getDomainIndexer().getOutputSet();
table.setDirected(outputSet);
assertEquals(outputSet, table.getDomainIndexer().getOutputSet());
table.copy(table); // shouldn't do anything
if (table.hasSparseEnergies())
{
double[] sparseEnergies = table.getEnergiesSparseUnsafe();
for (int si = table.sparseSize(); --si>=0;)
{
assertEquals(table.getEnergyForSparseIndex(si), sparseEnergies[si], 0.0);
}
}
if (table.hasSparseWeights())
{
double[] sparseWeights = table.getWeightsSparseUnsafe();
for (int si = table.sparseSize(); --si>=0;)
{
assertEquals(table.getWeightForSparseIndex(si), sparseWeights[si], 0.0);
}
}
JointDomainReindexer nullConverter =
JointDomainReindexer.createPermuter(table.getDomainIndexer(), table.getDomainIndexer());
IFactorTable table2 = table.convert(nullConverter);
assertBaseEqual(table, table2);
}
public static void assertBaseInvariants(IFactorTableBase table)
{
int nDomains = table.getDimensions();
assertTrue(nDomains >= 0);
JointDomainIndexer domains = table.getDomainIndexer();
assertEquals(nDomains, domains.size());
final boolean supportsJoint = table.supportsJointIndexing();
assertEquals(table.getDomainIndexer().supportsJointIndexing(), supportsJoint);
int expectedJointSize = 1;
int[] domainSizes = new int[nDomains];
for (int i = 0; i < nDomains; ++i)
{
int domainSize = domains.getDomainSize(i);
assertTrue(domainSize > 0);
if (supportsJoint)
{
expectedJointSize *= domainSize;
}
domainSizes[i] = domainSize;
}
BitSet fromSet = table.getInputSet();
if (table.isDirected())
{
requireNonNull(fromSet);
assertTrue(fromSet.cardinality() > 0);
}
else
{
assertNull(fromSet);
}
int size = table.sparseSize();
assertTrue(size >= 0);
int jointSize = -1;
if (supportsJoint)
{
jointSize = table.jointSize();
assertTrue(size <= jointSize);
assertEquals(expectedJointSize, jointSize);
}
Value[] arguments = new Value[nDomains];
int[] indices = new int[nDomains];
assertEquals(table.hasDenseRepresentation(), table.hasDenseWeights() || table.hasDenseEnergies());
assertEquals(table.hasSparseRepresentation(),
table.hasSparseWeights() || table.hasSparseEnergies() || table.isDeterministicDirected());
// Test iteration
{
int i = 0;
for (FactorTableEntry entry : table)
{
final int si = entry.sparseIndex();
final int ji = supportsJoint ? entry.jointIndex() : -1;
assertSame(domains, entry.domains());
assertNotEquals(0.0, entry.weight(), 0.0);
assertFalse(Double.isInfinite(entry.energy()));
assertEquals(entry.energy(), -Math.log(entry.weight()), 1e-12);
if (supportsJoint)
{
assertEquals(entry.weight(), table.getWeightForJointIndex(ji), 1e-12);
}
if (table.hasSparseRepresentation())
{
assertTrue(si >= 0);
assertTrue(si < table.sparseSize());
assertTrue(i <= si);
assertEquals(table.getEnergyForSparseIndex(si), entry.energy(), 0.0);
if (supportsJoint)
{
assertEquals(table.sparseIndexToJointIndex(si), ji);
assertEquals(si, table.sparseIndexFromJointIndex(ji));
}
assertArrayEquals(table.sparseIndexToIndices(si, null), entry.indices());
assertArrayEquals(table.sparseIndexToElements(si, null), entry.values());
}
else
{
assertTrue(si < 0);
}
++i;
}
assertEquals(table.countNonZeroWeights(), i);
double totalWeight = 0.0;
int nonZeroCount = 0;
IFactorTableIterator iter = null;
if (supportsJoint)
{
i = 0;
iter = table.fullIterator();
assertFalse(iter.skipsZeroWeights());
while (iter.hasNext())
{
FactorTableEntry entry = iter.next();
requireNonNull(entry);
assertEquals(i, entry.jointIndex());
assertArrayEquals(entry.indices(), iter.indicesUnsafe());
if (entry.weight() != 0.0)
{
totalWeight += entry.weight();
++nonZeroCount;
}
if (table.hasDenseEnergies())
{
assertEquals(entry.energy(), table.getEnergyForIndicesDense(entry.indices()), 1e-12);
}
if (table.hasDenseWeights())
{
assertEquals(entry.weight(), table.getWeightForIndicesDense(entry.indices()), 1e-12);
}
++i;
}
assertEquals(table.jointSize(), i);
assertEquals(nonZeroCount, table.countNonZeroWeights());
assertEquals((double)nonZeroCount/(double)i, table.density(), 1e-12);
if (table.isNormalized())
{
assertFalse(table.isDirected());
if (totalWeight != 0.0)
{
assertEquals(1.0, totalWeight, 1e-12);
}
}
}
if (table.isConditional())
{
assertTrue(table.isDirected());
if (table.supportsJointIndexing())
{
double totalWeightForPrevInput = 0.0;
for (int ii = 0, isize = domains.getInputCardinality(); ii < isize; ++ii)
{
double totalWeightForInput = 0.0;
for (int oi = 0, osize = domains.getOutputCardinality(); oi < osize; ++oi)
{
int ji = domains.jointIndexFromInputOutputIndices(ii, oi);
totalWeightForInput += table.getWeightForJointIndex(ji);
}
if (ii > 0)
{
assertEquals(totalWeightForPrevInput, totalWeightForInput, 1e-12);
}
totalWeightForPrevInput = totalWeightForInput;
}
}
else
{
// FIXME - test invariants for conditional SparseFactorTable
}
}
if (supportsJoint)
{
assertNull(requireNonNull(iter).next());
i = 0;
iter = table.fullIterator();
while (iter.advance())
{
assertEquals(i, iter.jointIndex());
if (table.hasSparseRepresentation())
{
int si = table.sparseIndexFromJointIndex(i);
if (si < 0) si = -1-si;
assertEquals(si, iter.sparseIndex());
}
else
{
assertEquals(-1, iter.sparseIndex());
}
assertEquals(table.getWeightForJointIndex(i), iter.weight(), 1e-12);
assertEquals(table.getEnergyForJointIndex(i), iter.energy(), 1e-12);
++i;
}
assertFalse(iter.advance());
assertFalse(iter.hasNext());
try
{
iter.remove();
fail("should not get here");
}
catch (UnsupportedOperationException ex)
{
}
}
}
if (table.hasSparseRepresentation())
{
IFactorTableIterator iter = table.iterator();
assertTrue(iter.skipsZeroWeights());
for (int si = 0; si < size; ++si)
{
table.sparseIndexToIndices(si, indices);
for (int j = 0; j < nDomains; ++j)
{
assertTrue(indices[j] >= 0);
assertTrue(indices[j] < domainSizes[j]);
}
assertEquals(si, table.sparseIndexFromIndices(indices));
Object[] elements = new Object[arguments.length];
table.sparseIndexToElements(si, elements);
for (int j = 0; j < nDomains; j++)
arguments[j] = Value.create(domains.get(j), elements[j]);
for (int j = 0; j < nDomains; ++j)
{
assertEquals(arguments[j].getObject(), domains.get(j).getElement(indices[j]));
}
assertEquals(si, table.sparseIndexFromElements(elements));
if (supportsJoint)
{
int joint = table.sparseIndexToJointIndex(si);
assertTrue(joint >= 0);
assertTrue(joint < jointSize);
assertEquals(si, table.sparseIndexFromJointIndex(joint));
}
double energy = table.getEnergyForSparseIndex(si);
table.setEnergyForSparseIndex(energy, si);
assertEquals(energy, table.getEnergyForIndices(indices), 1e-12);
assertEquals(energy, table.getEnergyForElements(elements), 1e-12);
double weight = table.getWeightForSparseIndex(si);
table.setWeightForSparseIndex(weight, si);
assertEquals(weight, table.getWeightForIndices(indices), 1e-12);
assertEquals(weight, table.getWeightForElements(elements), 1e-12);
assertEquals(energy, -Math.log(weight), 1e-12);
if (weight != 0.0)
{
assertTrue(iter.hasNext());
FactorTableEntry entry = iter.next();
requireNonNull(entry);
assertEquals(si, entry.sparseIndex());
assertEquals(energy, entry.energy(), 1e-12);
assertEquals(weight, entry.weight(), 1e-12);
assertArrayEquals(iter.indicesUnsafe(), entry.indices());
}
}
}
if (table.isDeterministicDirected())
{
assertTrue(table.isDirected());
for (int inputIndex = 0, end = domains.getInputCardinality(); inputIndex < end; ++inputIndex)
{
for (int i = 0; i < nDomains; i++)
arguments[i] = Value.create(domains.get(i)); // Empty the values
domains.inputIndexToValues(inputIndex, arguments);
table.evalDeterministic(arguments);
assertEquals(1.0, table.getWeightForValues(arguments), 0.0);
assertEquals(0.0, table.getEnergyForValues(arguments), 0.0);
assertEquals(1.0, table.getWeightForSparseIndex(inputIndex), 0.0);
assertEquals(0.0, table.getEnergyForSparseIndex(inputIndex), 0.0);
assertEquals(inputIndex, table.sparseIndexFromValues(arguments));
}
}
else
{
try
{
table.evalDeterministic(arguments);
fail("expected exception");
}
catch (DimpleException ex)
{
}
}
IFactorTableBase table2 = table.clone();
assertBaseEqual(table, table2);
IFactorTableBase table3 = SerializationTester.clone(table);
assertBaseEqual(table, table3);
}
public static void assertEqual(IFactorTable table1, IFactorTable table2)
{
assertEqualImpl(table1, table2, true);
}
private static void assertEqualImpl(IFactorTable table1, IFactorTable table2, boolean checkBaseEqual)
{
if (checkBaseEqual)
{
assertBaseEqual(table1, table2);
}
assertEquals(table1.sparseSize(), table2.sparseSize());
assertEquals(table1.getDimensions(), table2.getDimensions());
assertEquals(table1.getRepresentation(), table2.getRepresentation());
if (table1.hasSparseIndices())
{
int[][] indices1 = table1.getIndicesSparseUnsafe();
int[][] indices2 = table2.getIndicesSparseUnsafe();
assertEquals(indices1.length, indices2.length);
for (int i = indices1.length; --i>=0;)
{
assertArrayEquals(indices1[i], indices2[i]);
}
}
}
public static void assertBaseEqual(IFactorTableBase table1, IFactorTableBase table2)
{
final boolean bothSupportJoint = table1.supportsJointIndexing() && table2.supportsJointIndexing();
assertEquals(table1.getClass(), table2.getClass());
assertEquals(table1.getInputSet(), table2.getInputSet());
assertEquals(table1.getDimensions(), table2.getDimensions());
int nDomains = table1.getDimensions();
for (int i = 0; i < nDomains; ++i)
{
assertEquals(table1.getDomainIndexer().getDomainSize(i), table2.getDomainIndexer().getDomainSize(i));
assertEquals(table1.getDomainIndexer().get(i), table2.getDomainIndexer().get(i));
}
assertEquals(table1.isDirected(), table2.isDirected());
assertEquals(table1.isNormalized(), table2.isNormalized());
assertEquals(table1.isDeterministicDirected(), table2.isDeterministicDirected());
assertEquals(table1.sparseSize(), table2.sparseSize());
if (table1.hasSparseRepresentation())
{
final int size = table1.sparseSize();
for (int i = 0; i < size; ++i)
{
assertEquals(size, table1.sparseSize());
assertEquals(table1.getWeightForSparseIndex(i), table2.getWeightForSparseIndex(i), 1e-12);
assertEquals(table1.getEnergyForSparseIndex(i), table2.getEnergyForSparseIndex(i), 1e-12);
if (bothSupportJoint)
{
assertEquals(table1.sparseIndexToJointIndex(i), table2.sparseIndexToJointIndex(i));
}
}
}
if (bothSupportJoint)
{
assertEquals(table1.jointSize(), table2.jointSize());
final int jointSize = table1.jointSize();
for (int ji = 0; ji < jointSize; ++ji)
{
assertEquals(table1.getWeightForJointIndex(ji), table2.getWeightForJointIndex(ji), 1e-12);
assertEquals(table1.getEnergyForJointIndex(ji), table2.getEnergyForJointIndex(ji), 1e-12);
}
}
assertEquals(table1.countNonZeroWeights(), table2.countNonZeroWeights());
assertEquals(table1.density(), table2.density(), 1e-12);
if (table1 instanceof IFactorTable)
{
assertEqualImpl((IFactorTable)table1, (IFactorTable)table2, false);
}
}
private static void expectNotDense(IFactorTable table, String methodName, Object ... args)
{
expectThrow(DimpleException.class, ".*" + methodName + ".*dense representation not supported.",
table, methodName, args);
}
}