package org.nd4j.linalg.dataset;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
import org.nd4j.linalg.api.ops.random.impl.BinomialDistribution;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import java.io.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import static org.junit.Assert.*;
@RunWith(Parameterized.class)
public class MultiDataSetTest extends BaseNd4jTest {
public MultiDataSetTest(Nd4jBackend backend) {
super(backend);
}
@Test
public void testMerging2d() {
//Simple test: single input/output arrays; 5 MultiDataSets to merge
int nCols = 3;
int nRows = 5;
INDArray expIn = Nd4j.linspace(0, nCols * nRows - 1, nCols * nRows).reshape(nRows, nCols);
INDArray expOut = Nd4j.linspace(100, 100 + nCols * nRows - 1, nCols * nRows).reshape(nRows, nCols);
INDArray[] in = new INDArray[nRows];
INDArray[] out = new INDArray[nRows];
for (int i = 0; i < nRows; i++)
in[i] = expIn.getRow(i).dup();
for (int i = 0; i < nRows; i++)
out[i] = expOut.getRow(i).dup();
List<MultiDataSet> list = new ArrayList<>(nRows);
for (int i = 0; i < nRows; i++) {
list.add(new MultiDataSet(in[i], out[i]));
}
MultiDataSet merged = MultiDataSet.merge(list);
assertEquals(1, merged.getFeatures().length);
assertEquals(1, merged.getLabels().length);
assertEquals(expIn, merged.getFeatures(0));
assertEquals(expOut, merged.getLabels(0));
}
@Test
public void testMerging2dMultipleInOut() {
//Test merging: Multiple input/output arrays; 5 MultiDataSets to merge
int nRows = 5;
int nColsIn0 = 3;
int nColsIn1 = 4;
int nColsOut0 = 5;
int nColsOut1 = 6;
INDArray expIn0 = Nd4j.linspace(0, nRows * nColsIn0 - 1, nRows * nColsIn0).reshape(nRows, nColsIn0);
INDArray expIn1 = Nd4j.linspace(0, nRows * nColsIn1 - 1, nRows * nColsIn1).reshape(nRows, nColsIn1);
INDArray expOut0 = Nd4j.linspace(0, nRows * nColsOut0 - 1, nRows * nColsOut0).reshape(nRows, nColsOut0);
INDArray expOut1 = Nd4j.linspace(0, nRows * nColsOut1 - 1, nRows * nColsOut1).reshape(nRows, nColsOut1);
List<MultiDataSet> list = new ArrayList<>(nRows);
for (int i = 0; i < nRows; i++) {
if (i == 0) {
//For first MultiDataSet: have 2 rows, not just 1
INDArray in0 = expIn0.get(NDArrayIndex.interval(0, 1, true), NDArrayIndex.all()).dup();
INDArray in1 = expIn1.get(NDArrayIndex.interval(0, 1, true), NDArrayIndex.all()).dup();
INDArray out0 = expOut0.get(NDArrayIndex.interval(0, 1, true), NDArrayIndex.all()).dup();
INDArray out1 = expOut1.get(NDArrayIndex.interval(0, 1, true), NDArrayIndex.all()).dup();
list.add(new MultiDataSet(new INDArray[] {in0, in1}, new INDArray[] {out0, out1}));
i++;
} else {
INDArray in0 = expIn0.getRow(i).dup();
INDArray in1 = expIn1.getRow(i).dup();
INDArray out0 = expOut0.getRow(i).dup();
INDArray out1 = expOut1.getRow(i).dup();
list.add(new MultiDataSet(new INDArray[] {in0, in1}, new INDArray[] {out0, out1}));
}
}
MultiDataSet merged = MultiDataSet.merge(list);
assertEquals(2, merged.getFeatures().length);
assertEquals(2, merged.getLabels().length);
assertEquals(expIn0, merged.getFeatures(0));
assertEquals(expIn1, merged.getFeatures(1));
assertEquals(expOut0, merged.getLabels(0));
assertEquals(expOut1, merged.getLabels(1));
}
@Test
public void testMerging2dMultipleInOut2() {
//Test merging: Multiple input/output arrays; 5 MultiDataSets to merge
int nRows = 10;
int nColsIn0 = 3;
int nColsIn1 = 4;
int nColsIn2 = 5;
int nColsOut0 = 6;
int nColsOut1 = 7;
int nColsOut2 = 8;
INDArray expIn0 = Nd4j.linspace(0, nRows * nColsIn0 - 1, nRows * nColsIn0).reshape(nRows, nColsIn0);
INDArray expIn1 = Nd4j.linspace(0, nRows * nColsIn1 - 1, nRows * nColsIn1).reshape(nRows, nColsIn1);
INDArray expIn2 = Nd4j.linspace(0, nRows * nColsIn2 - 1, nRows * nColsIn2).reshape(nRows, nColsIn2);
INDArray expOut0 = Nd4j.linspace(0, nRows * nColsOut0 - 1, nRows * nColsOut0).reshape(nRows, nColsOut0);
INDArray expOut1 = Nd4j.linspace(0, nRows * nColsOut1 - 1, nRows * nColsOut1).reshape(nRows, nColsOut1);
INDArray expOut2 = Nd4j.linspace(0, nRows * nColsOut2 - 1, nRows * nColsOut2).reshape(nRows, nColsOut2);
List<MultiDataSet> list = new ArrayList<>(nRows);
for (int i = 0; i < nRows; i++) {
if (i == 0) {
//For first MultiDataSet: have 2 rows, not just 1
INDArray in0 = expIn0.get(NDArrayIndex.interval(0, 1, true), NDArrayIndex.all()).dup();
INDArray in1 = expIn1.get(NDArrayIndex.interval(0, 1, true), NDArrayIndex.all()).dup();
INDArray in2 = expIn2.get(NDArrayIndex.interval(0, 1, true), NDArrayIndex.all()).dup();
INDArray out0 = expOut0.get(NDArrayIndex.interval(0, 1, true), NDArrayIndex.all()).dup();
INDArray out1 = expOut1.get(NDArrayIndex.interval(0, 1, true), NDArrayIndex.all()).dup();
INDArray out2 = expOut2.get(NDArrayIndex.interval(0, 1, true), NDArrayIndex.all()).dup();
list.add(new MultiDataSet(new INDArray[] {in0, in1, in2}, new INDArray[] {out0, out1, out2}));
i++;
} else {
INDArray in0 = expIn0.getRow(i).dup();
INDArray in1 = expIn1.getRow(i).dup();
INDArray in2 = expIn2.getRow(i).dup();
INDArray out0 = expOut0.getRow(i).dup();
INDArray out1 = expOut1.getRow(i).dup();
INDArray out2 = expOut2.getRow(i).dup();
list.add(new MultiDataSet(new INDArray[] {in0, in1, in2}, new INDArray[] {out0, out1, out2}));
}
}
MultiDataSet merged = MultiDataSet.merge(list);
assertEquals(3, merged.getFeatures().length);
assertEquals(3, merged.getLabels().length);
assertEquals(expIn0, merged.getFeatures(0));
assertEquals(expIn1, merged.getFeatures(1));
assertEquals(expIn2, merged.getFeatures(2));
assertEquals(expOut0, merged.getLabels(0));
assertEquals(expOut1, merged.getLabels(1));
assertEquals(expOut2, merged.getLabels(2));
}
@Test
public void testMerging2dMultipleInOut3() {
//Test merging: fewer rows than output arrays...
int nRows = 2;
int nColsIn0 = 3;
int nColsIn1 = 4;
int nColsIn2 = 5;
int nColsOut0 = 6;
int nColsOut1 = 7;
int nColsOut2 = 8;
INDArray expIn0 = Nd4j.linspace(0, nRows * nColsIn0 - 1, nRows * nColsIn0).reshape(nRows, nColsIn0);
INDArray expIn1 = Nd4j.linspace(0, nRows * nColsIn1 - 1, nRows * nColsIn1).reshape(nRows, nColsIn1);
INDArray expIn2 = Nd4j.linspace(0, nRows * nColsIn2 - 1, nRows * nColsIn2).reshape(nRows, nColsIn2);
INDArray expOut0 = Nd4j.linspace(0, nRows * nColsOut0 - 1, nRows * nColsOut0).reshape(nRows, nColsOut0);
INDArray expOut1 = Nd4j.linspace(0, nRows * nColsOut1 - 1, nRows * nColsOut1).reshape(nRows, nColsOut1);
INDArray expOut2 = Nd4j.linspace(0, nRows * nColsOut2 - 1, nRows * nColsOut2).reshape(nRows, nColsOut2);
List<MultiDataSet> list = new ArrayList<>(nRows);
for (int i = 0; i < nRows; i++) {
INDArray in0 = expIn0.getRow(i).dup();
INDArray in1 = expIn1.getRow(i).dup();
INDArray in2 = expIn2.getRow(i).dup();
INDArray out0 = expOut0.getRow(i).dup();
INDArray out1 = expOut1.getRow(i).dup();
INDArray out2 = expOut2.getRow(i).dup();
list.add(new MultiDataSet(new INDArray[] {in0, in1, in2}, new INDArray[] {out0, out1, out2}));
}
MultiDataSet merged = MultiDataSet.merge(list);
assertEquals(3, merged.getFeatures().length);
assertEquals(3, merged.getLabels().length);
assertEquals(expIn0, merged.getFeatures(0));
assertEquals(expIn1, merged.getFeatures(1));
assertEquals(expIn2, merged.getFeatures(2));
assertEquals(expOut0, merged.getLabels(0));
assertEquals(expOut1, merged.getLabels(1));
assertEquals(expOut2, merged.getLabels(2));
}
@Test
public void testMerging4dMultipleInOut() {
int nRows = 5;
int depthIn0 = 3;
int widthIn0 = 4;
int heightIn0 = 5;
int depthIn1 = 4;
int widthIn1 = 5;
int heightIn1 = 6;
int nColsOut0 = 5;
int nColsOut1 = 6;
int lengthIn0 = nRows * depthIn0 * widthIn0 * heightIn0;
int lengthIn1 = nRows * depthIn1 * widthIn1 * heightIn1;
INDArray expIn0 = Nd4j.linspace(0, lengthIn0 - 1, lengthIn0).reshape(nRows, depthIn0, widthIn0, heightIn0);
INDArray expIn1 = Nd4j.linspace(0, lengthIn1 - 1, lengthIn1).reshape(nRows, depthIn1, widthIn1, heightIn1);
INDArray expOut0 = Nd4j.linspace(0, nRows * nColsOut0 - 1, nRows * nColsOut0).reshape(nRows, nColsOut0);
INDArray expOut1 = Nd4j.linspace(0, nRows * nColsOut1 - 1, nRows * nColsOut1).reshape(nRows, nColsOut1);
List<MultiDataSet> list = new ArrayList<>(nRows);
for (int i = 0; i < nRows; i++) {
if (i == 0) {
//For first MultiDataSet: have 2 rows, not just 1
INDArray in0 = expIn0.get(NDArrayIndex.interval(0, 1, true), NDArrayIndex.all(), NDArrayIndex.all(),
NDArrayIndex.all()).dup();
INDArray in1 = expIn1.get(NDArrayIndex.interval(0, 1, true), NDArrayIndex.all(), NDArrayIndex.all(),
NDArrayIndex.all()).dup();
INDArray out0 = expOut0.get(NDArrayIndex.interval(0, 1, true), NDArrayIndex.all()).dup();
INDArray out1 = expOut1.get(NDArrayIndex.interval(0, 1, true), NDArrayIndex.all()).dup();
list.add(new MultiDataSet(new INDArray[] {in0, in1}, new INDArray[] {out0, out1}));
i++;
} else {
INDArray in0 = expIn0.get(NDArrayIndex.interval(i, i, true), NDArrayIndex.all(), NDArrayIndex.all(),
NDArrayIndex.all()).dup();
INDArray in1 = expIn1.get(NDArrayIndex.interval(i, i, true), NDArrayIndex.all(), NDArrayIndex.all(),
NDArrayIndex.all()).dup();
INDArray out0 = expOut0.getRow(i).dup();
INDArray out1 = expOut1.getRow(i).dup();
list.add(new MultiDataSet(new INDArray[] {in0, in1}, new INDArray[] {out0, out1}));
}
}
MultiDataSet merged = MultiDataSet.merge(list);
assertEquals(2, merged.getFeatures().length);
assertEquals(2, merged.getLabels().length);
assertEquals(expIn0, merged.getFeatures(0));
assertEquals(expIn1, merged.getFeatures(1));
assertEquals(expOut0, merged.getLabels(0));
assertEquals(expOut1, merged.getLabels(1));
}
@Test
public void testMergingTimeSeriesEqualLength() {
int tsLength = 8;
int nRows = 5;
int nColsIn0 = 3;
int nColsIn1 = 4;
int nColsOut0 = 5;
int nColsOut1 = 6;
int n0 = nRows * nColsIn0 * tsLength;
int n1 = nRows * nColsIn1 * tsLength;
int nOut0 = nRows * nColsOut0 * tsLength;
int nOut1 = nRows * nColsOut1 * tsLength;
INDArray expIn0 = Nd4j.linspace(0, n0 - 1, n0).reshape(nRows, nColsIn0, tsLength);
INDArray expIn1 = Nd4j.linspace(0, n1 - 1, n1).reshape(nRows, nColsIn1, tsLength);
INDArray expOut0 = Nd4j.linspace(0, nOut0 - 1, nOut0).reshape(nRows, nColsOut0, tsLength);
INDArray expOut1 = Nd4j.linspace(0, nOut1 - 1, nOut1).reshape(nRows, nColsOut1, tsLength);
List<MultiDataSet> list = new ArrayList<>(nRows);
for (int i = 0; i < nRows; i++) {
if (i == 0) {
//For first MultiDataSet: have 2 rows, not just 1
INDArray in0 = expIn0.get(NDArrayIndex.interval(0, 1, true), NDArrayIndex.all(), NDArrayIndex.all())
.dup();
INDArray in1 = expIn1.get(NDArrayIndex.interval(0, 1, true), NDArrayIndex.all(), NDArrayIndex.all())
.dup();
INDArray out0 = expOut0.get(NDArrayIndex.interval(0, 1, true), NDArrayIndex.all(), NDArrayIndex.all())
.dup();
INDArray out1 = expOut1.get(NDArrayIndex.interval(0, 1, true), NDArrayIndex.all(), NDArrayIndex.all())
.dup();
list.add(new MultiDataSet(new INDArray[] {in0, in1}, new INDArray[] {out0, out1}));
i++;
} else {
INDArray in0 = expIn0.get(NDArrayIndex.interval(i, i, true), NDArrayIndex.all(), NDArrayIndex.all())
.dup();
INDArray in1 = expIn1.get(NDArrayIndex.interval(i, i, true), NDArrayIndex.all(), NDArrayIndex.all())
.dup();
INDArray out0 = expOut0.get(NDArrayIndex.interval(i, i, true), NDArrayIndex.all(), NDArrayIndex.all())
.dup();
INDArray out1 = expOut1.get(NDArrayIndex.interval(i, i, true), NDArrayIndex.all(), NDArrayIndex.all())
.dup();
list.add(new MultiDataSet(new INDArray[] {in0, in1}, new INDArray[] {out0, out1}));
}
}
MultiDataSet merged = MultiDataSet.merge(list);
assertEquals(2, merged.getFeatures().length);
assertEquals(2, merged.getLabels().length);
assertEquals(expIn0, merged.getFeatures(0));
assertEquals(expIn1, merged.getFeatures(1));
assertEquals(expOut0, merged.getLabels(0));
assertEquals(expOut1, merged.getLabels(1));
}
@Test
public void testMergingTimeSeriesWithMasking() {
//Mask arrays, and different lengths
int tsLengthIn0 = 8;
int tsLengthIn1 = 9;
int tsLengthOut0 = 10;
int tsLengthOut1 = 11;
int nRows = 5;
int nColsIn0 = 3;
int nColsIn1 = 4;
int nColsOut0 = 5;
int nColsOut1 = 6;
INDArray expectedIn0 = Nd4j.zeros(nRows, nColsIn0, tsLengthIn0);
INDArray expectedIn1 = Nd4j.zeros(nRows, nColsIn1, tsLengthIn1);
INDArray expectedOut0 = Nd4j.zeros(nRows, nColsOut0, tsLengthOut0);
INDArray expectedOut1 = Nd4j.zeros(nRows, nColsOut1, tsLengthOut1);
INDArray expectedMaskIn0 = Nd4j.zeros(nRows, tsLengthIn0);
INDArray expectedMaskIn1 = Nd4j.zeros(nRows, tsLengthIn1);
INDArray expectedMaskOut0 = Nd4j.zeros(nRows, tsLengthOut0);
INDArray expectedMaskOut1 = Nd4j.zeros(nRows, tsLengthOut1);
Random r = new Random(12345);
List<MultiDataSet> list = new ArrayList<>(nRows);
for (int i = 0; i < nRows; i++) {
int thisRowIn0Length = tsLengthIn0 - i;
int thisRowIn1Length = tsLengthIn1 - i;
int thisRowOut0Length = tsLengthOut0 - i;
int thisRowOut1Length = tsLengthOut1 - i;
int in0NumElem = thisRowIn0Length * nColsIn0;
INDArray in0 = Nd4j.linspace(0, in0NumElem - 1, in0NumElem).reshape(1, nColsIn0, thisRowIn0Length);
int in1NumElem = thisRowIn1Length * nColsIn1;
INDArray in1 = Nd4j.linspace(0, in1NumElem - 1, in1NumElem).reshape(1, nColsIn1, thisRowIn1Length);
int out0NumElem = thisRowOut0Length * nColsOut0;
INDArray out0 = Nd4j.linspace(0, out0NumElem - 1, out0NumElem).reshape(1, nColsOut0, thisRowOut0Length);
int out1NumElem = thisRowOut1Length * nColsOut1;
INDArray out1 = Nd4j.linspace(0, out1NumElem - 1, out1NumElem).reshape(1, nColsOut1, thisRowOut1Length);
INDArray maskIn0 = null;
INDArray maskIn1 = Nd4j.zeros(1, thisRowIn1Length);
for (int j = 0; j < thisRowIn1Length; j++) {
if (r.nextBoolean())
maskIn1.putScalar(j, 1.0);
}
INDArray maskOut0 = null;
INDArray maskOut1 = Nd4j.zeros(1, thisRowOut1Length);
for (int j = 0; j < thisRowOut1Length; j++) {
if (r.nextBoolean())
maskOut1.putScalar(j, 1.0);
}
expectedIn0.put(new INDArrayIndex[] {NDArrayIndex.point(i), NDArrayIndex.all(),
NDArrayIndex.interval(0, thisRowIn0Length)}, in0);
expectedIn1.put(new INDArrayIndex[] {NDArrayIndex.point(i), NDArrayIndex.all(),
NDArrayIndex.interval(0, thisRowIn1Length)}, in1);
expectedOut0.put(new INDArrayIndex[] {NDArrayIndex.point(i), NDArrayIndex.all(),
NDArrayIndex.interval(0, thisRowOut0Length)}, out0);
expectedOut1.put(new INDArrayIndex[] {NDArrayIndex.point(i), NDArrayIndex.all(),
NDArrayIndex.interval(0, thisRowOut1Length)}, out1);
expectedMaskIn0.put(new INDArrayIndex[] {NDArrayIndex.point(i), NDArrayIndex.interval(0, thisRowIn0Length)},
Nd4j.ones(1, thisRowIn0Length));
expectedMaskIn1.put(new INDArrayIndex[] {NDArrayIndex.point(i), NDArrayIndex.interval(0, thisRowIn1Length)},
maskIn1);
expectedMaskOut0.put(
new INDArrayIndex[] {NDArrayIndex.point(i), NDArrayIndex.interval(0, thisRowOut0Length)},
Nd4j.ones(1, thisRowOut0Length));
expectedMaskOut1.put(
new INDArrayIndex[] {NDArrayIndex.point(i), NDArrayIndex.interval(0, thisRowOut1Length)},
maskOut1);
list.add(new MultiDataSet(new INDArray[] {in0, in1}, new INDArray[] {out0, out1},
new INDArray[] {maskIn0, maskIn1}, new INDArray[] {maskOut0, maskOut1}));
}
MultiDataSet merged = MultiDataSet.merge(list);
assertEquals(2, merged.getFeatures().length);
assertEquals(2, merged.getLabels().length);
assertEquals(2, merged.getFeaturesMaskArrays().length);
assertEquals(2, merged.getLabelsMaskArrays().length);
assertEquals(expectedIn0, merged.getFeatures(0));
assertEquals(expectedIn1, merged.getFeatures(1));
assertEquals(expectedOut0, merged.getLabels(0));
assertEquals(expectedOut1, merged.getLabels(1));
assertEquals(expectedMaskIn0, merged.getFeaturesMaskArray(0));
assertEquals(expectedMaskIn1, merged.getFeaturesMaskArray(1));
assertEquals(expectedMaskOut0, merged.getLabelsMaskArray(0));
assertEquals(expectedMaskOut1, merged.getLabelsMaskArray(1));
}
@Test
public void testMergingWithPerOutputMasking(){
//Test 2d mask merging, 2d data
//features
INDArray f2d1 = Nd4j.create(new double[]{1, 2, 3});
INDArray f2d2 = Nd4j.create(new double[][]{{4, 5, 6}, {7, 8, 9}});
//labels
INDArray l2d1 = Nd4j.create(new double[]{1.5, 2.5, 3.5});
INDArray l2d2 = Nd4j.create(new double[][]{{4.5, 5.5, 6.5}, {7.5, 8.5, 9.5}});
//feature masks
INDArray fm2d1 = Nd4j.create(new double[]{0, 1, 1});
INDArray fm2d2 = Nd4j.create(new double[][]{{1, 0, 1}, {0, 1, 0}});
//label masks
INDArray lm2d1 = Nd4j.create(new double[]{1, 1, 0});
INDArray lm2d2 = Nd4j.create(new double[][]{{1, 0, 0}, {0, 1, 1}});
MultiDataSet mds2d1 = new MultiDataSet(f2d1, l2d1, fm2d1, lm2d1);
MultiDataSet mds2d2 = new MultiDataSet(f2d2, l2d2, fm2d2, lm2d2);
MultiDataSet merged = MultiDataSet.merge(Arrays.asList(mds2d1, mds2d2));
INDArray expFeatures2d = Nd4j.create(new double[][]{{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
INDArray expLabels2d = Nd4j.create(new double[][]{{1.5, 2.5, 3.5}, {4.5, 5.5, 6.5}, {7.5, 8.5, 9.5}});
INDArray expFM2d = Nd4j.create(new double[][]{{0, 1, 1}, {1, 0, 1}, {0, 1, 0}});
INDArray expLM2d = Nd4j.create(new double[][]{{1, 1, 0}, {1, 0, 0}, {0, 1, 1}});
MultiDataSet mdsExp2d = new MultiDataSet(expFeatures2d, expLabels2d, expFM2d, expLM2d);
assertEquals(mdsExp2d, merged);
//Test 4d features, 2d labels, 2d masks
INDArray f4d1 = Nd4j.create(1,3,5,5);
INDArray f4d2 = Nd4j.create(2,3,5,5);
MultiDataSet mds4d1 = new MultiDataSet(f4d1, l2d1, null, lm2d1);
MultiDataSet mds4d2 = new MultiDataSet(f4d2, l2d2, null, lm2d2);
MultiDataSet merged4d = MultiDataSet.merge(Arrays.asList(mds4d1, mds4d2));
assertEquals(expLabels2d, merged4d.getLabels(0));
assertEquals(expLM2d, merged4d.getLabelsMaskArray(0));
//Test 3d mask merging, 3d data
INDArray f3d1 = Nd4j.create(1,3,4);
INDArray f3d2 = Nd4j.create(1,3,3);
INDArray l3d1 = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.create(1,3,4),0.5));
INDArray l3d2 = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.create(2,3,3),0.5));
INDArray lm3d1 = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.create(1,3,4),0.5));
INDArray lm3d2 = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.create(2,3,3),0.5));
MultiDataSet mds3d1 = new MultiDataSet(f3d1, l3d1, null, lm3d1);
MultiDataSet mds3d2 = new MultiDataSet(f3d2, l3d2, null, lm3d2);
INDArray expLabels3d = Nd4j.create(3,3,4);
expLabels3d.put(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.all(),
NDArrayIndex.interval(0,4)}, l3d1 );
expLabels3d.put(new INDArrayIndex[]{NDArrayIndex.interval(1,2,true),
NDArrayIndex.all(), NDArrayIndex.interval(0,3)}, l3d2 );
INDArray expLM3d = Nd4j.create(3,3,4);
expLM3d.put(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.all(),
NDArrayIndex.interval(0,4)}, lm3d1 );
expLM3d.put(new INDArrayIndex[]{NDArrayIndex.interval(1,2,true),
NDArrayIndex.all(), NDArrayIndex.interval(0,3)}, lm3d2 );
MultiDataSet merged3d = MultiDataSet.merge(Arrays.asList(mds3d1, mds3d2));
assertEquals(expLabels3d, merged3d.getLabels(0));
assertEquals(expLM3d, merged3d.getLabelsMaskArray(0));
//Test 3d features, 2d masks, 2d output (for example: RNN -> global pooling w/ per-output masking)
MultiDataSet mds3d2d1 = new MultiDataSet(f3d1, l2d1, null, lm2d1);
MultiDataSet mds3d2d2 = new MultiDataSet(f3d2, l2d2, null, lm2d2);
MultiDataSet merged3d2d = MultiDataSet.merge(Arrays.asList(mds3d2d1, mds3d2d2));
assertEquals(expLabels2d, merged3d2d.getLabels(0));
assertEquals(expLM2d, merged3d2d.getLabelsMaskArray(0));
}
@Test
public void testSplit() {
INDArray[] features = new INDArray[3];
features[0] = Nd4j.linspace(1, 30, 30).reshape('c', 3, 10);
features[1] = Nd4j.linspace(1, 300, 300).reshape('c', 3, 10, 10);
features[2] = Nd4j.linspace(1, 3 * 5 * 10 * 10, 3 * 5 * 10 * 10).reshape('c', 3, 5, 10, 10);
INDArray[] labels = new INDArray[3];
labels[0] = Nd4j.linspace(1, 30, 30).reshape('c', 3, 10).addi(0.5);
labels[1] = Nd4j.linspace(1, 300, 300).reshape('c', 3, 10, 10).addi(0.3);
labels[2] = Nd4j.linspace(1, 3 * 5 * 10 * 10, 3 * 5 * 10 * 10).reshape('c', 3, 5, 10, 10).addi(0.1);
INDArray[] fMask = new INDArray[3];
fMask[1] = Nd4j.linspace(1, 30, 30).reshape('f', 3, 10);
INDArray[] lMask = new INDArray[3];
lMask[1] = Nd4j.linspace(1, 30, 30).reshape('f', 3, 10).addi(0.5);
MultiDataSet mds = new MultiDataSet(features, labels, fMask, lMask);
List<org.nd4j.linalg.dataset.api.MultiDataSet> list = mds.asList();
assertEquals(3, list.size());
for (int i = 0; i < 3; i++) {
MultiDataSet m = (MultiDataSet) list.get(i);
assertEquals(2, m.getFeatures(0).rank());
assertEquals(3, m.getFeatures(1).rank());
assertEquals(4, m.getFeatures(2).rank());
assertArrayEquals(new int[] {1, 10}, m.getFeatures(0).shape());
assertArrayEquals(new int[] {1, 10, 10}, m.getFeatures(1).shape());
assertArrayEquals(new int[] {1, 5, 10, 10}, m.getFeatures(2).shape());
assertEquals(features[0].get(NDArrayIndex.point(i), NDArrayIndex.all()), m.getFeatures(0));
assertEquals(features[1].get(NDArrayIndex.interval(i, i, true), NDArrayIndex.all(), NDArrayIndex.all()),
m.getFeatures(1));
assertEquals(features[2].get(NDArrayIndex.interval(i, i, true), NDArrayIndex.all(), NDArrayIndex.all(),
NDArrayIndex.all()), m.getFeatures(2));
assertEquals(2, m.getLabels(0).rank());
assertEquals(3, m.getLabels(1).rank());
assertEquals(4, m.getLabels(2).rank());
assertArrayEquals(new int[] {1, 10}, m.getLabels(0).shape());
assertArrayEquals(new int[] {1, 10, 10}, m.getLabels(1).shape());
assertArrayEquals(new int[] {1, 5, 10, 10}, m.getLabels(2).shape());
assertEquals(labels[0].get(NDArrayIndex.point(i), NDArrayIndex.all()), m.getLabels(0));
assertEquals(labels[1].get(NDArrayIndex.interval(i, i, true), NDArrayIndex.all(), NDArrayIndex.all()),
m.getLabels(1));
assertEquals(labels[2].get(NDArrayIndex.interval(i, i, true), NDArrayIndex.all(), NDArrayIndex.all(),
NDArrayIndex.all()), m.getLabels(2));
assertNull(m.getFeaturesMaskArray(0));
assertEquals(fMask[1].get(NDArrayIndex.point(i), NDArrayIndex.all()), m.getFeaturesMaskArray(1));
assertNull(m.getLabelsMaskArray(0));
assertEquals(lMask[1].get(NDArrayIndex.point(i), NDArrayIndex.all()), m.getLabelsMaskArray(1));
}
}
@Test
public void testToString() {
//Mask arrays, and different lengths
int tsLengthIn0 = 8;
int tsLengthIn1 = 9;
int tsLengthOut0 = 10;
int tsLengthOut1 = 11;
int nRows = 5;
int nColsIn0 = 3;
int nColsIn1 = 4;
int nColsOut0 = 5;
int nColsOut1 = 6;
INDArray expectedIn0 = Nd4j.zeros(nRows, nColsIn0, tsLengthIn0);
INDArray expectedIn1 = Nd4j.zeros(nRows, nColsIn1, tsLengthIn1);
INDArray expectedOut0 = Nd4j.zeros(nRows, nColsOut0, tsLengthOut0);
INDArray expectedOut1 = Nd4j.zeros(nRows, nColsOut1, tsLengthOut1);
INDArray expectedMaskIn0 = Nd4j.zeros(nRows, tsLengthIn0);
INDArray expectedMaskIn1 = Nd4j.zeros(nRows, tsLengthIn1);
INDArray expectedMaskOut0 = Nd4j.zeros(nRows, tsLengthOut0);
INDArray expectedMaskOut1 = Nd4j.zeros(nRows, tsLengthOut1);
Random r = new Random(12345);
List<MultiDataSet> list = new ArrayList<>(nRows);
for (int i = 0; i < nRows; i++) {
int thisRowIn0Length = tsLengthIn0 - i;
int thisRowIn1Length = tsLengthIn1 - i;
int thisRowOut0Length = tsLengthOut0 - i;
int thisRowOut1Length = tsLengthOut1 - i;
int in0NumElem = thisRowIn0Length * nColsIn0;
INDArray in0 = Nd4j.linspace(0, in0NumElem - 1, in0NumElem).reshape(1, nColsIn0, thisRowIn0Length);
int in1NumElem = thisRowIn1Length * nColsIn1;
INDArray in1 = Nd4j.linspace(0, in1NumElem - 1, in1NumElem).reshape(1, nColsIn1, thisRowIn1Length);
int out0NumElem = thisRowOut0Length * nColsOut0;
INDArray out0 = Nd4j.linspace(0, out0NumElem - 1, out0NumElem).reshape(1, nColsOut0, thisRowOut0Length);
int out1NumElem = thisRowOut1Length * nColsOut1;
INDArray out1 = Nd4j.linspace(0, out1NumElem - 1, out1NumElem).reshape(1, nColsOut1, thisRowOut1Length);
INDArray maskIn0 = null;
INDArray maskIn1 = Nd4j.zeros(1, thisRowIn1Length);
for (int j = 0; j < thisRowIn1Length; j++) {
if (r.nextBoolean())
maskIn1.putScalar(j, 1.0);
}
INDArray maskOut0 = null;
INDArray maskOut1 = Nd4j.zeros(1, thisRowOut1Length);
for (int j = 0; j < thisRowOut1Length; j++) {
if (r.nextBoolean())
maskOut1.putScalar(j, 1.0);
}
expectedIn0.put(new INDArrayIndex[] {NDArrayIndex.point(i), NDArrayIndex.all(),
NDArrayIndex.interval(0, thisRowIn0Length)}, in0);
expectedIn1.put(new INDArrayIndex[] {NDArrayIndex.point(i), NDArrayIndex.all(),
NDArrayIndex.interval(0, thisRowIn1Length)}, in1);
expectedOut0.put(new INDArrayIndex[] {NDArrayIndex.point(i), NDArrayIndex.all(),
NDArrayIndex.interval(0, thisRowOut0Length)}, out0);
expectedOut1.put(new INDArrayIndex[] {NDArrayIndex.point(i), NDArrayIndex.all(),
NDArrayIndex.interval(0, thisRowOut1Length)}, out1);
expectedMaskIn0.put(new INDArrayIndex[] {NDArrayIndex.point(i), NDArrayIndex.interval(0, thisRowIn0Length)},
Nd4j.ones(1, thisRowIn0Length));
expectedMaskIn1.put(new INDArrayIndex[] {NDArrayIndex.point(i), NDArrayIndex.interval(0, thisRowIn1Length)},
maskIn1);
expectedMaskOut0.put(
new INDArrayIndex[] {NDArrayIndex.point(i), NDArrayIndex.interval(0, thisRowOut0Length)},
Nd4j.ones(1, thisRowOut0Length));
expectedMaskOut1.put(
new INDArrayIndex[] {NDArrayIndex.point(i), NDArrayIndex.interval(0, thisRowOut1Length)},
maskOut1);
list.add(new MultiDataSet(new INDArray[] {in0, in1}, new INDArray[] {out0, out1},
new INDArray[] {maskIn0, maskIn1}, new INDArray[] {maskOut0, maskOut1}));
}
MultiDataSet merged = MultiDataSet.merge(list);
System.out.println(merged);
}
@Test
public void multiDataSetSaveLoadTest() throws IOException {
int max = 3;
Nd4j.getRandom().setSeed(12345);
for (int numF = 0; numF <= max; numF++) {
for (int numL = 0; numL <= max; numL++) {
INDArray[] f = (numF > 0 ? new INDArray[numF] : null);
INDArray[] l = (numL > 0 ? new INDArray[numL] : null);
INDArray[] fm = (numF > 0 ? new INDArray[numF] : null);
INDArray[] lm = (numL > 0 ? new INDArray[numL] : null);
if (numF > 0) {
for (int i = 0; i < f.length; i++) {
f[i] = Nd4j.rand(new int[] {3, 4, 5});
}
}
if (numL > 0) {
for (int i = 0; i < l.length; i++) {
l[i] = Nd4j.rand(new int[] {2, 3, 4});
}
}
if (numF > 0) {
for (int i = 0; i < Math.min(fm.length, 2); i++) {
fm[i] = Nd4j.rand(new int[] {3, 5});
}
}
if (numL > 0) {
for (int i = 0; i < Math.min(lm.length, 2); i++) {
lm[i] = Nd4j.rand(new int[] {2, 4});
}
}
MultiDataSet mds = new MultiDataSet(f, l, fm, lm);
ByteArrayOutputStream baos = new ByteArrayOutputStream();
DataOutputStream dos = new DataOutputStream(baos);
mds.save(dos);
byte[] asBytes = baos.toByteArray();
ByteArrayInputStream bais = new ByteArrayInputStream(asBytes);
DataInputStream dis = new DataInputStream(bais);
MultiDataSet mds2 = new MultiDataSet();
mds2.load(dis);
assertEquals(mds, mds2);
}
}
}
@Override
public char ordering() {
return 'c';
}
}