package edu.stanford.nlp.loglinear.learning; import com.pholser.junit.quickcheck.ForAll; import com.pholser.junit.quickcheck.From; import com.pholser.junit.quickcheck.generator.GenerationStatus; import com.pholser.junit.quickcheck.generator.Generator; import com.pholser.junit.quickcheck.random.SourceOfRandomness; import edu.stanford.nlp.loglinear.inference.CliqueTree; import edu.stanford.nlp.loglinear.inference.TableFactor; import edu.stanford.nlp.loglinear.model.ConcatVector; import edu.stanford.nlp.loglinear.model.ConcatVectorTable; import edu.stanford.nlp.loglinear.model.GraphicalModel; import org.junit.contrib.theories.Theories; import org.junit.contrib.theories.Theory; import org.junit.runner.RunWith; import java.util.*; import java.util.stream.Collectors; import static org.junit.Assert.assertEquals; /** * Created on 8/24/15. * @author keenon * <p> * Uses the definition of the derivative to verify that the calculated gradients are approximately correct. */ @RunWith(Theories.class) public class LogLikelihoodFunctionTest { @Theory public void testGetSummaryForInstance(@ForAll(sampleSize = 50) @From(GraphicalModelDatasetGenerator.class) GraphicalModel[] dataset, @ForAll(sampleSize = 2) @From(WeightsGenerator.class) ConcatVector weights) throws Exception { LogLikelihoodDifferentiableFunction fn = new LogLikelihoodDifferentiableFunction(); for (GraphicalModel model : dataset) { double goldLogLikelihood = logLikelihood(model, weights); ConcatVector goldGradient = definitionOfDerivative(model, weights); ConcatVector gradient = new ConcatVector(0); double logLikelihood = fn.getSummaryForInstance(model, weights, gradient); assertEquals(goldLogLikelihood, logLikelihood, Math.max(1.0e-3, goldLogLikelihood * 1.0e-2)); // Our check for gradient similarity involves distance between endpoints of vectors, instead of elementwise // similarity, b/c it can be controlled as a percentage ConcatVector difference = goldGradient.deepClone(); difference.addVectorInPlace(gradient, -1); double distance = Math.sqrt(difference.dotProduct(difference)); // The tolerance here is pretty large, since the gold gradient is computed approximately // 5% still tells us whether everything is working or not though if (distance > 5.0e-2) { System.err.println("Definitional and calculated gradient differ!"); System.err.println("Definition approx: " + goldGradient); System.err.println("Calculated: " + gradient); } assertEquals(0.0, distance, 5.0e-2); } } /** * The slowest, but obviously correct way to get log likelihood. We've already tested the partition function in * the CliqueTreeTest, but in the interest of making things as different as possible to catch any lurking bugs or * numerical issues, we use the brute force approach here. * * @param model the model to get the log-likelihood of, assumes labels for assignments * @param weights the weights to get the log-likelihood at * @return the log-likelihood */ private double logLikelihood(GraphicalModel model, ConcatVector weights) { Set<TableFactor> tableFactors = model.factors.stream().map(factor -> new TableFactor(weights, factor)).collect(Collectors.toSet()); assert (tableFactors.size() == model.factors.size()); // this is the super slow but obviously correct way to get global marginals TableFactor bruteForce = null; for (TableFactor factor : tableFactors) { if (bruteForce == null) bruteForce = factor; else bruteForce = bruteForce.multiply(factor); } assert (bruteForce != null); // observe out all variables that have been registered TableFactor observed = bruteForce; for (int n : bruteForce.neighborIndices) { if (model.getVariableMetaDataByReference(n).containsKey(CliqueTree.VARIABLE_OBSERVED_VALUE)) { int value = Integer.parseInt(model.getVariableMetaDataByReference(n).get(CliqueTree.VARIABLE_OBSERVED_VALUE)); if (observed.neighborIndices.length > 1) { observed = observed.observe(n, value); } // If we've observed everything, then just quit else return 0.0; } } bruteForce = observed; // Now we can get a partition function double partitionFunction = bruteForce.valueSum(); // For now, we'll assume that all the variables are given for training. EM is another problem altogether int[] assignment = new int[bruteForce.neighborIndices.length]; for (int i = 0; i < assignment.length; i++) { assert (!model.getVariableMetaDataByReference(bruteForce.neighborIndices[i]).containsKey(CliqueTree.VARIABLE_OBSERVED_VALUE)); assignment[i] = Integer.parseInt(model.getVariableMetaDataByReference(bruteForce.neighborIndices[i]).get(LogLikelihoodDifferentiableFunction.VARIABLE_TRAINING_VALUE)); } if (bruteForce.getAssignmentValue(assignment) == 0 || partitionFunction == 0) { return Double.NEGATIVE_INFINITY; } return Math.log(bruteForce.getAssignmentValue(assignment)) - Math.log(partitionFunction); } /** * Slowest possible way to calculate a derivative for a model: exhaustive definitional calculation, using the super * slow logLikelihood function from this test suite. * * @param model the model the get the derivative for * @param weights the weights to get the derivative at * @return the derivative of the log likelihood with respect to the weights */ private ConcatVector definitionOfDerivative(GraphicalModel model, ConcatVector weights) { double epsilon = 1.0e-7; ConcatVector goldGradient = new ConcatVector(CONCAT_VEC_COMPONENTS); for (int i = 0; i < CONCAT_VEC_COMPONENTS; i++) { double[] component = new double[CONCAT_VEC_COMPONENT_LENGTH]; for (int j = 0; j < CONCAT_VEC_COMPONENT_LENGTH; j++) { // Create a unit vector pointing in the direction of this element of this component ConcatVector unitVectorIJ = new ConcatVector(CONCAT_VEC_COMPONENTS); unitVectorIJ.setSparseComponent(i, j, 1.0); // Create a +eps weight vector ConcatVector weightsPlusEpsilon = weights.deepClone(); weightsPlusEpsilon.addVectorInPlace(unitVectorIJ, epsilon); // Create a -eps weight vector ConcatVector weightsMinusEpsilon = weights.deepClone(); weightsMinusEpsilon.addVectorInPlace(unitVectorIJ, -epsilon); // Use the definition (f(x+eps) - f(x-eps))/(2*eps) component[j] = (logLikelihood(model, weightsPlusEpsilon) - logLikelihood(model, weightsMinusEpsilon)) / (2 * epsilon); // If we encounter an impossible assignment, logLikelihood will return negative infinity, which will // screw with the definitional calculation if (Double.isNaN(component[j])) component[j] = 0.0; } goldGradient.setDenseComponent(i, component); } return goldGradient; } public static class GraphicalModelDatasetGenerator extends Generator<GraphicalModel[]> { GraphicalModelGenerator modelGenerator = new GraphicalModelGenerator(GraphicalModel.class); public GraphicalModelDatasetGenerator(Class<GraphicalModel[]> type) { super(type); } @Override public GraphicalModel[] generate(SourceOfRandomness sourceOfRandomness, GenerationStatus generationStatus) { GraphicalModel[] dataset = new GraphicalModel[sourceOfRandomness.nextInt(1, 10)]; for (int i = 0; i < dataset.length; i++) { dataset[i] = modelGenerator.generate(sourceOfRandomness, generationStatus); for (GraphicalModel.Factor f : dataset[i].factors) { for (int j = 0; j < f.neigborIndices.length; j++) { int n = f.neigborIndices[j]; int dim = f.featuresTable.getDimensions()[j]; dataset[i].getVariableMetaDataByReference(n).put(LogLikelihoodDifferentiableFunction.VARIABLE_TRAINING_VALUE, "" + sourceOfRandomness.nextInt(dim)); } } } return dataset; } } ///////////////////////////////////////////////////////////////////////////// // // These generators COPIED DIRECTLY FROM CliqueTreeTest in the inference module. // ///////////////////////////////////////////////////////////////////////////// public static final int CONCAT_VEC_COMPONENTS = 2; public static final int CONCAT_VEC_COMPONENT_LENGTH = 3; public static class WeightsGenerator extends Generator<ConcatVector> { public WeightsGenerator(Class<ConcatVector> type) { super(type); } @Override public ConcatVector generate(SourceOfRandomness sourceOfRandomness, GenerationStatus generationStatus) { ConcatVector v = new ConcatVector(CONCAT_VEC_COMPONENTS); for (int x = 0; x < CONCAT_VEC_COMPONENTS; x++) { if (sourceOfRandomness.nextBoolean()) { v.setSparseComponent(x, sourceOfRandomness.nextInt(CONCAT_VEC_COMPONENT_LENGTH), sourceOfRandomness.nextDouble()); } else { double[] val = new double[sourceOfRandomness.nextInt(CONCAT_VEC_COMPONENT_LENGTH)]; for (int y = 0; y < val.length; y++) { val[y] = sourceOfRandomness.nextDouble(); } v.setDenseComponent(x, val); } } return v; } } public static class GraphicalModelGenerator extends Generator<GraphicalModel> { public GraphicalModelGenerator(Class<GraphicalModel> type) { super(type); } private Map<String, String> generateMetaData(SourceOfRandomness sourceOfRandomness, Map<String, String> metaData) { int numPairs = sourceOfRandomness.nextInt(9); for (int i = 0; i < numPairs; i++) { int key = sourceOfRandomness.nextInt(); int value = sourceOfRandomness.nextInt(); metaData.put("key:" + key, "value:" + value); } return metaData; } @Override public GraphicalModel generate(SourceOfRandomness sourceOfRandomness, GenerationStatus generationStatus) { GraphicalModel model = new GraphicalModel(); // Create the variables and factors. These are deliberately tiny so that the brute force approach is tractable int[] variableSizes = new int[8]; for (int i = 0; i < variableSizes.length; i++) { variableSizes[i] = sourceOfRandomness.nextInt(1, 3); } // Traverse in a randomized BFS to ensure the generated graph is a tree generateCliques(variableSizes, new ArrayList<>(), new HashSet<>(), model, sourceOfRandomness); // Add metadata to the variables, factors, and model generateMetaData(sourceOfRandomness, model.getModelMetaDataByReference()); for (int i = 0; i < 20; i++) { generateMetaData(sourceOfRandomness, model.getVariableMetaDataByReference(i)); } for (GraphicalModel.Factor factor : model.factors) { generateMetaData(sourceOfRandomness, factor.getMetaDataByReference()); } // Observe a few of the variables for (GraphicalModel.Factor f : model.factors) { for (int i = 0; i < f.neigborIndices.length; i++) { if (sourceOfRandomness.nextDouble() > 0.8) { int obs = sourceOfRandomness.nextInt(f.featuresTable.getDimensions()[i]); model.getVariableMetaDataByReference(f.neigborIndices[i]).put(CliqueTree.VARIABLE_OBSERVED_VALUE, "" + obs); } } } return model; } private void generateCliques(int[] variableSizes, List<Integer> startSet, Set<Integer> alreadyRepresented, GraphicalModel model, SourceOfRandomness randomness) { if (alreadyRepresented.size() == variableSizes.length) return; // Generate the clique variable set List<Integer> cliqueContents = new ArrayList<>(); cliqueContents.addAll(startSet); while (true) { if (alreadyRepresented.size() == variableSizes.length) break; if (cliqueContents.size() == 0 || randomness.nextDouble(0, 1) < 0.7) { int gen; do { gen = randomness.nextInt(variableSizes.length); } while (alreadyRepresented.contains(gen)); alreadyRepresented.add(gen); cliqueContents.add(gen); } else break; } // Create the actual table int[] neighbors = new int[cliqueContents.size()]; int[] neighborSizes = new int[neighbors.length]; for (int j = 0; j < neighbors.length; j++) { neighbors[j] = cliqueContents.get(j); neighborSizes[j] = variableSizes[neighbors[j]]; } ConcatVectorTable table = new ConcatVectorTable(neighborSizes); for (int[] assignment : table) { // Generate a vector ConcatVector v = new ConcatVector(CONCAT_VEC_COMPONENTS); for (int x = 0; x < CONCAT_VEC_COMPONENTS; x++) { if (randomness.nextBoolean()) { v.setSparseComponent(x, randomness.nextInt(CONCAT_VEC_COMPONENT_LENGTH), randomness.nextDouble()); } else { double[] val = new double[randomness.nextInt(CONCAT_VEC_COMPONENT_LENGTH)]; for (int y = 0; y < val.length; y++) { val[y] = randomness.nextDouble(); } v.setDenseComponent(x, val); } } // set vec in table table.setAssignmentValue(assignment, () -> v); } model.addFactor(table, neighbors); // Pick the number of children List<Integer> availableVariables = new ArrayList<>(); availableVariables.addAll(cliqueContents); availableVariables.removeAll(startSet); int numChildren = randomness.nextInt(0, availableVariables.size()); if (numChildren == 0) return; List<List<Integer>> children = new ArrayList<>(); for (int i = 0; i < numChildren; i++) { children.add(new ArrayList<>()); } // divide up the shared variables across the children int cursor = 0; while (true) { if (availableVariables.size() == 0) break; if (children.get(cursor).size() == 0 || randomness.nextBoolean()) { int gen = randomness.nextInt(availableVariables.size()); children.get(cursor).add(availableVariables.get(gen)); availableVariables.remove(availableVariables.get(gen)); } else break; cursor = (cursor + 1) % numChildren; } for (List<Integer> shared1 : children) { for (int i : shared1) { for (List<Integer> shared2 : children) { assert (shared1 == shared2 || !shared2.contains(i)); } } } for (List<Integer> shared : children) { if (shared.size() > 0) generateCliques(variableSizes, shared, alreadyRepresented, model, randomness); } } } }