package edu.stanford.nlp.loglinear.model; import com.pholser.junit.quickcheck.ForAll; import com.pholser.junit.quickcheck.From; import com.pholser.junit.quickcheck.generator.GenerationStatus; import com.pholser.junit.quickcheck.generator.Generator; import com.pholser.junit.quickcheck.random.SourceOfRandomness; import org.junit.contrib.theories.Theories; import org.junit.contrib.theories.Theory; import org.junit.runner.RunWith; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.util.Map; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; /** * Created on 8/11/15. * @author keenon * <p> * Quickchecks a couple of pieces of functionality, but mostly the serialization and deserialization (basically the only * non-trivial section). */ @RunWith(Theories.class) public class GraphicalModelTest { @Theory public void testProtoModel(@ForAll(sampleSize = 50) @From(GraphicalModelGenerator.class) GraphicalModel graphicalModel) throws IOException { ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); graphicalModel.writeToStream(byteArrayOutputStream); byteArrayOutputStream.close(); byte[] bytes = byteArrayOutputStream.toByteArray(); ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(bytes); GraphicalModel recovered = GraphicalModel.readFromStream(byteArrayInputStream); assertTrue(graphicalModel.valueEquals(recovered, 1.0e-5)); } @Theory public void testClone(@ForAll(sampleSize = 50) @From(GraphicalModelGenerator.class) GraphicalModel graphicalModel) throws IOException { GraphicalModel clone = graphicalModel.cloneModel(); assertTrue(graphicalModel.valueEquals(clone, 1.0e-5)); } @Theory public void testGetVariableSizes(@ForAll(sampleSize = 50) @From(GraphicalModelGenerator.class) GraphicalModel graphicalModel) throws IOException { int[] sizes = graphicalModel.getVariableSizes(); for (GraphicalModel.Factor f : graphicalModel.factors) { for (int i = 0; i < f.neigborIndices.length; i++) { assertEquals(f.featuresTable.getDimensions()[i], sizes[f.neigborIndices[i]]); } } } public static class GraphicalModelGenerator extends Generator<GraphicalModel> { public GraphicalModelGenerator(Class<GraphicalModel> type) { super(type); } private Map<String, String> generateMetaData(SourceOfRandomness sourceOfRandomness, Map<String, String> metaData) { int numPairs = sourceOfRandomness.nextInt(9); for (int i = 0; i < numPairs; i++) { int key = sourceOfRandomness.nextInt(); int value = sourceOfRandomness.nextInt(); metaData.put("key:" + key, "value:" + value); } return metaData; } @Override public GraphicalModel generate(SourceOfRandomness sourceOfRandomness, GenerationStatus generationStatus) { GraphicalModel model = new GraphicalModel(); // Create the variables and factors int[] variableSizes = new int[20]; for (int i = 0; i < 20; i++) { variableSizes[i] = sourceOfRandomness.nextInt(1, 5); } int numFactors = sourceOfRandomness.nextInt(12); for (int i = 0; i < numFactors; i++) { int[] neighbors = new int[sourceOfRandomness.nextInt(1, 3)]; int[] neighborSizes = new int[neighbors.length]; for (int j = 0; j < neighbors.length; j++) { neighbors[j] = sourceOfRandomness.nextInt(20); neighborSizes[j] = variableSizes[neighbors[j]]; } ConcatVectorTable table = new ConcatVectorTable(neighborSizes); for (int[] assignment : table) { int numComponents = sourceOfRandomness.nextInt(7); // Generate a vector ConcatVector v = new ConcatVector(numComponents); for (int x = 0; x < numComponents; x++) { if (sourceOfRandomness.nextBoolean()) { v.setSparseComponent(x, sourceOfRandomness.nextInt(32), sourceOfRandomness.nextDouble()); } else { double[] val = new double[sourceOfRandomness.nextInt(12)]; for (int y = 0; y < val.length; y++) { val[y] = sourceOfRandomness.nextDouble(); } v.setDenseComponent(x, val); } } // set vec in table table.setAssignmentValue(assignment, () -> v); } model.addFactor(table, neighbors); } // Add metadata to the variables, factors, and model generateMetaData(sourceOfRandomness, model.getModelMetaDataByReference()); for (int i = 0; i < 20; i++) { generateMetaData(sourceOfRandomness, model.getVariableMetaDataByReference(i)); } for (GraphicalModel.Factor factor : model.factors) { generateMetaData(sourceOfRandomness, factor.getMetaDataByReference()); } return model; } } }