/*- * * * Copyright 2015 Skymind,Inc. * * * * Licensed under the Apache License, Version 2.0 (the "License"); * * you may not use this file except in compliance with the License. * * You may obtain a copy of the License at * * * * http://www.apache.org/licenses/LICENSE-2.0 * * * * Unless required by applicable law or agreed to in writing, software * * distributed under the License is distributed on an "AS IS" BASIS, * * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * * See the License for the specific language governing permissions and * * limitations under the License. * */ package org.nd4j.linalg.dataset; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.indexing.INDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.util.ArrayUtil; import org.nd4j.linalg.util.FeatureUtil; import java.io.*; import java.util.*; import static org.junit.Assert.*; import static org.nd4j.linalg.indexing.NDArrayIndex.all; import static org.nd4j.linalg.indexing.NDArrayIndex.interval; @RunWith(Parameterized.class) public class DataSetTest extends BaseNd4jTest { public DataSetTest(Nd4jBackend backend) { super(backend); } @Test public void testViewIterator() { DataSetIterator iter = new ViewIterator(new IrisDataSetIterator(150, 150).next(), 10); assertTrue(iter.hasNext()); int count = 0; while (iter.hasNext()) { DataSet next = iter.next(); count++; assertArrayEquals(new int[] {10, 4}, next.getFeatureMatrix().shape()); } assertFalse(iter.hasNext()); assertEquals(15, count); iter.reset(); assertTrue(iter.hasNext()); } @Test public void testSplitTestAndTrain() throws Exception { INDArray labels = FeatureUtil.toOutcomeMatrix(new int[] {0, 0, 0, 0, 0, 0, 0, 0}, 1); DataSet data = new DataSet(Nd4j.rand(8, 1), labels); SplitTestAndTrain train = data.splitTestAndTrain(6, new Random(1)); assertEquals(train.getTrain().getLabels().length(), 6); SplitTestAndTrain train2 = data.splitTestAndTrain(6, new Random(1)); assertEquals(getFailureMessage(), train.getTrain().getFeatureMatrix(), train2.getTrain().getFeatureMatrix()); DataSet x0 = new IrisDataSetIterator(150, 150).next(); SplitTestAndTrain testAndTrain = x0.splitTestAndTrain(10); assertArrayEquals(new int[] {10, 4}, testAndTrain.getTrain().getFeatureMatrix().shape()); assertEquals(x0.getFeatureMatrix().getRows(ArrayUtil.range(0, 10)), testAndTrain.getTrain().getFeatureMatrix()); assertEquals(x0.getLabels().getRows(ArrayUtil.range(0, 10)), testAndTrain.getTrain().getLabels()); } @Test public void testSplitTestAndTrainRng() throws Exception { Random rngHere; DataSet x1 = new IrisDataSetIterator(150, 150).next(); //original DataSet x2 = x1.copy(); //call split test train with rng //Manual shuffle x1.shuffle(new Random(123).nextLong()); SplitTestAndTrain testAndTrain = x1.splitTestAndTrain(10); // Pass rng with splt test train rngHere = new Random(123); SplitTestAndTrain testAndTrainRng = x2.splitTestAndTrain(10, rngHere); assertArrayEquals(testAndTrainRng.getTrain().getFeatureMatrix().shape(), testAndTrain.getTrain().getFeatureMatrix().shape()); assertEquals(testAndTrainRng.getTrain().getFeatureMatrix(), testAndTrain.getTrain().getFeatureMatrix()); assertEquals(testAndTrainRng.getTrain().getLabels(), testAndTrain.getTrain().getLabels()); } @Test public void testLabelCounts() { DataSet x0 = new IrisDataSetIterator(150, 150).next(); assertEquals(getFailureMessage(), 0, x0.get(0).outcome()); assertEquals(getFailureMessage(), 0, x0.get(1).outcome()); assertEquals(getFailureMessage(), 2, x0.get(149).outcome()); Map<Integer, Double> counts = x0.labelCounts(); assertEquals(getFailureMessage(), 50, counts.get(0), 1e-1); assertEquals(getFailureMessage(), 50, counts.get(1), 1e-1); assertEquals(getFailureMessage(), 50, counts.get(2), 1e-1); } @Test public void testTimeSeriesMerge() { //Basic test for time series, all of the same length + no masking arrays int numExamples = 10; int inSize = 13; int labelSize = 5; int tsLength = 15; Nd4j.getRandom().setSeed(12345); List<DataSet> list = new ArrayList<>(numExamples); for (int i = 0; i < numExamples; i++) { INDArray in = Nd4j.rand(new int[] {1, inSize, tsLength}); INDArray out = Nd4j.rand(new int[] {1, labelSize, tsLength}); list.add(new DataSet(in, out)); } DataSet merged = DataSet.merge(list); assertEquals(numExamples, merged.numExamples()); INDArray f = merged.getFeatures(); INDArray l = merged.getLabels(); assertArrayEquals(new int[] {numExamples, inSize, tsLength}, f.shape()); assertArrayEquals(new int[] {numExamples, labelSize, tsLength}, l.shape()); for (int i = 0; i < numExamples; i++) { DataSet exp = list.get(i); INDArray expIn = exp.getFeatureMatrix(); INDArray expL = exp.getLabels(); INDArray fSubset = f.get(interval(i, i + 1), all(), all()); INDArray lSubset = l.get(interval(i, i + 1), all(), all()); assertEquals(expIn, fSubset); assertEquals(expL, lSubset); } } @Test public void testTimeSeriesMergeDifferentLength() { //Test merging of time series with different lengths -> no masking arrays on the input DataSets int numExamples = 10; int inSize = 13; int labelSize = 5; int minTSLength = 10; //Lengths 10, 11, ..., 19 Nd4j.getRandom().setSeed(12345); List<DataSet> list = new ArrayList<>(numExamples); for (int i = 0; i < numExamples; i++) { INDArray in = Nd4j.rand(new int[] {1, inSize, minTSLength + i}); INDArray out = Nd4j.rand(new int[] {1, labelSize, minTSLength + i}); list.add(new DataSet(in, out)); } DataSet merged = DataSet.merge(list); assertEquals(numExamples, merged.numExamples()); INDArray f = merged.getFeatures(); INDArray l = merged.getLabels(); int expectedLength = minTSLength + numExamples - 1; assertArrayEquals(new int[] {numExamples, inSize, expectedLength}, f.shape()); assertArrayEquals(new int[] {numExamples, labelSize, expectedLength}, l.shape()); assertTrue(merged.hasMaskArrays()); assertNotNull(merged.getFeaturesMaskArray()); assertNotNull(merged.getLabelsMaskArray()); INDArray featuresMask = merged.getFeaturesMaskArray(); INDArray labelsMask = merged.getLabelsMaskArray(); assertArrayEquals(new int[] {numExamples, expectedLength}, featuresMask.shape()); assertArrayEquals(new int[] {numExamples, expectedLength}, labelsMask.shape()); //Check each row individually: for (int i = 0; i < numExamples; i++) { DataSet exp = list.get(i); INDArray expIn = exp.getFeatureMatrix(); INDArray expL = exp.getLabels(); int thisRowOriginalLength = minTSLength + i; INDArray fSubset = f.get(interval(i, i + 1), all(), all()); INDArray lSubset = l.get(interval(i, i + 1), all(), all()); for (int j = 0; j < inSize; j++) { for (int k = 0; k < thisRowOriginalLength; k++) { double expected = expIn.getDouble(0, j, k); double act = fSubset.getDouble(0, j, k); if (Math.abs(expected - act) > 1e-3) { System.out.println(expIn); System.out.println(fSubset); } assertEquals(expected, act, 1e-3f); } //Padded values: should be exactly 0.0 for (int k = thisRowOriginalLength; k < expectedLength; k++) { assertEquals(0.0, fSubset.getDouble(0, j, k), 0.0); } } for (int j = 0; j < labelSize; j++) { for (int k = 0; k < thisRowOriginalLength; k++) { double expected = expL.getDouble(0, j, k); double act = lSubset.getDouble(0, j, k); assertEquals(expected, act, 1e-3f); } //Padded values: should be exactly 0.0 for (int k = thisRowOriginalLength; k < expectedLength; k++) { assertEquals(0.0, lSubset.getDouble(0, j, k), 0.0); } } //Check mask values: for (int j = 0; j < expectedLength; j++) { double expected = (j >= thisRowOriginalLength ? 0.0 : 1.0); double actFMask = featuresMask.getDouble(i, j); double actLMask = labelsMask.getDouble(i, j); if (expected != actFMask) { System.out.println(featuresMask); System.out.println(j); } assertEquals(expected, actFMask, 0.0); assertEquals(expected, actLMask, 0.0); } } } @Test public void testTimeSeriesMergeWithMasking() { //Test merging of time series with (a) different lengths, and (b) mask arrays in the input DataSets int numExamples = 10; int inSize = 13; int labelSize = 5; int minTSLength = 10; //Lengths 10, 11, ..., 19 Random r = new Random(12345); Nd4j.getRandom().setSeed(12345); List<DataSet> list = new ArrayList<>(numExamples); for (int i = 0; i < numExamples; i++) { INDArray in = Nd4j.rand(new int[] {1, inSize, minTSLength + i}); INDArray out = Nd4j.rand(new int[] {1, labelSize, minTSLength + i}); INDArray inMask = Nd4j.create(1, minTSLength + i); INDArray outMask = Nd4j.create(1, minTSLength + i); for (int j = 0; j < inMask.size(1); j++) { inMask.putScalar(j, (r.nextBoolean() ? 1.0 : 0.0)); outMask.putScalar(j, (r.nextBoolean() ? 1.0 : 0.0)); } list.add(new DataSet(in, out, inMask, outMask)); } DataSet merged = DataSet.merge(list); assertEquals(numExamples, merged.numExamples()); INDArray f = merged.getFeatures(); INDArray l = merged.getLabels(); int expectedLength = minTSLength + numExamples - 1; assertArrayEquals(new int[] {numExamples, inSize, expectedLength}, f.shape()); assertArrayEquals(new int[] {numExamples, labelSize, expectedLength}, l.shape()); assertTrue(merged.hasMaskArrays()); assertNotNull(merged.getFeaturesMaskArray()); assertNotNull(merged.getLabelsMaskArray()); INDArray featuresMask = merged.getFeaturesMaskArray(); INDArray labelsMask = merged.getLabelsMaskArray(); assertArrayEquals(new int[] {numExamples, expectedLength}, featuresMask.shape()); assertArrayEquals(new int[] {numExamples, expectedLength}, labelsMask.shape()); //Check each row individually: for (int i = 0; i < numExamples; i++) { DataSet original = list.get(i); INDArray expIn = original.getFeatureMatrix(); INDArray expL = original.getLabels(); INDArray origMaskF = original.getFeaturesMaskArray(); INDArray origMaskL = original.getLabelsMaskArray(); int thisRowOriginalLength = minTSLength + i; INDArray fSubset = f.get(interval(i, i + 1), all(), all()); INDArray lSubset = l.get(interval(i, i + 1), all(), all()); for (int j = 0; j < inSize; j++) { for (int k = 0; k < thisRowOriginalLength; k++) { double expected = expIn.getDouble(0, j, k); double act = fSubset.getDouble(0, j, k); if (Math.abs(expected - act) > 1e-3) { System.out.println(expIn); System.out.println(fSubset); } assertEquals(expected, act, 1e-3f); } //Padded values: should be exactly 0.0 for (int k = thisRowOriginalLength; k < expectedLength; k++) { assertEquals(0.0, fSubset.getDouble(0, j, k), 0.0); } } for (int j = 0; j < labelSize; j++) { for (int k = 0; k < thisRowOriginalLength; k++) { double expected = expL.getDouble(0, j, k); double act = lSubset.getDouble(0, j, k); assertEquals(expected, act, 1e-3f); } //Padded values: should be exactly 0.0 for (int k = thisRowOriginalLength; k < expectedLength; k++) { assertEquals(0.0, lSubset.getDouble(0, j, k), 0.0); } } //Check mask values: for (int j = 0; j < expectedLength; j++) { double expectedF; double expectedL; if (j >= thisRowOriginalLength) { //Outside of original data bounds -> should be 0 expectedF = 0.0; expectedL = 0.0; } else { //Value should be same as original mask array value expectedF = origMaskF.getDouble(j); expectedL = origMaskL.getDouble(j); } double actFMask = featuresMask.getDouble(i, j); double actLMask = labelsMask.getDouble(i, j); assertEquals(expectedF, actFMask, 0.0); assertEquals(expectedL, actLMask, 0.0); } } } @Test public void testCnnMerge() { //Test merging of CNN data sets int nOut = 3; int width = 5; int height = 4; int depth = 3; int nExamples1 = 2; int nExamples2 = 1; int length1 = width * height * depth * nExamples1; int length2 = width * height * depth * nExamples2; INDArray first = Nd4j.linspace(1, length1, length1).reshape('c', nExamples1, depth, width, height); INDArray second = Nd4j.linspace(1, length2, length2).reshape('c', nExamples2, depth, width, height).addi(0.1); INDArray labels1 = Nd4j.linspace(1, nExamples1 * nOut, nExamples1 * nOut).reshape('c', nExamples1, nOut); INDArray labels2 = Nd4j.linspace(1, nExamples2 * nOut, nExamples2 * nOut).reshape('c', nExamples2, nOut); DataSet ds1 = new DataSet(first, labels1); DataSet ds2 = new DataSet(second, labels2); DataSet merged = DataSet.merge(Arrays.asList(ds1, ds2)); INDArray fMerged = merged.getFeatureMatrix(); INDArray lMerged = merged.getLabels(); assertArrayEquals(new int[] {nExamples1 + nExamples2, depth, width, height}, fMerged.shape()); assertArrayEquals(new int[] {nExamples1 + nExamples2, nOut}, lMerged.shape()); assertEquals(first, fMerged.get(interval(0, nExamples1), all(), all(), all())); assertEquals(second, fMerged.get(interval(nExamples1, nExamples1 + nExamples2, true), all(), all(), all())); assertEquals(labels1, lMerged.get(interval(0, nExamples1), all())); assertEquals(labels2, lMerged.get(interval(nExamples1, nExamples1 + nExamples2), all())); } @Test public void testMixedRnn2dMerging() { //RNN input with 2d label output //Basic test for time series, all of the same length + no masking arrays int numExamples = 10; int inSize = 13; int labelSize = 5; int tsLength = 15; Nd4j.getRandom().setSeed(12345); List<DataSet> list = new ArrayList<>(numExamples); for (int i = 0; i < numExamples; i++) { INDArray in = Nd4j.rand(new int[] {1, inSize, tsLength}); INDArray out = Nd4j.rand(new int[] {1, labelSize}); list.add(new DataSet(in, out)); } DataSet merged = DataSet.merge(list); assertEquals(numExamples, merged.numExamples()); INDArray f = merged.getFeatures(); INDArray l = merged.getLabels(); assertArrayEquals(new int[] {numExamples, inSize, tsLength}, f.shape()); assertArrayEquals(new int[] {numExamples, labelSize}, l.shape()); for (int i = 0; i < numExamples; i++) { DataSet exp = list.get(i); INDArray expIn = exp.getFeatureMatrix(); INDArray expL = exp.getLabels(); INDArray fSubset = f.get(interval(i, i + 1), all(), all()); INDArray lSubset = l.get(interval(i, i + 1), all()); assertEquals(expIn, fSubset); assertEquals(expL, lSubset); } } @Test public void testMergingWithPerOutputMasking(){ //Test 2d mask merging, 2d data //features INDArray f2d1 = Nd4j.create(new double[]{1, 2, 3}); INDArray f2d2 = Nd4j.create(new double[][]{{4, 5, 6}, {7, 8, 9}}); //labels INDArray l2d1 = Nd4j.create(new double[]{1.5, 2.5, 3.5}); INDArray l2d2 = Nd4j.create(new double[][]{{4.5, 5.5, 6.5}, {7.5, 8.5, 9.5}}); //feature masks INDArray fm2d1 = Nd4j.create(new double[]{0, 1, 1}); INDArray fm2d2 = Nd4j.create(new double[][]{{1, 0, 1}, {0, 1, 0}}); //label masks INDArray lm2d1 = Nd4j.create(new double[]{1, 1, 0}); INDArray lm2d2 = Nd4j.create(new double[][]{{1, 0, 0}, {0, 1, 1}}); DataSet mds2d1 = new DataSet(f2d1, l2d1, fm2d1, lm2d1); DataSet mds2d2 = new DataSet(f2d2, l2d2, fm2d2, lm2d2); DataSet merged = DataSet.merge(Arrays.asList(mds2d1, mds2d2)); INDArray expFeatures2d = Nd4j.create(new double[][]{{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); INDArray expLabels2d = Nd4j.create(new double[][]{{1.5, 2.5, 3.5}, {4.5, 5.5, 6.5}, {7.5, 8.5, 9.5}}); INDArray expFM2d = Nd4j.create(new double[][]{{0, 1, 1}, {1, 0, 1}, {0, 1, 0}}); INDArray expLM2d = Nd4j.create(new double[][]{{1, 1, 0}, {1, 0, 0}, {0, 1, 1}}); DataSet dsExp2d = new DataSet(expFeatures2d, expLabels2d, expFM2d, expLM2d); assertEquals(dsExp2d, merged); //Test 4d features, 2d labels, 2d masks INDArray f4d1 = Nd4j.create(1,3,5,5); INDArray f4d2 = Nd4j.create(2,3,5,5); DataSet ds4d1 = new DataSet(f4d1, l2d1, null, lm2d1); DataSet ds4d2 = new DataSet(f4d2, l2d2, null, lm2d2); DataSet merged4d = DataSet.merge(Arrays.asList(ds4d1, ds4d2)); assertEquals(expLabels2d, merged4d.getLabels()); assertEquals(expLM2d, merged4d.getLabelsMaskArray()); //Test 3d mask merging, 3d data INDArray f3d1 = Nd4j.create(1,3,4); INDArray f3d2 = Nd4j.create(1,3,3); INDArray l3d1 = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.create(1,3,4),0.5)); INDArray l3d2 = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.create(2,3,3),0.5)); INDArray lm3d1 = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.create(1,3,4),0.5)); INDArray lm3d2 = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.create(2,3,3),0.5)); DataSet ds3d1 = new DataSet(f3d1, l3d1, null, lm3d1); DataSet ds3d2 = new DataSet(f3d2, l3d2, null, lm3d2); INDArray expLabels3d = Nd4j.create(3,3,4); expLabels3d.put(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.interval(0,4)}, l3d1 ); expLabels3d.put(new INDArrayIndex[]{NDArrayIndex.interval(1,2,true), NDArrayIndex.all(), NDArrayIndex.interval(0,3)}, l3d2 ); INDArray expLM3d = Nd4j.create(3,3,4); expLM3d.put(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.interval(0,4)}, lm3d1 ); expLM3d.put(new INDArrayIndex[]{NDArrayIndex.interval(1,2,true), NDArrayIndex.all(), NDArrayIndex.interval(0,3)}, lm3d2 ); DataSet merged3d = DataSet.merge(Arrays.asList(ds3d1, ds3d2)); assertEquals(expLabels3d, merged3d.getLabels()); assertEquals(expLM3d, merged3d.getLabelsMaskArray()); //Test 3d features, 2d masks, 2d output (for example: RNN -> global pooling w/ per-output masking) DataSet ds3d2d1 = new DataSet(f3d1, l2d1, null, lm2d1); DataSet ds3d2d2 = new DataSet(f3d2, l2d2, null, lm2d2); DataSet merged3d2d = DataSet.merge(Arrays.asList(ds3d2d1, ds3d2d2)); assertEquals(expLabels2d, merged3d2d.getLabels()); assertEquals(expLM2d, merged3d2d.getLabelsMaskArray()); } @Test public void testShuffle4d() { int nSamples = 10; int nChannels = 3; int imgRows = 4; int imgCols = 2; int nLabels = 5; int[] shape = new int[] {nSamples, nChannels, imgRows, imgCols}; int entries = nSamples * nChannels * imgRows * imgCols; int labels = nSamples * nLabels; INDArray ds_data = Nd4j.linspace(1, entries, entries).reshape(nSamples, nChannels, imgRows, imgCols); INDArray ds_labels = Nd4j.linspace(1, labels, labels).reshape(nSamples, nLabels); DataSet ds = new DataSet(ds_data, ds_labels); ds.shuffle(); for (int dim = 1; dim < 4; dim++) { //get tensor along dimension - the order in every dimension but zero should be preserved for (int tensorNum = 0; tensorNum < entries / shape[dim]; tensorNum++) { for (int i = 0, j = 1; j < shape[dim]; i++, j++) { int f_element = ds.getFeatures().tensorAlongDimension(tensorNum, dim).getInt(i); int f_next_element = ds.getFeatures().tensorAlongDimension(tensorNum, dim).getInt(j); int f_element_diff = f_next_element - f_element; assertTrue(f_element_diff == ds_data.stride(dim)); } } } } @Test public void testShuffleNd() { int numDims = 7; int nLabels = 3; Random r = new Random(); int[] shape = new int[numDims]; int entries = 1; for (int i = 0; i < numDims; i++) { //randomly generating shapes bigger than 1 shape[i] = r.nextInt(4) + 2; entries *= shape[i]; } int labels = shape[0] * nLabels; INDArray ds_data = Nd4j.linspace(1, entries, entries).reshape(shape); INDArray ds_labels = Nd4j.linspace(1, labels, labels).reshape(shape[0], nLabels); DataSet ds = new DataSet(ds_data, ds_labels); ds.shuffle(); //Checking Nd dataset which is the data for (int dim = 1; dim < numDims; dim++) { //get tensor along dimension - the order in every dimension but zero should be preserved for (int tensorNum = 0; tensorNum < ds_data.tensorssAlongDimension(dim); tensorNum++) { //the difference between consecutive elements should be equal to the stride for (int i = 0, j = 1; j < shape[dim]; i++, j++) { int f_element = ds.getFeatures().tensorAlongDimension(tensorNum, dim).getInt(i); int f_next_element = ds.getFeatures().tensorAlongDimension(tensorNum, dim).getInt(j); int f_element_diff = f_next_element - f_element; assertTrue(f_element_diff == ds_data.stride(dim)); } } } //Checking 2d, features int dim = 1; //get tensor along dimension - the order in every dimension but zero should be preserved for (int tensorNum = 0; tensorNum < ds_labels.tensorssAlongDimension(dim); tensorNum++) { //the difference between consecutive elements should be equal to the stride for (int i = 0, j = 1; j < nLabels; i++, j++) { int l_element = ds.getLabels().tensorAlongDimension(tensorNum, dim).getInt(i); int l_next_element = ds.getLabels().tensorAlongDimension(tensorNum, dim).getInt(j); int l_element_diff = l_next_element - l_element; assertTrue(l_element_diff == ds_labels.stride(dim)); } } } @Test public void testShuffleMeta() { int nExamples = 20; int nColumns = 4; INDArray f = Nd4j.zeros(nExamples, nColumns); INDArray l = Nd4j.zeros(nExamples, nColumns); List<Integer> meta = new ArrayList<>(); for (int i = 0; i < nExamples; i++) { f.getRow(i).assign(i); l.getRow(i).assign(i); meta.add(i); } DataSet ds = new DataSet(f, l); ds.setExampleMetaData(meta); for (int i = 0; i < 10; i++) { ds.shuffle(); INDArray fCol = f.getColumn(0); INDArray lCol = l.getColumn(0); System.out.println(fCol + "\t" + ds.getExampleMetaData()); for (int j = 0; j < nExamples; j++) { int fVal = (int) fCol.getDouble(j); int lVal = (int) lCol.getDouble(j); int metaVal = (Integer) ds.getExampleMetaData().get(j); assertEquals(fVal, lVal); assertEquals(fVal, metaVal); } } } @Test public void testLabelNames() { List<String> names = Arrays.asList("label1", "label2", "label3", "label0"); INDArray features = Nd4j.ones(10); INDArray labels = Nd4j.linspace(0, 3, 4); org.nd4j.linalg.dataset.api.DataSet ds = new DataSet(features, labels); ds.setLabelNames(names); assertEquals("label1", ds.getLabelName(0)); assertEquals(4, ds.getLabelNamesList().size()); assertEquals(names, ds.getLabelNames(labels)); } @Test public void testToString() { org.nd4j.linalg.dataset.api.DataSet ds = new DataSet(); //this should not throw a null pointer System.out.println(ds); //Checking printing of masks int numExamples = 10; int inSize = 13; int labelSize = 5; int minTSLength = 10; //Lengths 10, 11, ..., 19 Nd4j.getRandom().setSeed(12345); List<DataSet> list = new ArrayList<>(numExamples); for (int i = 0; i < numExamples; i++) { INDArray in = Nd4j.rand(new int[] {1, inSize, minTSLength + i}); INDArray out = Nd4j.rand(new int[] {1, labelSize, minTSLength + i}); list.add(new DataSet(in, out)); } ds = DataSet.merge(list); System.out.println(ds); } @Test public void testGetRangeMask() { org.nd4j.linalg.dataset.api.DataSet ds = new DataSet(); //Checking printing of masks int numExamples = 10; int inSize = 13; int labelSize = 5; int minTSLength = 10; //Lengths 10, 11, ..., 19 Nd4j.getRandom().setSeed(12345); List<DataSet> list = new ArrayList<>(numExamples); for (int i = 0; i < numExamples; i++) { INDArray in = Nd4j.rand(new int[] {1, inSize, minTSLength + i}); INDArray out = Nd4j.rand(new int[] {1, labelSize, minTSLength + i}); list.add(new DataSet(in, out)); } int from = 3; int to = 9; ds = DataSet.merge(list); org.nd4j.linalg.dataset.api.DataSet newDs = ds.getRange(from, to); //The feature mask does not have to be equal to the label mask, just in this ex it should be assertEquals(newDs.getLabelsMaskArray(), newDs.getFeaturesMaskArray()); //System.out.println(newDs); assertEquals(Nd4j.linspace(numExamples + from, numExamples + to - 1, to - from), newDs.getLabelsMaskArray().sum(1)); } @Test public void testAsList() { org.nd4j.linalg.dataset.api.DataSet ds; //Comparing merge with asList int numExamples = 10; int inSize = 13; int labelSize = 5; int minTSLength = 10; //Lengths 10, 11, ..., 19 Nd4j.getRandom().setSeed(12345); List<DataSet> list = new ArrayList<>(numExamples); for (int i = 0; i < numExamples; i++) { INDArray in = Nd4j.rand(new int[] {1, inSize, minTSLength + i}); INDArray out = Nd4j.rand(new int[] {1, labelSize, minTSLength + i}); list.add(new DataSet(in, out)); } //Merged dataset and dataset list ds = DataSet.merge(list); List<DataSet> dsList = ds.asList(); //Reset seed Nd4j.getRandom().setSeed(12345); for (int i = 0; i < numExamples; i++) { INDArray in = Nd4j.rand(new int[] {1, inSize, minTSLength + i}); INDArray out = Nd4j.rand(new int[] {1, labelSize, minTSLength + i}); DataSet iDataSet = new DataSet(in, out); //Checking if the features and labels are equal assertEquals(iDataSet.getFeatures(), dsList.get(i).getFeatures().get(all(), all(), interval(0, minTSLength + i))); assertEquals(iDataSet.getLabels(), dsList.get(i).getLabels().get(all(), all(), interval(0, minTSLength + i))); } } @Test public void testDataSetSaveLoad() throws IOException { boolean[] b = new boolean[] {true, false}; INDArray f = Nd4j.linspace(1, 24, 24).reshape('c', 4, 3, 2); INDArray l = Nd4j.linspace(24, 48, 24).reshape('c', 4, 3, 2); INDArray fm = Nd4j.linspace(100, 108, 8).reshape('c', 4, 2); INDArray lm = Nd4j.linspace(108, 116, 8).reshape('c', 4, 2); for (boolean features : b) { for (boolean labels : b) { for (boolean labelsSameAsFeatures : b) { if (labelsSameAsFeatures && (!features || !labels)) continue; //Can't have "labels same as features" if no features, or if no labels for (boolean fMask : b) { for (boolean lMask : b) { DataSet ds = new DataSet((features ? f : null), (labels ? (labelsSameAsFeatures ? f : l) : null), (fMask ? fm : null), (lMask ? lm : null)); ByteArrayOutputStream baos = new ByteArrayOutputStream(); DataOutputStream dos = new DataOutputStream(baos); ds.save(dos); byte[] asBytes = baos.toByteArray(); ByteArrayInputStream bais = new ByteArrayInputStream(asBytes); DataInputStream dis = new DataInputStream(bais); DataSet ds2 = new DataSet(); ds2.load(dis); dis.close(); assertEquals(ds, ds2); if (labelsSameAsFeatures) assertTrue(ds2.getFeatureMatrix() == ds2.getLabels()); //Expect same object } } } } } } @Test public void testDataSetSaveLoadSingle() throws IOException { INDArray f = Nd4j.linspace(1, 24, 24).reshape('c', 4, 3, 2); INDArray l = Nd4j.linspace(24, 48, 24).reshape('c', 4, 3, 2); INDArray fm = Nd4j.linspace(100, 108, 8).reshape('c', 4, 2); INDArray lm = Nd4j.linspace(108, 116, 8).reshape('c', 4, 2); boolean features = true; boolean labels = false; boolean labelsSameAsFeatures = false; boolean fMask = true; boolean lMask = true; DataSet ds = new DataSet((features ? f : null), (labels ? (labelsSameAsFeatures ? f : l) : null), (fMask ? fm : null), (lMask ? lm : null)); ByteArrayOutputStream baos = new ByteArrayOutputStream(); DataOutputStream dos = new DataOutputStream(baos); ds.save(dos); dos.close(); byte[] asBytes = baos.toByteArray(); ByteArrayInputStream bais = new ByteArrayInputStream(asBytes); DataInputStream dis = new DataInputStream(bais); DataSet ds2 = new DataSet(); ds2.load(dis); dis.close(); assertEquals(ds, ds2); if (labelsSameAsFeatures) assertTrue(ds2.getFeatureMatrix() == ds2.getLabels()); //Expect same object } @Override public char ordering() { return 'f'; } }