package org.nd4j.linalg.activations; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.activations.impl.*; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.shade.jackson.databind.*; import java.util.ArrayList; import java.util.Arrays; import java.util.Iterator; import java.util.List; import static junit.framework.TestCase.assertTrue; import static org.junit.Assert.assertEquals; /** * Created by Alex on 30/12/2016. */ @RunWith(Parameterized.class) public class TestActivationJson extends BaseNd4jTest { public TestActivationJson(Nd4jBackend backend) { super(backend); } @Override public char ordering() { return 'c'; } private ObjectMapper mapper; @Before public void initMapper() { mapper = new ObjectMapper(); mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); mapper.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false); mapper.configure(MapperFeature.SORT_PROPERTIES_ALPHABETICALLY, true); mapper.enable(SerializationFeature.INDENT_OUTPUT); } @Test public void testJson() throws Exception { IActivation[] activations = new IActivation[] {new ActivationCube(), new ActivationELU(0.25), new ActivationHardSigmoid(), new ActivationHardTanH(), new ActivationIdentity(), new ActivationLReLU(0.25), new ActivationRationalTanh(), new ActivationReLU(), new ActivationRReLU(0.25, 0.5), new ActivationSigmoid(), new ActivationSoftmax(), new ActivationSoftPlus(), new ActivationSoftSign(), new ActivationTanH()}; String[][] expectedFields = new String[][] {{}, //Cube {"alpha"}, //ELU {}, //Hard sigmoid {}, //Hard TanH {}, //Identity {"alpha"}, //Leaky Relu {}, //rational tanh {}, //relu {"l", "u"}, //rrelu {}, //sigmoid {}, //Softmax {}, //Softplus {}, //Softsign {} //Tanh }; for (int i = 0; i < activations.length; i++) { String asJson = mapper.writeValueAsString(activations[i]); System.out.println(asJson); JsonNode node = mapper.readTree(asJson); JsonNode content = node.elements().next(); Iterator<String> fieldNamesIter = content.fieldNames(); List<String> actualFieldsByName = new ArrayList<>(); while (fieldNamesIter.hasNext()) { actualFieldsByName.add(fieldNamesIter.next()); } String[] expFields = expectedFields[i]; String msg = activations[i].toString() + "\tExpected fields: " + Arrays.toString(expFields) + "\tActual fields: " + actualFieldsByName; assertEquals(msg, expFields.length, actualFieldsByName.size()); for (String s : expFields) { msg = "Expected field \"" + s + "\", was not found in " + activations[i].toString(); assertTrue(msg, actualFieldsByName.contains(s)); } //Test conversion from JSON: IActivation act = mapper.readValue(asJson, IActivation.class); assertEquals(activations[i], act); } } }