package org.nd4j.serde.gson; import org.junit.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import static org.junit.Assert.assertEquals; public class GsonDeserializationUtilsTest { @Test public void deserializeRawJson_PassInInRank3Array_ExpectCorrectDeserialization() { String serializedRawArray = "[[[1.00, 11.00, 3.00],\n" + "[13.00, 5.00, 15.00],\n" + "[7.00, 17.00, 9.00]]]"; INDArray expectedArray = buildExpectedArray(1, 1, 3, 3); INDArray indArray = GsonDeserializationUtils.deserializeRawJson(serializedRawArray); assertEquals(expectedArray, indArray); } @Test public void deserializeRawJson_ArrayHasOnlyOneRowWithColumns_ExpectCorrectDeserialization() { String serializedRawArray = "[1.00, 11.00, 3.00]"; INDArray expectedArray = Nd4j.create(new double[] {1, 11, 3}); INDArray indArray = GsonDeserializationUtils.deserializeRawJson(serializedRawArray); assertEquals(expectedArray, indArray); } @Test public void deserializeRawJson_ArrayIsRankFive_ExpectCorrectDeserialization() { String serializedRawArray = "[[[[[1.00, 11.00],\n" + " [3.00, 13.00]],\n" + " [[5.00, 15.00],\n" + " [7.00, 17.00]]],\n" + " [[[9.00, 1.00],\n" + " [11.00, 3.00]],\n" + " [[13.00, 5.00],\n" + " [15.00, 7.00]]],\n" + " [[[17.00, 9.00],\n" + " [1.00, 11.00]],\n" + " [[3.00, 13.00],\n" + " [5.00, 15.00]]]],\n" + " [[[[7.00, 17.00],\n" + " [9.00, 1.00]],\n" + " [[11.00, 3.00],\n" + " [13.00, 5.00]]],\n" + " [[[15.00, 7.00],\n" + " [17.00, 9.00]],\n" + " [[1.00, 11.00],\n" + " [3.00, 13.00]]],\n" + " [[[5.00, 15.00],\n" + " [7.00, 17.00]],\n" + " [[9.00, 1.00],\n" + " [11.00, 3.00]]]],\n" + " [[[[13.00, 5.00],\n" + " [15.00, 7.00]],\n" + " [[17.00, 9.00],\n" + " [1.00, 11.00]]],\n" + " [[[3.00, 13.00],\n" + " [5.00, 15.00]],\n" + " [[7.00, 17.00],\n" + " [9.00, 1.00]]],\n" + " [[[11.00, 3.00],\n" + " [13.00, 5.00]],\n" + " [[15.00, 7.00],\n" + " [17.00, 9.00]]]]]"; INDArray expectedArray = buildExpectedArray(8, 3, 3, 2, 2, 2); INDArray array = GsonDeserializationUtils.deserializeRawJson(serializedRawArray); assertEquals(expectedArray, array); } @Test public void deserializeRawJson_HaveCommaInsideNumbers_ExpectCorrectDeserialization() { String serializedRawArray = "[[1.00, 1,100.00, 3.00],\n" + "[13.00, 5.00, 15,591.00],\n" + "[7,000.00, 17.00, 9.00]]"; INDArray expectedArray = Nd4j.create(new double[] {1, 1100, 3, 13, 5, 15591, 7000, 17, 9}, new int[] {3, 3}); INDArray indArray = GsonDeserializationUtils.deserializeRawJson(serializedRawArray); assertEquals(expectedArray, indArray); } private INDArray buildExpectedArray(int numberOfTripletRows, int... shape) { INDArray expectedArray = Nd4j.create(3 * numberOfTripletRows, 3); for (int i = 0; i < numberOfTripletRows; i++) { int index = 3 * i; expectedArray.putRow(index, Nd4j.create(new double[] {1, 11, 3})); expectedArray.putRow(index + 1, Nd4j.create(new double[] {13, 5, 15})); expectedArray.putRow(index + 2, Nd4j.create(new double[] {7, 17, 9})); } return expectedArray.reshape(shape); } }