package edu.stanford.nlp.loglinear.storage;
import com.pholser.junit.quickcheck.ForAll;
import com.pholser.junit.quickcheck.From;
import com.pholser.junit.quickcheck.generator.GenerationStatus;
import com.pholser.junit.quickcheck.generator.Generator;
import com.pholser.junit.quickcheck.random.SourceOfRandomness;
import edu.stanford.nlp.loglinear.model.GraphicalModel;
import edu.stanford.nlp.loglinear.model.GraphicalModelTest;
import org.junit.contrib.theories.Theories;
import org.junit.contrib.theories.Theory;
import org.junit.runner.RunWith;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
/**
* Created on 10/17/15.
* @author keenon
* <p>
* This just double checks that we can write and read these model batches without loss.
*/
@RunWith(Theories.class)
public class ModelBatchTest {
@Theory
public void testProtoBatch(@ForAll(sampleSize = 50) @From(BatchGenerator.class) ModelBatch batch) throws IOException {
ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
batch.writeToStream(byteArrayOutputStream);
byteArrayOutputStream.close();
byte[] bytes = byteArrayOutputStream.toByteArray();
ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(bytes);
ModelBatch recovered = new ModelBatch(byteArrayInputStream);
byteArrayInputStream.close();
assertEquals(batch.size(), recovered.size());
for (int i = 0; i < batch.size(); i++) {
assertTrue(batch.get(i).valueEquals(recovered.get(i), 1.0e-5));
}
}
@Theory
public void testProtoBatchModifier(@ForAll(sampleSize = 50) @From(BatchGenerator.class) ModelBatch batch) throws IOException {
ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
batch.writeToStream(byteArrayOutputStream);
byteArrayOutputStream.close();
byte[] bytes = byteArrayOutputStream.toByteArray();
ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(bytes);
ModelBatch recovered = new ModelBatch(byteArrayInputStream, (model) -> {
model.getModelMetaDataByReference().put("testing", "true");
});
byteArrayInputStream.close();
assertEquals(batch.size(), recovered.size());
for (int i = 0; i < batch.size(); i++) {
assertEquals("true", recovered.get(i).getModelMetaDataByReference().get("testing"));
}
}
@Theory
public void testProtoBatchWithoutFactors(@ForAll(sampleSize = 50) @From(BatchGenerator.class) ModelBatch batch) throws IOException {
ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
batch.writeToStreamWithoutFactors(byteArrayOutputStream);
byteArrayOutputStream.close();
byte[] bytes = byteArrayOutputStream.toByteArray();
ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(bytes);
ModelBatch recovered = new ModelBatch(byteArrayInputStream);
byteArrayInputStream.close();
assertEquals(batch.size(), recovered.size());
for (int i = 0; i < batch.size(); i++) {
assertEquals(0, recovered.get(i).factors.size());
assertTrue(batch.get(i).getModelMetaDataByReference().equals(recovered.get(i).getModelMetaDataByReference()));
for (int j = 0; j < batch.get(i).getVariableSizes().length; j++) {
assertTrue(batch.get(i).getVariableMetaDataByReference(j).equals(recovered.get(i).getVariableMetaDataByReference(j)));
}
}
}
public static class BatchGenerator extends Generator<ModelBatch> {
GraphicalModelTest.GraphicalModelGenerator modelGenerator = new GraphicalModelTest.GraphicalModelGenerator(GraphicalModel.class);
public BatchGenerator(Class<ModelBatch> type) {
super(type);
}
@Override
public ModelBatch generate(SourceOfRandomness sourceOfRandomness, GenerationStatus generationStatus) {
int length = sourceOfRandomness.nextInt(0, 50);
ModelBatch batch = new ModelBatch();
for (int i = 0; i < length; i++) {
batch.add(modelGenerator.generate(sourceOfRandomness, generationStatus));
}
return batch;
}
}
}