package org.deeplearning4j.regressiontest;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.distribution.*;
import org.junit.Test;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
/**
* Created by Alex on 08/05/2017.
*/
public class TestDistributionDeserializer {
@Test
public void testDistributionDeserializer() throws Exception {
//Test current format:
Distribution[] distributions = new Distribution[]{
new NormalDistribution(3,0.5),
new UniformDistribution(-2,1),
new GaussianDistribution(2,1.0),
new BinomialDistribution(10,0.3)
};
ObjectMapper om = NeuralNetConfiguration.mapper();
for( Distribution d : distributions){
String json = om.writeValueAsString(d);
Distribution fromJson = om.readValue(json, Distribution.class);
assertEquals(d, fromJson);
}
}
@Test
public void testDistributionDeserializerLegacyFormat() throws Exception {
ObjectMapper om = NeuralNetConfiguration.mapper();
String normalJson = "{\n" +
" \"normal\" : {\n" +
" \"mean\" : 0.1,\n" +
" \"std\" : 1.2\n" +
" }\n" +
" }";
Distribution nd = om.readValue(normalJson, Distribution.class);
assertTrue(nd instanceof NormalDistribution);
NormalDistribution normDist = (NormalDistribution)nd;
assertEquals(0.1, normDist.getMean(), 1e-6);
assertEquals(1.2, normDist.getStd(), 1e-6);
String uniformJson = "{\n" +
" \"uniform\" : {\n" +
" \"lower\" : -1.1,\n" +
" \"upper\" : 2.2\n" +
" }\n" +
" }";
Distribution ud = om.readValue(uniformJson, Distribution.class);
assertTrue(ud instanceof UniformDistribution);
UniformDistribution uniDist = (UniformDistribution) ud;
assertEquals(-1.1, uniDist.getLower(), 1e-6);
assertEquals(2.2, uniDist.getUpper(), 1e-6);
String gaussianJson = "{\n" +
" \"gaussian\" : {\n" +
" \"mean\" : 0.1,\n" +
" \"std\" : 1.2\n" +
" }\n" +
" }";
Distribution gd = om.readValue(gaussianJson, Distribution.class);
assertTrue(gd instanceof GaussianDistribution);
GaussianDistribution gDist = (GaussianDistribution)gd;
assertEquals(0.1, gDist.getMean(), 1e-6);
assertEquals(1.2, gDist.getStd(), 1e-6);
String bernoulliJson = "{\n" +
" \"binomial\" : {\n" +
" \"numberOfTrials\" : 10,\n" +
" \"probabilityOfSuccess\" : 0.3\n" +
" }\n" +
" }";
Distribution bd = om.readValue(bernoulliJson, Distribution.class);
assertTrue(bd instanceof BinomialDistribution);
BinomialDistribution binDist = (BinomialDistribution)bd;
assertEquals(10, binDist.getNumberOfTrials());
assertEquals(0.3, binDist.getProbabilityOfSuccess(), 1e-6);
}
}