package jcuda.jcublas.ops; import org.junit.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.util.ArrayUtil; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Random; import static junit.framework.TestCase.assertTrue; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; /** * @author raver119@gmail.com */ public class ShufflesTests { @Test public void testSimpleShuffle1() { INDArray array = Nd4j.zeros(10, 10); for (int x = 0; x < 10; x++) { array.getRow(x).assign(x); } System.out.println(array); OrderScanner2D scanner = new OrderScanner2D(array); assertArrayEquals(new float[]{0f, 1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f, 9f}, scanner.getMap(), 0.01f); System.out.println(); Nd4j.shuffle(array, 1); System.out.println(array); ArrayUtil.argMin(new int[]{}); assertTrue(scanner.compareRow(array)); } @Test public void testSimpleShuffle2() { INDArray array = Nd4j.zeros(10, 10); for (int x = 0; x < 10; x++) { array.getColumn(x).assign(x); } System.out.println(array); OrderScanner2D scanner = new OrderScanner2D(array); assertArrayEquals(new float[]{0f, 0f, 0f, 0f, 0f, 0f, 0f, 0f, 0f, 0f}, scanner.getMap(), 0.01f); System.out.println(); Nd4j.shuffle(array, 0); System.out.println(array); assertTrue(scanner.compareColumn(array)); } @Test public void testSymmetricShuffle1() { INDArray features = Nd4j.zeros(10, 10); INDArray labels = Nd4j.zeros(10, 3); for (int x = 0; x < 10; x++) { features.getRow(x).assign(x); labels.getRow(x).assign(x); } System.out.println(features); OrderScanner2D scanner = new OrderScanner2D(features); assertArrayEquals(new float[]{0f, 1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f, 9f}, scanner.getMap(), 0.01f); System.out.println(); List<INDArray> list = new ArrayList<>(); list.add(features); list.add(labels); Nd4j.shuffle(list, 1); System.out.println(features); System.out.println(); System.out.println(labels); ArrayUtil.argMin(new int[]{}); assertTrue(scanner.compareRow(features)); for (int x = 0; x < 10; x++) { double val = features.getRow(x).getDouble(0); INDArray row = labels.getRow(x); for (int y = 0; y < row.length(); y++ ) { assertEquals(val, row.getDouble(y), 0.001); } } } @Test public void testSymmetricShuffle2() throws Exception { INDArray features = Nd4j.zeros(10, 10, 20); INDArray labels = Nd4j.zeros(10, 10, 3); for (int x = 0; x < 10; x++) { features.slice(x).assign(x); labels.slice(x).assign(x); } System.out.println(features); OrderScanner3D scannerFeatures = new OrderScanner3D(features); OrderScanner3D scannerLabels = new OrderScanner3D(labels); System.out.println(); List<INDArray> list = new ArrayList<>(); list.add(features); list.add(labels); Nd4j.shuffle(list, 1, 2); System.out.println(features); System.out.println("------------------"); System.out.println(labels); assertTrue(scannerFeatures.compareSlice(features)); assertTrue(scannerLabels.compareSlice(labels)); for (int x = 0; x < 10; x++) { double val = features.slice(x).getDouble(0); INDArray row = labels.slice(x); for (int y = 0; y < row.length(); y++ ) { assertEquals(val, row.getDouble(y), 0.001); } } } @Test public void testSymmetricShuffle3() throws Exception { INDArray features = Nd4j.zeros(10, 10, 20); INDArray featuresMask = Nd4j.zeros(10, 20); INDArray labels = Nd4j.zeros(10, 10, 3); INDArray labelsMask = Nd4j.zeros(10, 3); for (int x = 0; x < 10; x++) { features.slice(x).assign(x); featuresMask.slice(x).assign(x); labels.slice(x).assign(x); labelsMask.slice(x).assign(x); } OrderScanner3D scannerFeatures = new OrderScanner3D(features); OrderScanner3D scannerLabels = new OrderScanner3D(labels); OrderScanner3D scannerFeaturesMask = new OrderScanner3D(featuresMask); OrderScanner3D scannerLabelsMask = new OrderScanner3D(labelsMask); List<INDArray> arrays = new ArrayList<>(); arrays.add(features); arrays.add(labels); arrays.add(featuresMask); arrays.add(labelsMask); List<int[]> dimensions = new ArrayList<>(); dimensions.add(ArrayUtil.range(1,features.rank())); dimensions.add(ArrayUtil.range(1,labels.rank())); dimensions.add(ArrayUtil.range(1,featuresMask.rank())); dimensions.add(ArrayUtil.range(1,labelsMask.rank())); Nd4j.shuffle(arrays, new Random(), dimensions); assertTrue(scannerFeatures.compareSlice(features)); assertTrue(scannerLabels.compareSlice(labels)); assertTrue(scannerFeaturesMask.compareSlice(featuresMask)); assertTrue(scannerLabelsMask.compareSlice(labelsMask)); for (int x = 0; x < 10; x++) { double val = features.slice(x).getDouble(0); INDArray sliceLabels = labels.slice(x); INDArray sliceLabelsMask = labelsMask.slice(x); INDArray sliceFeaturesMask = featuresMask.slice(x); for (int y = 0; y < sliceLabels.length(); y++ ) { assertEquals(val, sliceLabels.getDouble(y), 0.001); } for (int y = 0; y < sliceLabelsMask.length(); y++ ) { assertEquals(val, sliceLabelsMask.getDouble(y), 0.001); } for (int y = 0; y < sliceFeaturesMask.length(); y++ ) { assertEquals(val, sliceFeaturesMask.getDouble(y), 0.001); } } } @Test public void testSymmetricShuffle4() throws Exception { INDArray features = Nd4j.zeros(10, 3, 4, 2); INDArray labels = Nd4j.zeros(10, 5); for (int x = 0; x < 10; x++) { features.slice(x).assign(x); labels.slice(x).assign(x); } OrderScanner3D scannerFeatures = new OrderScanner3D(features); OrderScanner3D scannerLabels = new OrderScanner3D(labels); System.out.println(features); System.out.println(); DataSet ds = new DataSet(features, labels); ds.shuffle(); System.out.println(features); System.out.println("------------------"); assertTrue(scannerFeatures.compareSlice(features)); assertTrue(scannerLabels.compareSlice(labels)); for (int x = 0; x < 10; x++) { double val = features.slice(x).getDouble(0); INDArray row = labels.slice(x); for (int y = 0; y < row.length(); y++ ) { assertEquals(val, row.getDouble(y), 0.001); } } } @Test public void testSymmetricShuffle4F() throws Exception { INDArray features = Nd4j.create(new int[] {10, 3, 4, 2} , 'f'); INDArray labels = Nd4j.create(new int[] {10, 5}, 'f'); for (int x = 0; x < 10; x++) { features.slice(x).assign(x); labels.slice(x).assign(x); } System.out.println("features.length: " + features.length()); System.out.println(labels); System.out.println(); DataSet ds = new DataSet(features, labels); ds.shuffle(); System.out.println(labels); System.out.println("------------------"); for (int x = 0; x < 10; x++) { double val = features.slice(x).getDouble(0); INDArray row = labels.slice(x); for (int y = 0; y < row.length(); y++ ) { assertEquals(val, row.getDouble(y), 0.001); } } } @Test public void testHalfVectors() throws Exception { int[] array = ArrayUtil.buildHalfVector(new Random(12), 11); System.out.println("HalfVec: " + Arrays.toString(array)); } public static class OrderScanner3D { private float[] map; public OrderScanner3D(INDArray data) { map = measureState(data); } public float[] measureState(INDArray data) { // for 3D we save 0 element for each slice. float[] result = new float[data.shape()[0]]; for (int x = 0; x < data.shape()[0]; x++) { result[x] = data.slice(x).getFloat(0); } return result; } public boolean compareSlice(INDArray data) { float[] newMap = measureState(data); if (newMap.length != map.length) { System.out.println("Different map lengths"); return false; } if (Arrays.equals(map, newMap)) { System.out.println("Maps are equal"); return false; } for (int x = 0; x < data.shape()[0]; x++) { INDArray slice = data.slice(x); for (int y = 0; y < slice.length(); y++) { if (Math.abs(slice.getFloat(y) - newMap[x]) > Nd4j.EPS_THRESHOLD) { System.out.print("Different data in a row"); return false; } } } return true; } } public static class OrderScanner2D { private float[] map; public OrderScanner2D(INDArray data) { map = measureState(data); } public float[] measureState(INDArray data) { float[] result = new float[data.rows()]; for (int x = 0; x < data.rows(); x++) { result[x] = data.getRow(x).getFloat(0); } return result; } public boolean compareRow(INDArray newData) { float[] newMap = measureState(newData); if (newMap.length != map.length) { System.out.println("Different map lengths"); return false; } if (Arrays.equals(map, newMap)) { System.out.println("Maps are equal"); return false; } for (int x = 0; x < newData.rows(); x++) { INDArray row = newData.getRow(x); for (int y = 0; y < row.lengthLong(); y++ ) { if (Math.abs(row.getFloat(y) - newMap[x]) > Nd4j.EPS_THRESHOLD) { System.out.print("Different data in a row"); return false; } } } return true; } public boolean compareColumn(INDArray newData) { float[] newMap = measureState(newData); if (newMap.length != map.length) { System.out.println("Different map lengths"); return false; } if (Arrays.equals(map, newMap)) { System.out.println("Maps are equal"); return false; } for (int x = 0; x < newData.rows(); x++) { INDArray column = newData.getColumn(x); double val = column.getDouble(0); for (int y = 0; y < column.lengthLong(); y++ ) { if (Math.abs(column.getFloat(y) - val) > Nd4j.EPS_THRESHOLD) { System.out.print("Different data in a column: " + column.getFloat(y)); return false; } } } return true; } public float[] getMap() { return map; } } }