package org.nd4j.linalg.shape.concat;
import org.apache.commons.math3.util.Pair;
import org.junit.Ignore;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.checkutil.NDArrayCreationUtil;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import java.util.Arrays;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
/**
* @author Adam Gibson
*/
@RunWith(Parameterized.class)
public class ConcatTestsC extends BaseNd4jTest {
public ConcatTestsC(Nd4jBackend backend) {
super(backend);
}
@Test
public void testConcatVertically() {
INDArray rowVector = Nd4j.ones(5);
INDArray other = Nd4j.ones(5);
INDArray concat = Nd4j.vstack(other, rowVector);
assertEquals(rowVector.rows() * 2, concat.rows());
assertEquals(rowVector.columns(), concat.columns());
INDArray arr2 = Nd4j.create(5, 5);
INDArray slice1 = arr2.slice(0);
INDArray slice2 = arr2.slice(1);
INDArray arr3 = Nd4j.create(2, 5);
INDArray vstack = Nd4j.vstack(slice1, slice2);
assertEquals(arr3, vstack);
INDArray col1 = arr2.getColumn(0);
INDArray col2 = arr2.getColumn(1);
INDArray vstacked = Nd4j.vstack(col1, col2);
assertEquals(Nd4j.create(10, 1), vstacked);
}
@Test
public void testConcatScalars() {
INDArray first = Nd4j.arange(0, 1).reshape(1, 1);
INDArray second = Nd4j.arange(0, 1).reshape(1, 1);
INDArray firstRet = Nd4j.concat(0, first, second);
assertTrue(firstRet.isColumnVector());
INDArray secondRet = Nd4j.concat(1, first, second);
assertTrue(secondRet.isRowVector());
}
@Test
public void testConcatScalars1() {
INDArray first = Nd4j.scalar(1);
INDArray second = Nd4j.scalar(2);
INDArray third = Nd4j.scalar(3);
INDArray result = Nd4j.concat(0, first, second, third);
assertEquals(1f, result.getFloat(0), 0.01f);
assertEquals(2f, result.getFloat(1), 0.01f);
assertEquals(3f, result.getFloat(2), 0.01f);
}
@Test
public void testConcatVectors1() {
INDArray first = Nd4j.ones(10);
INDArray second = Nd4j.ones(10);
INDArray third = Nd4j.ones(10);
INDArray result = Nd4j.concat(0, first, second, third);
assertEquals(3, result.rows());
assertEquals(10, result.columns());
System.out.println(result);
for (int x = 0; x < 30; x++) {
assertEquals(1f, result.getFloat(x), 0.001f);
}
}
@Test
public void testConcatMatrices() {
INDArray a = Nd4j.linspace(1, 4, 4).reshape(2, 2);
INDArray b = a.dup();
INDArray concat1 = Nd4j.concat(1, a, b);
INDArray oneAssertion = Nd4j.create(new double[][] {{1, 2, 1, 2}, {3, 4, 3, 4}});
System.out.println("Assertion: " + Arrays.toString(oneAssertion.data().asFloat()));
System.out.println("Result: " + Arrays.toString(concat1.data().asFloat()));
assertEquals(oneAssertion, concat1);
INDArray concat = Nd4j.concat(0, a, b);
INDArray zeroAssertion = Nd4j.create(new double[][] {{1, 2}, {3, 4}, {1, 2}, {3, 4}});
assertEquals(zeroAssertion, concat);
}
@Test
public void testAssign() {
INDArray vector = Nd4j.linspace(1, 5, 5);
vector.assign(1);
assertEquals(Nd4j.ones(5), vector);
INDArray twos = Nd4j.ones(2, 2);
INDArray rand = Nd4j.rand(2, 2);
twos.assign(rand);
assertEquals(rand, twos);
INDArray tensor = Nd4j.rand((long) 3, 3, 3, 3);
INDArray ones = Nd4j.ones(3, 3, 3);
assertTrue(Arrays.equals(tensor.shape(), ones.shape()));
ones.assign(tensor);
assertEquals(tensor, ones);
}
@Test
public void testConcatRowVectors() {
INDArray rowVector = Nd4j.create(new double[] {1, 2, 3, 4, 5, 6}, new int[] {1, 6});
INDArray matrix = Nd4j.create(new double[] {7, 8, 9, 10, 11, 12}, new int[] {1, 6});
INDArray assertion1 = Nd4j.create(new double[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, new int[] {1, 12});
INDArray assertion0 = Nd4j.create(new double[][] {{1, 2, 3, 4, 5, 6}, {7, 8, 9, 10, 11, 12}});
INDArray concat1 = Nd4j.hstack(rowVector, matrix);
INDArray concat0 = Nd4j.vstack(rowVector, matrix);
assertEquals(assertion1, concat1);
assertEquals(assertion0, concat0);
}
@Test
public void testConcat3d() {
INDArray first = Nd4j.linspace(1, 24, 24).reshape('c', 2, 3, 4);
INDArray second = Nd4j.linspace(24, 36, 12).reshape('c', 1, 3, 4);
INDArray third = Nd4j.linspace(36, 48, 12).reshape('c', 1, 3, 4);
//Concat, dim 0
INDArray exp = Nd4j.create(2 + 1 + 1, 3, 4);
exp.put(new INDArrayIndex[] {NDArrayIndex.interval(0, 2), NDArrayIndex.all(), NDArrayIndex.all()}, first);
exp.put(new INDArrayIndex[] {NDArrayIndex.point(2), NDArrayIndex.all(), NDArrayIndex.all()}, second);
exp.put(new INDArrayIndex[] {NDArrayIndex.point(3), NDArrayIndex.all(), NDArrayIndex.all()}, third);
INDArray concat0 = Nd4j.concat(0, first, second, third);
assertEquals(exp, concat0);
//Concat, dim 1
second = Nd4j.linspace(24, 32, 8).reshape('c', 2, 1, 4);
for (int i = 0; i < second.tensorssAlongDimension(1); i++) {
INDArray secondTad = second.javaTensorAlongDimension(i, 1);
System.out.println(second.tensorAlongDimension(i, 1));
}
third = Nd4j.linspace(32, 48, 16).reshape('c', 2, 2, 4);
exp = Nd4j.create(2, 3 + 1 + 2, 4);
exp.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.interval(0, 3), NDArrayIndex.all()}, first);
exp.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.point(3), NDArrayIndex.all()}, second);
exp.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.interval(4, 6), NDArrayIndex.all()}, third);
INDArray concat1 = Nd4j.concat(1, first, second, third);
assertEquals(exp, concat1);
//Concat, dim 2
second = Nd4j.linspace(24, 36, 12).reshape('c', 2, 3, 2);
third = Nd4j.linspace(36, 42, 6).reshape('c', 2, 3, 1);
exp = Nd4j.create(2, 3, 4 + 2 + 1);
exp.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(0, 4)}, first);
exp.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(4, 6)}, second);
exp.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(6)}, third);
INDArray concat2 = Nd4j.concat(2, first, second, third);
assertEquals(exp, concat2);
}
@Test(expected = IllegalArgumentException.class)
public void testConcatVector() {
System.out.println(Nd4j.concat(0, Nd4j.ones(1000000), Nd4j.create(1)));
}
@Test
@Ignore
public void testConcat3dv2() {
INDArray first = Nd4j.linspace(1, 24, 24).reshape('c', 2, 3, 4);
INDArray second = Nd4j.linspace(24, 35, 12).reshape('c', 1, 3, 4);
INDArray third = Nd4j.linspace(36, 47, 12).reshape('c', 1, 3, 4);
//Concat, dim 0
INDArray exp = Nd4j.create(2 + 1 + 1, 3, 4);
exp.put(new INDArrayIndex[] {NDArrayIndex.interval(0, 2), NDArrayIndex.all(), NDArrayIndex.all()}, first);
exp.put(new INDArrayIndex[] {NDArrayIndex.point(2), NDArrayIndex.all(), NDArrayIndex.all()}, second);
exp.put(new INDArrayIndex[] {NDArrayIndex.point(3), NDArrayIndex.all(), NDArrayIndex.all()}, third);
List<Pair<INDArray, String>> firsts = NDArrayCreationUtil.getAll3dTestArraysWithShape(12345, 2, 3, 4);
List<Pair<INDArray, String>> seconds = NDArrayCreationUtil.getAll3dTestArraysWithShape(12345, 1, 3, 4);
List<Pair<INDArray, String>> thirds = NDArrayCreationUtil.getAll3dTestArraysWithShape(12345, 1, 3, 4);
for (Pair<INDArray, String> f : firsts) {
for (Pair<INDArray, String> s : seconds) {
for (Pair<INDArray, String> t : thirds) {
INDArray f2 = f.getFirst().assign(first);
INDArray s2 = s.getFirst().assign(second);
INDArray t2 = t.getFirst().assign(third);
INDArray concat0 = Nd4j.concat(0, f2, s2, t2);
if (!exp.equals(concat0)) {
concat0 = Nd4j.concat(0, f2, s2, t2);
}
assertEquals(exp, concat0);
}
}
}
//Concat, dim 1
second = Nd4j.linspace(24, 31, 8).reshape('c', 2, 1, 4);
third = Nd4j.linspace(32, 47, 16).reshape('c', 2, 2, 4);
exp = Nd4j.create(2, 3 + 1 + 2, 4);
exp.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.interval(0, 3), NDArrayIndex.all()}, first);
exp.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.point(3), NDArrayIndex.all()}, second);
exp.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.interval(4, 6), NDArrayIndex.all()}, third);
firsts = NDArrayCreationUtil.getAll3dTestArraysWithShape(12345, 2, 3, 4);
seconds = NDArrayCreationUtil.getAll3dTestArraysWithShape(12345, 2, 1, 4);
thirds = NDArrayCreationUtil.getAll3dTestArraysWithShape(12345, 2, 2, 4);
for (Pair<INDArray, String> f : firsts) {
for (Pair<INDArray, String> s : seconds) {
for (Pair<INDArray, String> t : thirds) {
INDArray f2 = f.getFirst().assign(first);
INDArray s2 = s.getFirst().assign(second);
INDArray t2 = t.getFirst().assign(third);
INDArray concat1 = Nd4j.concat(1, f2, s2, t2);
assertEquals(exp, concat1);
}
}
}
//Concat, dim 2
second = Nd4j.linspace(24, 35, 12).reshape('c', 2, 3, 2);
third = Nd4j.linspace(36, 41, 6).reshape('c', 2, 3, 1);
exp = Nd4j.create(2, 3, 4 + 2 + 1);
exp.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(0, 4)}, first);
exp.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(4, 6)}, second);
exp.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(6)}, third);
firsts = NDArrayCreationUtil.getAll3dTestArraysWithShape(12345, 2, 3, 4);
seconds = NDArrayCreationUtil.getAll3dTestArraysWithShape(12345, 2, 3, 2);
thirds = NDArrayCreationUtil.getAll3dTestArraysWithShape(12345, 2, 3, 1);
for (Pair<INDArray, String> f : firsts) {
for (Pair<INDArray, String> s : seconds) {
for (Pair<INDArray, String> t : thirds) {
INDArray f2 = f.getFirst().assign(first);
INDArray s2 = s.getFirst().assign(second);
INDArray t2 = t.getFirst().assign(third);
INDArray concat2 = Nd4j.concat(2, f2, s2, t2);
assertEquals(exp, concat2);
}
}
}
}
@Override
public char ordering() {
return 'c';
}
}