package edu.stanford.nlp.loglinear.inference; import com.pholser.junit.quickcheck.ForAll; import com.pholser.junit.quickcheck.From; import com.pholser.junit.quickcheck.generator.GenerationStatus; import com.pholser.junit.quickcheck.generator.Generator; import com.pholser.junit.quickcheck.random.SourceOfRandomness; import edu.stanford.nlp.loglinear.model.ConcatVector; import edu.stanford.nlp.loglinear.model.ConcatVectorTable; import edu.stanford.nlp.loglinear.model.GraphicalModel; import org.junit.contrib.theories.Theories; import org.junit.contrib.theories.Theory; import org.junit.runner.RunWith; import java.util.*; import java.util.stream.Collectors; import static org.junit.Assert.*; /** * Created on 8/11/15. * @author keenon * <p> * This is a really tricky thing to test in the quickcheck way, since we basically don't know what we want out of random * graphs unless we run the routines that we're trying to test. The trick here is to implement exhaustive factor * multiplication, which is normally super intractable but easy to get right, as ground truth. */ @RunWith(Theories.class) public class CliqueTreeTest { @Theory public void testCalculateMarginals(@ForAll(sampleSize = 100) @From(GraphicalModelGenerator.class) GraphicalModel model, @ForAll(sampleSize = 2) @From(WeightsGenerator.class) ConcatVector weights) throws Exception { CliqueTree inference = new CliqueTree(model, weights); // This is the basic check that inference works when you first construct the model checkMarginalsAgainstBruteForce(model, weights, inference); // Now we go through several random mutations to the model, and check that everything is still consistent Random r = new Random(); for (int i = 0; i < 10; i++) { randomlyMutateGraphicalModel(model, r); checkMarginalsAgainstBruteForce(model, weights, inference); } } private void randomlyMutateGraphicalModel(GraphicalModel model, Random r) { if (r.nextBoolean() && model.factors.size() > 1) { // Remove one factor at random model.factors.remove(model.factors.toArray(new GraphicalModel.Factor[model.factors.size()])[r.nextInt(model.factors.size())]); } else { // Add a simple binary factor, attaching a variable we haven't touched yet, but do observe, to an // existing variable. This represents the human observation operation in LENSE int maxVar = 0; int attachVar = -1; int attachVarSize = 0; for (GraphicalModel.Factor f : model.factors) { for (int j = 0; j < f.neigborIndices.length; j++) { int k = f.neigborIndices[j]; if (k > maxVar) { maxVar = k; } if (r.nextDouble() > 0.3 || attachVar == -1) { attachVar = k; attachVarSize = f.featuresTable.getDimensions()[j]; } } } int newVar = maxVar + 1; int newVarSize = 1 + r.nextInt(2); if (maxVar >= 8) { boolean[] seenVariables = new boolean[maxVar + 1]; for (GraphicalModel.Factor f : model.factors) { for (int n : f.neigborIndices) seenVariables[n] = true; } for (int j = 0; j < seenVariables.length; j++) { if (!seenVariables[j]) { newVar = j; break; } } // This means the model is already too gigantic to be tractable, so we don't add anything here if (newVar == maxVar + 1) { return; } } if (model.getVariableMetaDataByReference(newVar).containsKey(CliqueTree.VARIABLE_OBSERVED_VALUE)) { int assignment = Integer.parseInt(model.getVariableMetaDataByReference(newVar).get(CliqueTree.VARIABLE_OBSERVED_VALUE)); if (assignment >= newVarSize) { newVarSize = assignment + 1; } } GraphicalModel.Factor binary = model.addFactor(new int[]{newVar, attachVar}, new int[]{newVarSize, attachVarSize}, (assignment) -> { ConcatVector v = new ConcatVector(CONCAT_VEC_COMPONENTS); for (int j = 0; j < v.getNumberOfComponents(); j++) { if (r.nextBoolean()) { v.setSparseComponent(j, r.nextInt(CONCAT_VEC_COMPONENT_LENGTH), r.nextDouble()); } else { double[] d = new double[CONCAT_VEC_COMPONENT_LENGTH]; for (int k = 0; k < d.length; k++) { d[k] = r.nextDouble(); } v.setDenseComponent(j, d); } } return v; }); // "Cook" the randomly generated feature vector thunks, so they don't change as we run the system for (int[] assignment : binary.featuresTable) { ConcatVector randomlyGenerated = binary.featuresTable.getAssignmentValue(assignment).get(); binary.featuresTable.setAssignmentValue(assignment, () -> randomlyGenerated); } } } private void checkMarginalsAgainstBruteForce(GraphicalModel model, ConcatVector weights, CliqueTree inference) { CliqueTree.MarginalResult result = inference.calculateMarginals(); double[][] marginals = result.marginals; Set<TableFactor> tableFactors = model.factors.stream().map(factor -> new TableFactor(weights, factor)).collect(Collectors.toSet()); assert (tableFactors.size() == model.factors.size()); // this is the super slow but obviously correct way to get global marginals TableFactor bruteForce = null; for (TableFactor factor : tableFactors) { if (bruteForce == null) bruteForce = factor; else bruteForce = bruteForce.multiply(factor); } if (bruteForce != null) { // observe out all variables that have been registered TableFactor observed = bruteForce; for (int i = 0; i < bruteForce.neighborIndices.length; i++) { int n = bruteForce.neighborIndices[i]; if (model.getVariableMetaDataByReference(n).containsKey(CliqueTree.VARIABLE_OBSERVED_VALUE)) { int value = Integer.parseInt(model.getVariableMetaDataByReference(n).get(CliqueTree.VARIABLE_OBSERVED_VALUE)); // Check that the marginals reflect the observation for (int j = 0; j < marginals[n].length; j++) { assertEquals(j == value ? 1.0 : 0.0, marginals[n][j], 1.0e-9); } if (observed.neighborIndices.length > 1) { observed = observed.observe(n, value); } // If we've observed everything, then just quit else return; } } bruteForce = observed; // Spot check each of the marginals in the brute force calculation double[][] bruteMarginals = bruteForce.getSummedMarginals(); int index = 0; for (int i : bruteForce.neighborIndices) { boolean isEqual = true; double[] brute = bruteMarginals[index]; index++; assert (brute != null); assert (marginals[i] != null); for (int j = 0; j < brute.length; j++) { if (Double.isNaN(brute[j])) { isEqual = false; break; } if (Math.abs(brute[j] - marginals[i][j]) > 3.0e-2) { isEqual = false; break; } } if (!isEqual) { System.err.println("Arrays not equal! Variable " + i); System.err.println("\tGold: " + Arrays.toString(brute)); System.err.println("\tResult: " + Arrays.toString(marginals[i])); } assertArrayEquals(brute, marginals[i], 3.0e-2); } // Spot check the partition function double goldPartitionFunction = bruteForce.valueSum(); // Correct to within 3% assertEquals(goldPartitionFunction, result.partitionFunction, goldPartitionFunction * 3.0e-2); // Check the joint marginals marginals: for (GraphicalModel.Factor f : model.factors) { assertTrue(result.jointMarginals.containsKey(f)); TableFactor bruteForceJointMarginal = bruteForce; outer: for (int n : bruteForce.neighborIndices) { for (int i : f.neigborIndices) if (i == n) { continue outer; } if (bruteForceJointMarginal.neighborIndices.length > 1) { bruteForceJointMarginal = bruteForceJointMarginal.sumOut(n); } else { int[] fixedAssignment = new int[f.neigborIndices.length]; for (int i = 0; i < fixedAssignment.length; i++) { fixedAssignment[i] = Integer.parseInt(model.getVariableMetaDataByReference(f.neigborIndices[i]).get(CliqueTree.VARIABLE_OBSERVED_VALUE)); } for (int[] assn : result.jointMarginals.get(f)) { if (Arrays.equals(assn, fixedAssignment)) { assertEquals(1.0, result.jointMarginals.get(f).getAssignmentValue(assn), 1.0e-7); } else { if (result.jointMarginals.get(f).getAssignmentValue(assn) != 0) { TableFactor j = result.jointMarginals.get(f); for (int[] assignment : j) { System.err.println(Arrays.toString(assignment) + ": " + j.getAssignmentValue(assignment)); } } assertEquals(0.0, result.jointMarginals.get(f).getAssignmentValue(assn), 1.0e-7); } } continue marginals; } } // Find the correspondence between the brute force joint marginal, which may be missing variables // because they were observed out of the table, and the output joint marginals, which are always an exact // match for the original factor int[] backPointers = new int[f.neigborIndices.length]; int[] observedValue = new int[f.neigborIndices.length]; for (int i = 0; i < backPointers.length; i++) { if (model.getVariableMetaDataByReference(f.neigborIndices[i]).containsKey(CliqueTree.VARIABLE_OBSERVED_VALUE)) { observedValue[i] = Integer.parseInt(model.getVariableMetaDataByReference(f.neigborIndices[i]).get(CliqueTree.VARIABLE_OBSERVED_VALUE)); backPointers[i] = -1; } else { observedValue[i] = -1; backPointers[i] = -1; for (int j = 0; j < bruteForceJointMarginal.neighborIndices.length; j++) { if (bruteForceJointMarginal.neighborIndices[j] == f.neigborIndices[i]) { backPointers[i] = j; } } assert (backPointers[i] != -1); } } double sum = bruteForceJointMarginal.valueSum(); if (sum == 0.0) sum = 1; outer: for (int[] assignment : result.jointMarginals.get(f)) { int[] bruteForceMarginalAssignment = new int[bruteForceJointMarginal.neighborIndices.length]; for (int i = 0; i < assignment.length; i++) { if (backPointers[i] != -1) { bruteForceMarginalAssignment[backPointers[i]] = assignment[i]; } // Make sure all assignments that don't square with observations get 0 weight else { assert (observedValue[i] != -1); if (assignment[i] != observedValue[i]) { if (result.jointMarginals.get(f).getAssignmentValue(assignment) != 0) { System.err.println("Joint marginals: " + Arrays.toString(result.jointMarginals.get(f).neighborIndices)); System.err.println("Assignment: " + Arrays.toString(assignment)); System.err.println("Observed Value: " + Arrays.toString(observedValue)); for (int[] assn : result.jointMarginals.get(f)) { System.err.println("\t" + Arrays.toString(assn) + ":" + result.jointMarginals.get(f).getAssignmentValue(assn)); } } assertEquals(0.0, result.jointMarginals.get(f).getAssignmentValue(assignment), 1.0e-7); continue outer; } } } assertEquals(bruteForceJointMarginal.getAssignmentValue(bruteForceMarginalAssignment) / sum, result.jointMarginals.get(f).getAssignmentValue(assignment), 1.0e-3); } } } else { for (double[] marginal : marginals) { for (double d : marginal) { assertEquals(1.0 / marginal.length, d, 3.0e-2); } } } } @Theory public void testCalculateMap(@ForAll(sampleSize = 100) @From(GraphicalModelGenerator.class) GraphicalModel model, @ForAll(sampleSize = 2) @From(WeightsGenerator.class) ConcatVector weights) throws Exception { if (model.factors.size() == 0) return; CliqueTree inference = new CliqueTree(model, weights); // This is the basic check that inference works when you first construct the model checkMAPAgainstBruteForce(model, weights, inference); // Now we go through several random mutations to the model, and check that everything is still consistent Random r = new Random(); for (int i = 0; i < 10; i++) { randomlyMutateGraphicalModel(model, r); checkMAPAgainstBruteForce(model, weights, inference); } } public void checkMAPAgainstBruteForce(GraphicalModel model, ConcatVector weights, CliqueTree inference) { int[] map = inference.calculateMAP(); Set<TableFactor> tableFactors = model.factors.stream().map(factor -> new TableFactor(weights, factor)).collect(Collectors.toSet()); // this is the super slow but obviously correct way to get global marginals TableFactor bruteForce = null; for (TableFactor factor : tableFactors) { if (bruteForce == null) bruteForce = factor; else bruteForce = bruteForce.multiply(factor); } assert (bruteForce != null); // observe out all variables that have been registered TableFactor observed = bruteForce; for (int n : bruteForce.neighborIndices) { if (model.getVariableMetaDataByReference(n).containsKey(CliqueTree.VARIABLE_OBSERVED_VALUE)) { int value = Integer.parseInt(model.getVariableMetaDataByReference(n).get(CliqueTree.VARIABLE_OBSERVED_VALUE)); if (observed.neighborIndices.length > 1) { observed = observed.observe(n, value); } // If we've observed everything, then just quit else return; } } bruteForce = observed; int largestVariableNum = 0; for (GraphicalModel.Factor f : model.factors) { for (int i : f.neigborIndices) if (i > largestVariableNum) largestVariableNum = i; } // this is presented in true order, where 0 corresponds to var 0 int[] mapValueAssignment = new int[largestVariableNum + 1]; // this is kept in the order that the factor presents to us int[] highestValueAssignment = new int[bruteForce.neighborIndices.length]; for (int[] assignment : bruteForce) { if (bruteForce.getAssignmentValue(assignment) > bruteForce.getAssignmentValue(highestValueAssignment)) { highestValueAssignment = assignment; for (int i = 0; i < assignment.length; i++) { mapValueAssignment[bruteForce.neighborIndices[i]] = assignment[i]; } } } int[] forcedAssignments = new int[largestVariableNum + 1]; for (int i = 0; i < mapValueAssignment.length; i++) { if (model.getVariableMetaDataByReference(i).containsKey(CliqueTree.VARIABLE_OBSERVED_VALUE)) { mapValueAssignment[i] = Integer.parseInt(model.getVariableMetaDataByReference(i).get(CliqueTree.VARIABLE_OBSERVED_VALUE)); forcedAssignments[i] = mapValueAssignment[i]; } } if (!Arrays.equals(mapValueAssignment, map)) { System.err.println("---"); System.err.println("Relevant variables: " + Arrays.toString(bruteForce.neighborIndices)); System.err.println("Var Sizes: " + Arrays.toString(bruteForce.getDimensions())); System.err.println("MAP: " + Arrays.toString(map)); System.err.println("Brute force map: " + Arrays.toString(mapValueAssignment)); System.err.println("Forced assignments: " + Arrays.toString(forcedAssignments)); } for (int i : bruteForce.neighborIndices) { // Only check defined variables assertEquals(mapValueAssignment[i], map[i]); } } ///////////////////////////////////////////////////////////////////////////// // // A copy of these generators exists in GradientSourceTest in the learning module. If any bug fixes are made here, // remember to update that code as well by copy-paste. // ///////////////////////////////////////////////////////////////////////////// public static final int CONCAT_VEC_COMPONENTS = 2; public static final int CONCAT_VEC_COMPONENT_LENGTH = 3; public static class WeightsGenerator extends Generator<ConcatVector> { public WeightsGenerator(Class<ConcatVector> type) { super(type); } @Override public ConcatVector generate(SourceOfRandomness sourceOfRandomness, GenerationStatus generationStatus) { ConcatVector v = new ConcatVector(CONCAT_VEC_COMPONENTS); for (int x = 0; x < CONCAT_VEC_COMPONENTS; x++) { if (sourceOfRandomness.nextBoolean()) { v.setSparseComponent(x, sourceOfRandomness.nextInt(CONCAT_VEC_COMPONENT_LENGTH), sourceOfRandomness.nextDouble()); } else { double[] val = new double[sourceOfRandomness.nextInt(CONCAT_VEC_COMPONENT_LENGTH)]; for (int y = 0; y < val.length; y++) { val[y] = sourceOfRandomness.nextDouble(); } v.setDenseComponent(x, val); } } return v; } } public static class GraphicalModelGenerator extends Generator<GraphicalModel> { public GraphicalModelGenerator(Class<GraphicalModel> type) { super(type); } private Map<String, String> generateMetaData(SourceOfRandomness sourceOfRandomness, Map<String, String> metaData) { int numPairs = sourceOfRandomness.nextInt(9); for (int i = 0; i < numPairs; i++) { int key = sourceOfRandomness.nextInt(); int value = sourceOfRandomness.nextInt(); metaData.put("key:" + key, "value:" + value); } return metaData; } @Override public GraphicalModel generate(SourceOfRandomness sourceOfRandomness, GenerationStatus generationStatus) { GraphicalModel model = new GraphicalModel(); // Create the variables and factors. These are deliberately tiny so that the brute force approach is tractable int[] variableSizes = new int[8]; for (int i = 0; i < variableSizes.length; i++) { variableSizes[i] = sourceOfRandomness.nextInt(1, 3); } // Traverse in a randomized BFS to ensure the generated graph is a tree if (sourceOfRandomness.nextBoolean()) { generateCliques(variableSizes, new ArrayList<>(), new HashSet<>(), model, sourceOfRandomness); } // Or generate a linear chain CRF, because our random BFS doesn't generate these very often, and they're very // common in practice, so worth testing densely else { for (int i = 0; i < variableSizes.length; i++) { // Add unary factor GraphicalModel.Factor unary = model.addFactor(new int[]{i}, new int[]{variableSizes[i]}, (assignment) -> { ConcatVector features = new ConcatVector(CONCAT_VEC_COMPONENTS); for (int j = 0; j < CONCAT_VEC_COMPONENTS; j++) { if (sourceOfRandomness.nextBoolean()) { features.setSparseComponent(j, sourceOfRandomness.nextInt(CONCAT_VEC_COMPONENT_LENGTH), sourceOfRandomness.nextDouble()); } else { double[] dense = new double[sourceOfRandomness.nextInt(CONCAT_VEC_COMPONENT_LENGTH)]; for (int k = 0; k < dense.length; k++) { dense[k] = sourceOfRandomness.nextDouble(); } features.setDenseComponent(j, dense); } } return features; }); // "Cook" the randomly generated feature vector thunks, so they don't change as we run the system for (int[] assignment : unary.featuresTable) { ConcatVector randomlyGenerated = unary.featuresTable.getAssignmentValue(assignment).get(); unary.featuresTable.setAssignmentValue(assignment, () -> randomlyGenerated); } // Add binary factor if (i < variableSizes.length - 1) { GraphicalModel.Factor binary = model.addFactor(new int[]{i, i + 1}, new int[]{variableSizes[i], variableSizes[i + 1]}, (assignment) -> { ConcatVector features = new ConcatVector(CONCAT_VEC_COMPONENTS); for (int j = 0; j < CONCAT_VEC_COMPONENTS; j++) { if (sourceOfRandomness.nextBoolean()) { features.setSparseComponent(j, sourceOfRandomness.nextInt(CONCAT_VEC_COMPONENT_LENGTH), sourceOfRandomness.nextDouble()); } else { double[] dense = new double[sourceOfRandomness.nextInt(CONCAT_VEC_COMPONENT_LENGTH)]; for (int k = 0; k < dense.length; k++) { dense[k] = sourceOfRandomness.nextDouble(); } features.setDenseComponent(j, dense); } } return features; }); // "Cook" the randomly generated feature vector thunks, so they don't change as we run the system for (int[] assignment : binary.featuresTable) { ConcatVector randomlyGenerated = binary.featuresTable.getAssignmentValue(assignment).get(); binary.featuresTable.setAssignmentValue(assignment, () -> randomlyGenerated); } } } } // Add metadata to the variables, factors, and model generateMetaData(sourceOfRandomness, model.getModelMetaDataByReference()); for (int i = 0; i < 20; i++) { generateMetaData(sourceOfRandomness, model.getVariableMetaDataByReference(i)); } for (GraphicalModel.Factor factor : model.factors) { generateMetaData(sourceOfRandomness, factor.getMetaDataByReference()); } // Observe a few of the variables for (GraphicalModel.Factor f : model.factors) { for (int i = 0; i < f.neigborIndices.length; i++) { if (sourceOfRandomness.nextDouble() > 0.8) { int obs = sourceOfRandomness.nextInt(f.featuresTable.getDimensions()[i]); model.getVariableMetaDataByReference(f.neigborIndices[i]).put(CliqueTree.VARIABLE_OBSERVED_VALUE, "" + obs); } } } return model; } private void generateCliques(int[] variableSizes, List<Integer> startSet, Set<Integer> alreadyRepresented, GraphicalModel model, SourceOfRandomness randomness) { if (alreadyRepresented.size() == variableSizes.length) return; // Generate the clique variable set List<Integer> cliqueContents = new ArrayList<>(); cliqueContents.addAll(startSet); alreadyRepresented.addAll(startSet); while (true) { if (alreadyRepresented.size() == variableSizes.length) break; if (cliqueContents.size() == 0 || randomness.nextDouble(0, 1) < 0.7) { int gen; do { gen = randomness.nextInt(variableSizes.length); } while (alreadyRepresented.contains(gen)); alreadyRepresented.add(gen); cliqueContents.add(gen); } else break; } // Create the actual table int[] neighbors = new int[cliqueContents.size()]; int[] neighborSizes = new int[neighbors.length]; for (int j = 0; j < neighbors.length; j++) { neighbors[j] = cliqueContents.get(j); neighborSizes[j] = variableSizes[neighbors[j]]; } ConcatVectorTable table = new ConcatVectorTable(neighborSizes); for (int[] assignment : table) { // Generate a vector ConcatVector v = new ConcatVector(CONCAT_VEC_COMPONENTS); for (int x = 0; x < CONCAT_VEC_COMPONENTS; x++) { if (randomness.nextBoolean()) { v.setSparseComponent(x, randomness.nextInt(32), randomness.nextDouble()); } else { double[] val = new double[randomness.nextInt(12)]; for (int y = 0; y < val.length; y++) { val[y] = randomness.nextDouble(); } v.setDenseComponent(x, val); } } // set vec in table table.setAssignmentValue(assignment, () -> v); } model.addFactor(table, neighbors); // Pick the number of children List<Integer> availableVariables = new ArrayList<>(); availableVariables.addAll(cliqueContents); availableVariables.removeAll(startSet); int numChildren = randomness.nextInt(0, availableVariables.size()); if (numChildren == 0) return; List<List<Integer>> children = new ArrayList<>(); for (int i = 0; i < numChildren; i++) { children.add(new ArrayList<>()); } // divide up the shared variables across the children int cursor = 0; while (true) { if (availableVariables.size() == 0) break; if (children.get(cursor).size() == 0 || randomness.nextBoolean()) { int gen = randomness.nextInt(availableVariables.size()); children.get(cursor).add(availableVariables.get(gen)); availableVariables.remove(availableVariables.get(gen)); } else break; cursor = (cursor + 1) % numChildren; } for (List<Integer> shared1 : children) { for (int i : shared1) { for (List<Integer> shared2 : children) { assert (shared1 == shared2 || !shared2.contains(i)); } } } for (List<Integer> shared : children) { if (shared.size() > 0) generateCliques(variableSizes, shared, alreadyRepresented, model, randomness); } } } }