package edu.stanford.nlp.loglinear.inference;
import com.pholser.junit.quickcheck.ForAll;
import com.pholser.junit.quickcheck.From;
import com.pholser.junit.quickcheck.generator.GenerationStatus;
import com.pholser.junit.quickcheck.generator.Generator;
import com.pholser.junit.quickcheck.generator.InRange;
import com.pholser.junit.quickcheck.random.SourceOfRandomness;
import edu.stanford.nlp.loglinear.model.ConcatVector;
import edu.stanford.nlp.loglinear.model.ConcatVectorTable;
import edu.stanford.nlp.loglinear.model.GraphicalModel;
import org.junit.contrib.theories.Theories;
import org.junit.contrib.theories.Theory;
import org.junit.runner.RunWith;
import java.util.*;
import java.util.stream.Collectors;
import static org.junit.Assert.*;
import static org.junit.Assume.assumeTrue;
/**
* Created on 8/12/15.
* @author keenon
* <p>
* Tries to quickcheck our factor functions, as well as unit test for documentation and simple verification.
*/
@RunWith(Theories.class)
public class TableFactorTest {
@Theory
public void testConstructWithObservations(@ForAll(sampleSize = 50) @From(PartiallyObservedConstructorDataGenerator.class) PartiallyObservedConstructorData data,
@ForAll(sampleSize = 2) @From(ConcatVectorGenerator.class) ConcatVector weights) throws Exception {
int[] obsArray = new int[9];
for (int i = 0; i < obsArray.length; i++) obsArray[i] = -1;
for (int i = 0; i < data.observations.length; i++) {
obsArray[data.factor.neigborIndices[i]] = data.observations[i];
}
TableFactor normalObservations = new TableFactor(weights, data.factor);
for (int i = 0; i < obsArray.length; i++) {
if (obsArray[i] != -1) {
normalObservations = normalObservations.observe(i, obsArray[i]);
}
}
TableFactor constructedObservations = new TableFactor(weights, data.factor, data.observations);
assertArrayEquals(normalObservations.neighborIndices, constructedObservations.neighborIndices);
for (int[] assn : normalObservations) {
assertEquals(normalObservations.getAssignmentValue(assn), constructedObservations.getAssignmentValue(assn), 1.0e-9);
}
}
@Theory
public void testObserve(@ForAll(sampleSize = 50) @From(TableFactorGenerator.class) TableFactor factor,
@ForAll(sampleSize = 2) @InRange(minInt = 0, maxInt = 3) int observe,
@ForAll(sampleSize = 2) @InRange(minInt = 0, maxInt = 1) int value) throws Exception {
if (!Arrays.stream(factor.neighborIndices).boxed().collect(Collectors.toSet()).contains(observe)) return;
if (factor.neighborIndices.length == 1) return;
TableFactor observedOut = factor.observe(observe, value);
int observeIndex = -1;
for (int i = 0; i < factor.neighborIndices.length; i++) {
if (factor.neighborIndices[i] == observe) observeIndex = i;
}
for (int[] assignment : factor) {
if (assignment[observeIndex] == value) {
assertEquals(factor.getAssignmentValue(assignment), observedOut.getAssignmentValue(subsetAssignment(assignment, factor, observedOut)), 1.0e-7);
}
}
}
@Theory
public void testGetMaxedMarginals(@ForAll(sampleSize = 50) @From(TableFactorGenerator.class) TableFactor factor,
@ForAll(sampleSize = 10) @InRange(minInt = 0, maxInt = 3) int marginalizeTo) throws Exception {
if (!Arrays.stream(factor.neighborIndices).boxed().collect(Collectors.toSet()).contains(marginalizeTo)) return;
int indexOfVariable = -1;
for (int i = 0; i < factor.neighborIndices.length; i++) {
if (factor.neighborIndices[i] == marginalizeTo) {
indexOfVariable = i;
break;
}
}
assumeTrue(indexOfVariable > -1);
double[] gold = new double[factor.getDimensions()[indexOfVariable]];
for (int i = 0; i < gold.length; i++) {
gold[i] = Double.NEGATIVE_INFINITY;
}
for (int[] assignment : factor) {
gold[assignment[indexOfVariable]] = Math.max(gold[assignment[indexOfVariable]], factor.getAssignmentValue(assignment));
}
normalize(gold);
assertArrayEquals(gold, factor.getMaxedMarginals()[indexOfVariable], 1.0e-5);
}
@Theory
public void testGetSummedMarginals(@ForAll(sampleSize = 50) @From(TableFactorGenerator.class) TableFactor factor,
@ForAll(sampleSize = 10) @InRange(minInt = 0, maxInt = 3) int marginalizeTo) throws Exception {
if (!Arrays.stream(factor.neighborIndices).boxed().collect(Collectors.toSet()).contains(marginalizeTo)) return;
int indexOfVariable = -1;
for (int i = 0; i < factor.neighborIndices.length; i++) {
if (factor.neighborIndices[i] == marginalizeTo) {
indexOfVariable = i;
break;
}
}
assumeTrue(indexOfVariable > -1);
double[] gold = new double[factor.getDimensions()[indexOfVariable]];
for (int[] assignment : factor) {
gold[assignment[indexOfVariable]] = gold[assignment[indexOfVariable]] + factor.getAssignmentValue(assignment);
}
normalize(gold);
assertArrayEquals(gold, factor.getSummedMarginals()[indexOfVariable], 1.0e-5);
}
private void normalize(double[] arr) {
double sum = 0;
for (double d : arr) sum += d;
if (sum == 0) {
for (int i = 0; i < arr.length; i++) {
arr[i] = 1.0 / arr.length;
}
} else {
for (int i = 0; i < arr.length; i++) {
arr[i] = arr[i] / sum;
}
}
}
@Theory
public void testValueSum(@ForAll(sampleSize = 50) @From(TableFactorGenerator.class) TableFactor factor) throws Exception {
double sum = 0.0;
for (int[] assignment : factor) {
sum += factor.getAssignmentValue(assignment);
}
assertEquals(sum, factor.valueSum(), 1.0e-5);
}
@Theory
public void testMaxOut(@ForAll(sampleSize = 50) @From(TableFactorGenerator.class) TableFactor factor,
@ForAll(sampleSize = 10) @InRange(minInt = 0, maxInt = 3) int marginalize) throws Exception {
if (!Arrays.stream(factor.neighborIndices).boxed().collect(Collectors.toSet()).contains(marginalize)) return;
if (factor.neighborIndices.length <= 1) return;
TableFactor maxedOut = factor.maxOut(marginalize);
assertEquals(factor.neighborIndices.length - 1, maxedOut.neighborIndices.length);
assertTrue(!Arrays.stream(maxedOut.neighborIndices).boxed().collect(Collectors.toSet()).contains(marginalize));
for (int[] assignment : factor) {
assertTrue(factor.getAssignmentValue(assignment) >= Double.NEGATIVE_INFINITY);
assertTrue(factor.getAssignmentValue(assignment) <= maxedOut.getAssignmentValue(subsetAssignment(assignment, factor, maxedOut)));
}
Map<List<Integer>, List<int[]>> subsetToSuperset = subsetToSupersetAssignments(factor, maxedOut);
for (List<Integer> subsetAssignmentList : subsetToSuperset.keySet()) {
double max = Double.NEGATIVE_INFINITY;
for (int[] supersetAssignment : subsetToSuperset.get(subsetAssignmentList)) {
max = Math.max(max, factor.getAssignmentValue(supersetAssignment));
}
int[] subsetAssignment = new int[subsetAssignmentList.size()];
for (int i = 0; i < subsetAssignment.length; i++) {
subsetAssignment[i] = subsetAssignmentList.get(i);
}
assertEquals(max, maxedOut.getAssignmentValue(subsetAssignment), 1.0e-5);
}
}
@Theory
public void testSumOut(@ForAll(sampleSize = 50) @From(TableFactorGenerator.class) TableFactor factor,
@ForAll(sampleSize = 10) @InRange(minInt = 0, maxInt = 3) int marginalize) throws Exception {
if (!Arrays.stream(factor.neighborIndices).boxed().collect(Collectors.toSet()).contains(marginalize)) return;
if (factor.neighborIndices.length <= 1) return;
TableFactor summedOut = factor.sumOut(marginalize);
assertEquals(factor.neighborIndices.length - 1, summedOut.neighborIndices.length);
assertTrue(!Arrays.stream(summedOut.neighborIndices).boxed().collect(Collectors.toSet()).contains(marginalize));
Map<List<Integer>, List<int[]>> subsetToSuperset = subsetToSupersetAssignments(factor, summedOut);
for (List<Integer> subsetAssignmentList : subsetToSuperset.keySet()) {
double sum = 0.0;
for (int[] supersetAssignment : subsetToSuperset.get(subsetAssignmentList)) {
sum += factor.getAssignmentValue(supersetAssignment);
}
int[] subsetAssignment = new int[subsetAssignmentList.size()];
for (int i = 0; i < subsetAssignment.length; i++) {
subsetAssignment[i] = subsetAssignmentList.get(i);
}
assertEquals(sum, summedOut.getAssignmentValue(subsetAssignment), 1.0e-5);
}
}
@Theory
public void testMultiply(@ForAll(sampleSize = 10) @From(TableFactorGenerator.class) TableFactor factor1,
@ForAll(sampleSize = 10) @From(TableFactorGenerator.class) TableFactor factor2) throws Exception {
TableFactor result = factor1.multiply(factor2);
for (int[] assignment : result) {
double factor1Value = factor1.getAssignmentValue(subsetAssignment(assignment, result, factor1));
double factor2Value = factor2.getAssignmentValue(subsetAssignment(assignment, result, factor2));
assertEquals(factor1Value * factor2Value, result.getAssignmentValue(assignment), 1.0e-5);
}
// Check for no duplication
for (int i = 0; i < result.neighborIndices.length; i++) {
for (int j = 0; j < result.neighborIndices.length; j++) {
if (i == j) continue;
assertNotEquals(result.neighborIndices[i], result.neighborIndices[j]);
}
}
}
public static int[] variableSizes = new int[]{
2, 4, 2, 3
};
public static class TableFactorGenerator extends Generator<TableFactor> {
public TableFactorGenerator(Class<TableFactor> type) {
super(type);
}
@Override
public TableFactor generate(SourceOfRandomness sourceOfRandomness, GenerationStatus generationStatus) {
int numNeighbors = sourceOfRandomness.nextInt(1, 3);
int[] neighbors = new int[numNeighbors];
int[] dimensions = new int[numNeighbors];
Set<Integer> usedNeighbors = new HashSet<>();
for (int i = 0; i < neighbors.length; i++) {
while (true) {
int neighbor = sourceOfRandomness.nextInt(0, 3);
if (!usedNeighbors.contains(neighbor)) {
usedNeighbors.add(neighbor);
neighbors[i] = neighbor;
dimensions[i] = variableSizes[neighbor];
break;
}
}
}
// Make sure we get some all-0 factor tables
double multiple = sourceOfRandomness.nextDouble();
TableFactor factor = new TableFactor(neighbors, dimensions);
for (int[] assignment : factor) {
factor.setAssignmentValue(assignment, multiple * sourceOfRandomness.nextDouble());
}
return factor;
}
}
public static class ConcatVectorGenerator extends Generator<ConcatVector> {
public ConcatVectorGenerator(Class<ConcatVector> type) {
super(type);
}
@Override
public ConcatVector generate(SourceOfRandomness sourceOfRandomness, GenerationStatus generationStatus) {
ConcatVector vec = new ConcatVector(1);
double[] d = new double[20];
for (int i = 0; i < d.length; i++) d[i] = sourceOfRandomness.nextDouble();
vec.setDenseComponent(0, d);
return vec;
}
}
private static class PartiallyObservedConstructorData {
public GraphicalModel.Factor factor;
public int[] observations;
}
public static class PartiallyObservedConstructorDataGenerator extends Generator<PartiallyObservedConstructorData> {
public PartiallyObservedConstructorDataGenerator(Class<PartiallyObservedConstructorData> type) {
super(type);
}
@Override
public PartiallyObservedConstructorData generate(SourceOfRandomness sourceOfRandomness, GenerationStatus generationStatus) {
int len = sourceOfRandomness.nextInt(1, 4);
Set<Integer> taken = new HashSet<>();
int[] neighborIndices = new int[len];
int[] dimensions = new int[len];
int[] observations = new int[len];
int numObserved = 0;
for (int i = 0; i < len; i++) {
int j = sourceOfRandomness.nextInt(8);
while (taken.contains(j)) {
j = sourceOfRandomness.nextInt(8);
}
taken.add(j);
neighborIndices[i] = j;
dimensions[i] = sourceOfRandomness.nextInt(1, 3);
if (sourceOfRandomness.nextBoolean() && numObserved + 1 < dimensions.length) {
observations[i] = sourceOfRandomness.nextInt(dimensions[i]);
numObserved++;
} else observations[i] = -1;
}
ConcatVectorTable t = new ConcatVectorTable(dimensions);
ConcatVectorGenerator gen = new ConcatVectorGenerator(ConcatVector.class);
for (int[] assn : t) {
ConcatVector vec = gen.generate(sourceOfRandomness, generationStatus);
t.setAssignmentValue(assn, () -> vec);
}
PartiallyObservedConstructorData data = new PartiallyObservedConstructorData();
data.factor = new GraphicalModel.Factor(t, neighborIndices);
data.observations = observations;
return data;
}
}
/**
* Takes a full assignment from a superset factor, and figures out how to map it into a subset factor. This is very
* useful for testing that functional properties are not violated across both product and marginalization steps.
*
* @param supersetAssignment the assignment in the superset factor
* @param superset the superset factor, containing the variables from the subset
* @param subset the subset factor, containing some of the variables found in the superset
* @return an assignment into the subset factor
*/
private int[] subsetAssignment(int[] supersetAssignment, TableFactor superset, TableFactor subset) {
int[] subsetAssignment = new int[subset.neighborIndices.length];
for (int i = 0; i < subset.neighborIndices.length; i++) {
int var = subset.neighborIndices[i];
subsetAssignment[i] = -1;
for (int j = 0; j < superset.neighborIndices.length; j++) {
if (superset.neighborIndices[j] == var) {
subsetAssignment[i] = supersetAssignment[j];
break;
}
}
assert (subsetAssignment[i] != -1);
}
return subsetAssignment;
}
/**
* Convenience function to construct a subset to superset assignment map. Each subset assignment will be mapping
* to a large number of superset assignments.
*
* @param superset the superset factor to map to
* @param subset the subset factor to map from
* @return a map from subset assignment to list of superset assignment
*/
private Map<List<Integer>, List<int[]>> subsetToSupersetAssignments(TableFactor superset, TableFactor subset) {
Map<List<Integer>, List<int[]>> subsetToSupersets = new HashMap<>();
for (int[] assignment : subset) {
List<Integer> subsetAssignmentList = Arrays.stream(assignment).boxed().collect(Collectors.toList());
List<int[]> supersetAssignments = new ArrayList<>();
for (int[] supersetAssignment : superset) {
if (Arrays.equals(assignment, subsetAssignment(supersetAssignment, superset, subset))) {
supersetAssignments.add(supersetAssignment);
}
}
subsetToSupersets.put(subsetAssignmentList, supersetAssignments);
}
return subsetToSupersets;
}
}