package org.nd4j.linalg.util;
import org.junit.Test;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.factory.Nd4jBackend;
import java.util.Random;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
public class TestArrayUtils extends BaseNd4jTest {
public TestArrayUtils(Nd4jBackend backend) {
super(backend);
}
@Test
public void testFlattenDoubleArray() {
assertArrayEquals(new double[0], ArrayUtil.flattenDoubleArray(new double[0]), 0.0);
Random r = new Random(12345L);
double[] d1 = new double[10];
for (int i = 0; i < d1.length; i++)
d1[i] = r.nextDouble();
assertArrayEquals(d1, ArrayUtil.flattenDoubleArray(d1), 0.0);
double[][] d2 = new double[5][10];
for (int i = 0; i < 5; i++)
for (int j = 0; j < 10; j++)
d2[i][j] = r.nextDouble();
assertArrayEquals(ArrayUtil.flatten(d2), ArrayUtil.flattenDoubleArray(d2), 0.0);
double[][][] d3 = new double[5][10][15];
double[] exp3 = new double[5 * 10 * 15];
int c = 0;
for (int i = 0; i < 5; i++) {
for (int j = 0; j < 10; j++) {
for (int k = 0; k < 15; k++) {
double d = r.nextDouble();
exp3[c++] = d;
d3[i][j][k] = d;
}
}
}
assertArrayEquals(exp3, ArrayUtil.flattenDoubleArray(d3), 0.0);
double[][][][] d4 = new double[3][5][7][9];
double[] exp4 = new double[3 * 5 * 7 * 9];
c = 0;
for (int i = 0; i < 3; i++) {
for (int j = 0; j < 5; j++) {
for (int k = 0; k < 7; k++) {
for (int l = 0; l < 9; l++) {
double d = r.nextDouble();
exp4[c++] = d;
d4[i][j][k][l] = d;
}
}
}
}
assertArrayEquals(exp4, ArrayUtil.flattenDoubleArray(d4), 0.0);
}
@Test
public void testFlattenFloatArray() {
assertArrayEquals(new float[0], ArrayUtil.flattenFloatArray(new float[0]), 0.0f);
Random r = new Random(12345L);
float[] f1 = new float[10];
for (int i = 0; i < f1.length; i++)
f1[i] = r.nextFloat();
assertArrayEquals(f1, ArrayUtil.flattenFloatArray(f1), 0.0f);
float[][] f2 = new float[5][10];
for (int i = 0; i < 5; i++)
for (int j = 0; j < 10; j++)
f2[i][j] = r.nextFloat();
assertArrayEquals(ArrayUtil.flatten(f2), ArrayUtil.flattenFloatArray(f2), 0.0f);
float[][][] f3 = new float[5][10][15];
float[] exp3 = new float[5 * 10 * 15];
int c = 0;
for (int i = 0; i < 5; i++) {
for (int j = 0; j < 10; j++) {
for (int k = 0; k < 15; k++) {
float d = r.nextFloat();
exp3[c++] = d;
f3[i][j][k] = d;
}
}
}
assertArrayEquals(exp3, ArrayUtil.flattenFloatArray(f3), 0.0f);
float[][][][] f4 = new float[3][5][7][9];
float[] exp4 = new float[3 * 5 * 7 * 9];
c = 0;
for (int i = 0; i < 3; i++) {
for (int j = 0; j < 5; j++) {
for (int k = 0; k < 7; k++) {
for (int l = 0; l < 9; l++) {
float d = r.nextFloat();
exp4[c++] = d;
f4[i][j][k][l] = d;
}
}
}
}
assertArrayEquals(exp4, ArrayUtil.flattenFloatArray(f4), 0.0f);
}
@Test
public void testArrayShape() {
assertArrayEquals(ArrayUtil.arrayShape(new int[0]), new int[] {0});
assertArrayEquals(ArrayUtil.arrayShape(new int[5][7][9]), new int[] {5, 7, 9});
assertArrayEquals(ArrayUtil.arrayShape(new Object[2][3][4][5][6]), new int[] {2, 3, 4, 5, 6});
assertArrayEquals(ArrayUtil.arrayShape(new double[9][7][5][3]), new int[] {9, 7, 5, 3});
assertArrayEquals(ArrayUtil.arrayShape(new double[1][1][1][0]), new int[] {1, 1, 1, 0});
assertArrayEquals(ArrayUtil.arrayShape(new char[3][2][1]), new int[] {3, 2, 1});
assertArrayEquals(ArrayUtil.arrayShape(new String[3][2][1]), new int[] {3, 2, 1});
}
@Test
public void testArgMinOfMaxMethods() {
int[] first = {1, 5, 2, 4};
int[] second = {4, 6, 3, 2};
assertEquals(2, ArrayUtil.argMinOfMax(first, second));
int[] third = {7, 3, 8, 10};
assertEquals(1, ArrayUtil.argMinOfMax(first, second, third));
}
@Override
public char ordering() {
return 'c';
}
}