package org.nd4j.linalg; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import static org.junit.Assert.fail; /** * Created by Alex on 05/08/2016. */ @RunWith(Parameterized.class) public class InputValidationTests extends BaseNd4jTest { public InputValidationTests(Nd4jBackend backend) { super(backend); } @Override public char ordering() { return 'c'; } ///////////////////////////////////////////////////////////// ///////////////////// Broadcast Tests /////////////////////// @Test public void testInvalidColVectorOp1() { INDArray first = Nd4j.create(10, 10); INDArray col = Nd4j.create(5, 1); try { first.muliColumnVector(col); fail("Should have thrown IllegalStateException"); } catch (IllegalStateException e) { //OK } } @Test public void testInvalidColVectorOp2() { INDArray first = Nd4j.create(10, 10); INDArray col = Nd4j.create(5, 1); try { first.addColumnVector(col); fail("Should have thrown IllegalStateException"); } catch (IllegalStateException e) { //OK } } @Test public void testInvalidRowVectorOp1() { INDArray first = Nd4j.create(10, 10); INDArray row = Nd4j.create(1, 5); try { first.addiRowVector(row); fail("Should have thrown IllegalStateException"); } catch (IllegalStateException e) { //OK } } @Test public void testInvalidRowVectorOp2() { INDArray first = Nd4j.create(10, 10); INDArray row = Nd4j.create(1, 5); try { first.subRowVector(row); fail("Should have thrown IllegalStateException"); } catch (IllegalStateException e) { //OK } } }