/*******************************************************************************
* Copyright 2014 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.test.FactorFunctions;
import static java.util.Objects.*;
import static org.junit.Assert.*;
import java.util.BitSet;
import java.util.Collection;
import java.util.HashSet;
import java.util.concurrent.atomic.AtomicReference;
import com.analog.lyric.collect.BitSetUtil;
import com.analog.lyric.dimple.factorfunctions.core.FactorFunction;
import com.analog.lyric.dimple.model.domains.DiscreteDomain;
import com.analog.lyric.dimple.model.domains.Domain;
import com.analog.lyric.dimple.model.values.IndexedValue;
import com.analog.lyric.dimple.model.values.Value;
import com.analog.lyric.dimple.test.DimpleTestBase;
import com.google.common.math.DoubleMath;
/**
* Contains utility methods for testing {@link FactorFunction} implementations.
*/
public class FactorFunctionTester extends DimpleTestBase
{
public static void assertEvalEnergy(FactorFunction function, double expectedEnergy, Value ... values)
{
assertEquals(expectedEnergy, function.evalEnergy(values), 1e-15);
}
public static void assertEvalEnergyReal(FactorFunction function, double expectedEnergy, double ... values)
{
Value[] reals = new Value[values.length];
for (int i = values.length; --i>=0;)
{
reals[i] = Value.createReal(values[i]);
}
assertEquals(expectedEnergy, function.evalEnergy(reals), 1e-15);
}
public static void assertEvalEnergyDiscrete(FactorFunction function, double expectedEnergy, DiscreteDomain domain,
Object ... values)
{
Value[] discretes = new Value[values.length];
for (int i = values.length; --i>=0;)
{
discretes[i] = Value.create(domain, values[i]);
}
assertEquals(expectedEnergy, function.evalEnergy(discretes), 1e-15);
}
public static void assertEvalEnergyBit(FactorFunction function, double expectedEnergy, Object ... values)
{
Value[] discretes = new Value[values.length];
for (int i = values.length; --i>=0;)
{
discretes[i] = Value.create(DiscreteDomain.bit(), values[i]);
}
assertEquals(expectedEnergy, function.evalEnergy(discretes), 1e-15);
}
public static void assertEvalEnergyBool(FactorFunction function, double expectedEnergy, boolean ... values)
{
Value[] discretes = new Value[values.length];
for (int i = values.length; --i>=0;)
{
discretes[i] = Value.create(DiscreteDomain.bool(), values[i]);
}
assertEquals(expectedEnergy, function.evalEnergy(discretes), 1e-15);
}
/**
* Tests given factor function over set of test cases.
* <p>
* Exercises the following methods:
* <ul>
* <li>{@link FactorFunction#evalDeterministic(Value[])}
* <li>{@link FactorFunction#updateDeterministicLimit(int)}
* <li>{@link FactorFunction#updateDeterministic(Value[], Collection, AtomicReference)}
* </ul>
* <p>
* @param function is the function to be tested.
* @param domains specifies the domains of all of the values in each test case. The domains list must have
* at least one entry, but may be shorter than the test case length, in which case the last domain in the
* list is used for all remaining entries.
* @param outputIndices specifies the subindexes of each test cases that represent the output values. The
* other values are input values.
* @param testCases one or more object arrays specifying the inputs and expected outputs for deterministic
* evaluation. Each test case must be the same length.
*
* @see #testEvalDeterministic(FactorFunction, Domain, int[], Value[][])
* @see #testEvalDeterministic(FactorFunction, Domain, Value[][])
*/
public static void testEvalDeterministic(FactorFunction function, Domain[] domains, int[] outputIndices,
Value[]... testCases)
{
assertTrue(function.isDeterministicDirected());
assertTrue(function.isDirected());
AtomicReference<int[]> changedOutputsHolder = new AtomicReference<int[]>();
for (int i = 0, end = testCases.length; i < end; ++i)
{
Object[] prevTestCase = i > 0 ? Value.toObjects(testCases[i - 1]) : null;
Value[] testCaseValues = testCases[i];
final int caseSize = testCaseValues.length;
final int nInputs = caseSize - outputIndices.length;
final BitSet inputSet = BitSetUtil.bitsetFromIndices(caseSize, outputIndices);
inputSet.flip(0, caseSize);
final int[] inputIndices = new int[nInputs];
for (int k = 0, j = -1; (j = inputSet.nextSetBit(j+1)) >= 0; ++k)
{
inputIndices[k] = j;
}
Value[] objectValues = copyInputs(inputIndices, testCaseValues);
Value[] resultValues = function.evalDeterministicToCopy(testCaseValues);
function.evalDeterministic(objectValues);
Object[] testCase = Value.toObjects(testCaseValues);
Object[] objects = Value.toObjects(objectValues);
assertArrayEquals(testCase, objects);
assertArrayEquals(Value.toObjects(resultValues), objects);
assertEquals(0.0, function.evalEnergy(testCase), 0.0);
assertEquals(1.0, function.eval(testCase), 0.0);
if (prevTestCase != null)
{
boolean outputsDiffer = false;
for (int outputIndex : outputIndices)
{
objects[outputIndex] = prevTestCase[outputIndex];
if (!prevTestCase[outputIndex].equals(testCase[outputIndex]))
{
outputsDiffer = true;
}
}
if (outputsDiffer)
{
// If one of the outputs is different, then using the inputs from one
// test case with the outputs from another should result in a zero weight/ infinite energy.
assertEquals(Double.POSITIVE_INFINITY, function.evalEnergy(objects), 0.0);
assertEquals(0.0, function.eval(objects), 0.0);
}
Collection<IndexedValue> oldValues = new HashSet<IndexedValue>();
if (function.updateDeterministicLimit(caseSize) <= 0)
{
// This should just do a full update
Value[] values = Value.createFromObjects(objects, domains);
changedOutputsHolder.set(null);
function.updateDeterministic(values, oldValues, changedOutputsHolder);
assertNull(changedOutputsHolder.get());
for (int inputIndex : inputIndices)
{
assertEquals(testCase[inputIndex], values[inputIndex].getObject());
}
for (int outputIndex : outputIndices)
{
assertTrue(valueFuzzyEquals(values[outputIndex], testCase[outputIndex], 1e-10));
}
}
else
{
// Incrementally update starting with previous case.
Value[] values = Value.createFromObjects(prevTestCase, domains);
// Test index exception
for (int out : outputIndices)
{
IndexedValue.SingleList badIndexes = IndexedValue.SingleList.create(out, values[out]);
try
{
function.updateDeterministic(values, badIndexes, changedOutputsHolder);
}
catch (IndexOutOfBoundsException ex)
{
}
badIndexes.release();
}
// Alternate between steps of size 1 and 2 to exercise multi-value updates
for (int step = 1, j = 0; j < nInputs; j += step, step ^= 3)
{
oldValues.clear();
for (int k = j, endk = Math.min(nInputs, j + step); k < endk; ++k)
{
int index = inputIndices[k];
Value oldValue = values[index].clone();
values[index].setObject(testCase[index]);
oldValues.add(new IndexedValue(index, oldValue));
}
changedOutputsHolder.set(null);
function.updateDeterministic(values, oldValues, changedOutputsHolder);
int[] changedOutputs = changedOutputsHolder.get();
if (changedOutputs != null)
{
for (int outputIndex : changedOutputs)
{
assertTrue(outputIndex >= 0);
assertTrue(outputIndex < caseSize);
assertFalse(inputSet.get(outputIndex));
}
}
}
for (int inputIndex : inputIndices)
{
assertEquals(testCase[inputIndex], values[inputIndex].getObject());
}
for (int outputIndex : outputIndices)
{
assertTrue(valueFuzzyEquals(values[outputIndex], testCase[outputIndex], 1e-10));
}
}
}
}
}
/**
* Shorthand for call to {@link #testEvalDeterministic(FactorFunction, Domain[], int[], Value[][])} like
* following:
* <pre>
* testEvalDeterministic(function, new Domain[] { domain }, inputIndices, testCases)
* </pre>
*/
public static void testEvalDeterministic(FactorFunction function, Domain domain, int[] inputIndices,
Value[]... testCases)
{
testEvalDeterministic(function, new Domain[] { domain }, inputIndices, testCases);
}
/**
* Shorthand for call to {@link #testEvalDeterministic(FactorFunction, Domain[], int[], Value[][])}
* like following:
* <pre>
* testEvalDeterministic(function, new Domain[] { domain }, new int[] { 0 }, testCases)
* </pre>
*/
public static void testEvalDeterministic(FactorFunction function, Domain domain, Value[]... testCases)
{
testEvalDeterministic(function, domain, new int[] { 0 }, testCases);
}
public static void testEvalDeterministic(FactorFunction function, Domain[] domains, Value[]... testCases)
{
testEvalDeterministic(function, domains, new int[] { 0 }, testCases);
}
/*-----------------
* Private methods
*/
private static boolean valueFuzzyEquals(Value value, Object object, double tolerance)
{
Object valueObj = requireNonNull(value.getObject());
if (valueObj instanceof Number && object instanceof Number)
{
return DoubleMath.fuzzyEquals(((Number)valueObj).doubleValue(), ((Number)object).doubleValue(), tolerance);
}
return valueObj.equals(object);
}
/**
* Returns a new Value array with same length as {@code objects} and shallow copying only
* the entries specified by {@code inputIndices}.
*/
private static Value[] copyInputs(int[] inputIndices, Value[] objects)
{
Value[] copy = new Value[objects.length];
for (int i = 0; i < objects.length; i++)
{
copy[i] = Value.create(objects[i].getDomain());
}
for (int i : inputIndices)
{
copy[i].setFrom(objects[i]);
}
return copy;
}
}