package org.nd4j.linalg.shape.concat; import lombok.extern.slf4j.Slf4j; import org.apache.commons.math3.util.Pair; import org.junit.Ignore; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.checkutil.NDArrayCreationUtil; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.indexing.INDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex; import java.util.Arrays; import java.util.List; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; /** * @author Adam Gibson */ @Slf4j @RunWith(Parameterized.class) public class ConcatTests extends BaseNd4jTest { public ConcatTests(Nd4jBackend backend) { super(backend); } @Test public void testConcat() { INDArray A = Nd4j.linspace(1, 8, 8).reshape(2, 2, 2); INDArray B = Nd4j.linspace(1, 12, 12).reshape(3, 2, 2); INDArray concat = Nd4j.concat(0, A, B); assertTrue(Arrays.equals(new int[] {5, 2, 2}, concat.shape())); } @Test public void testConcatHorizontally() { INDArray rowVector = Nd4j.ones(5); INDArray other = Nd4j.ones(5); INDArray concat = Nd4j.hstack(other, rowVector); assertEquals(rowVector.rows(), concat.rows()); assertEquals(rowVector.columns() * 2, concat.columns()); } @Test public void testVStackColumn() { INDArray linspaced = Nd4j.linspace(1, 3, 3).reshape(3, 1); INDArray stacked = linspaced.dup(); INDArray assertion = Nd4j.create(new double[] {1, 2, 3, 1, 2, 3}, new int[] {6, 1}); INDArray test = Nd4j.vstack(linspaced, stacked); assertEquals(assertion, test); } @Test public void testConcatScalars() { INDArray first = Nd4j.arange(0, 1).reshape(1, 1); INDArray second = Nd4j.arange(0, 1).reshape(1, 1); INDArray firstRet = Nd4j.concat(0, first, second); assertTrue(firstRet.isColumnVector()); INDArray secondRet = Nd4j.concat(1, first, second); assertTrue(secondRet.isRowVector()); } @Test public void testConcatMatrices() { INDArray a = Nd4j.linspace(1, 4, 4).reshape(2, 2); INDArray b = a.dup(); INDArray concat1 = Nd4j.concat(1, a, b); INDArray oneAssertion = Nd4j.create(new double[][] {{1, 3, 1, 3}, {2, 4, 2, 4}}); assertEquals(oneAssertion, concat1); INDArray concat = Nd4j.concat(0, a, b); INDArray zeroAssertion = Nd4j.create(new double[][] {{1, 3}, {2, 4}, {1, 3}, {2, 4}}); assertEquals(zeroAssertion, concat); } @Test public void testConcatRowVectors() { INDArray rowVector = Nd4j.create(new double[] {1, 2, 3, 4, 5, 6}, new int[] {1, 6}); INDArray matrix = Nd4j.create(new double[] {7, 8, 9, 10, 11, 12}, new int[] {1, 6}); INDArray assertion1 = Nd4j.create(new double[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, new int[] {1, 12}); INDArray assertion0 = Nd4j.create(new double[][] {{1, 2, 3, 4, 5, 6}, {7, 8, 9, 10, 11, 12}}); // INDArray concat1 = Nd4j.hstack(rowVector, matrix); INDArray concat0 = Nd4j.vstack(rowVector, matrix); // assertEquals(assertion1, concat1); assertEquals(assertion0, concat0); } @Test public void testConcat3d() { INDArray first = Nd4j.linspace(1, 24, 24).reshape('c', 2, 3, 4); INDArray second = Nd4j.linspace(24, 36, 12).reshape('c', 1, 3, 4); INDArray third = Nd4j.linspace(36, 48, 12).reshape('c', 1, 3, 4); //Concat, dim 0 INDArray exp = Nd4j.create(2 + 1 + 1, 3, 4); exp.put(new INDArrayIndex[] {NDArrayIndex.interval(0, 2), NDArrayIndex.all(), NDArrayIndex.all()}, first); exp.put(new INDArrayIndex[] {NDArrayIndex.point(2), NDArrayIndex.all(), NDArrayIndex.all()}, second); exp.put(new INDArrayIndex[] {NDArrayIndex.point(3), NDArrayIndex.all(), NDArrayIndex.all()}, third); INDArray concat0 = Nd4j.concat(0, first, second, third); assertEquals(exp, concat0); System.out.println("1------------------------"); //Concat, dim 1 second = Nd4j.linspace(24, 32, 8).reshape('c', 2, 1, 4); third = Nd4j.linspace(32, 48, 16).reshape('c', 2, 2, 4); exp = Nd4j.create(2, 3 + 1 + 2, 4); exp.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.interval(0, 3), NDArrayIndex.all()}, first); exp.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.point(3), NDArrayIndex.all()}, second); exp.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.interval(4, 6), NDArrayIndex.all()}, third); System.out.println("2------------------------"); INDArray concat1 = Nd4j.concat(1, first, second, third); assertEquals(exp, concat1); //Concat, dim 2 second = Nd4j.linspace(24, 36, 12).reshape('c', 2, 3, 2); third = Nd4j.linspace(36, 42, 6).reshape('c', 2, 3, 1); exp = Nd4j.create(2, 3, 4 + 2 + 1); exp.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(0, 4)}, first); exp.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(4, 6)}, second); exp.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(6)}, third); INDArray concat2 = Nd4j.concat(2, first, second, third); assertEquals(exp, concat2); } @Test @Ignore public void testConcat3dv2() { INDArray first = Nd4j.linspace(1, 24, 24).reshape('c', 2, 3, 4); INDArray second = Nd4j.linspace(24, 35, 12).reshape('c', 1, 3, 4); INDArray third = Nd4j.linspace(36, 47, 12).reshape('c', 1, 3, 4); //Concat, dim 0 INDArray exp = Nd4j.create(2 + 1 + 1, 3, 4); exp.put(new INDArrayIndex[] {NDArrayIndex.interval(0, 2), NDArrayIndex.all(), NDArrayIndex.all()}, first); exp.put(new INDArrayIndex[] {NDArrayIndex.point(2), NDArrayIndex.all(), NDArrayIndex.all()}, second); exp.put(new INDArrayIndex[] {NDArrayIndex.point(3), NDArrayIndex.all(), NDArrayIndex.all()}, third); List<Pair<INDArray, String>> firsts = NDArrayCreationUtil.getAll3dTestArraysWithShape(12345, 2, 3, 4); List<Pair<INDArray, String>> seconds = NDArrayCreationUtil.getAll3dTestArraysWithShape(12345, 1, 3, 4); List<Pair<INDArray, String>> thirds = NDArrayCreationUtil.getAll3dTestArraysWithShape(12345, 1, 3, 4); for (Pair<INDArray, String> f : firsts) { for (Pair<INDArray, String> s : seconds) { for (Pair<INDArray, String> t : thirds) { INDArray f2 = f.getFirst().assign(first); INDArray s2 = s.getFirst().assign(second); INDArray t2 = t.getFirst().assign(third); System.out.println("-------------------------------------------"); INDArray concat0 = Nd4j.concat(0, f2, s2, t2); assertEquals(exp, concat0); } } } //Concat, dim 1 second = Nd4j.linspace(24, 31, 8).reshape('c', 2, 1, 4); third = Nd4j.linspace(32, 47, 16).reshape('c', 2, 2, 4); exp = Nd4j.create(2, 3 + 1 + 2, 4); exp.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.interval(0, 3), NDArrayIndex.all()}, first); exp.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.point(3), NDArrayIndex.all()}, second); exp.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.interval(4, 6), NDArrayIndex.all()}, third); firsts = NDArrayCreationUtil.getAll3dTestArraysWithShape(12345, 2, 3, 4); seconds = NDArrayCreationUtil.getAll3dTestArraysWithShape(12345, 2, 1, 4); thirds = NDArrayCreationUtil.getAll3dTestArraysWithShape(12345, 2, 2, 4); for (Pair<INDArray, String> f : firsts) { for (Pair<INDArray, String> s : seconds) { for (Pair<INDArray, String> t : thirds) { INDArray f2 = f.getFirst().assign(first); INDArray s2 = s.getFirst().assign(second); INDArray t2 = t.getFirst().assign(third); INDArray concat1 = Nd4j.concat(1, f2, s2, t2); assertEquals(exp, concat1); } } } //Concat, dim 2 second = Nd4j.linspace(24, 35, 12).reshape('c', 2, 3, 2); third = Nd4j.linspace(36, 41, 6).reshape('c', 2, 3, 1); exp = Nd4j.create(2, 3, 4 + 2 + 1); exp.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(0, 4)}, first); exp.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(4, 6)}, second); exp.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(6)}, third); firsts = NDArrayCreationUtil.getAll3dTestArraysWithShape(12345, 2, 3, 4); seconds = NDArrayCreationUtil.getAll3dTestArraysWithShape(12345, 2, 3, 2); thirds = NDArrayCreationUtil.getAll3dTestArraysWithShape(12345, 2, 3, 1); for (Pair<INDArray, String> f : firsts) { for (Pair<INDArray, String> s : seconds) { for (Pair<INDArray, String> t : thirds) { INDArray f2 = f.getFirst().assign(first); INDArray s2 = s.getFirst().assign(second); INDArray t2 = t.getFirst().assign(third); INDArray concat2 = Nd4j.concat(2, f2, s2, t2); assertEquals(exp, concat2); } } } } @Override public char ordering() { return 'f'; } }