package org.nd4j.linalg.indexing; import lombok.extern.slf4j.Slf4j; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.ops.transforms.Transforms; import static org.junit.Assert.assertEquals; /** * @author raver119@gmail.com * @author Ede Meijer */ @Slf4j @RunWith(Parameterized.class) public class TransformsTest extends BaseNd4jTest { public TransformsTest(Nd4jBackend backend) { super(backend); } @Test public void testEq1() { INDArray x = Nd4j.create(new double[] {0, 1, 2, 1}); INDArray exp = Nd4j.create(new double[] {0, 0, 1, 0}); INDArray z = x.eq(2); assertEquals(exp, z); } @Test public void testEq2() { INDArray x = Nd4j.create(new double[] {0, 1, 2, 1}); INDArray exp = Nd4j.create(new double[] {0, 0, 1, 0}); x.eqi(2); assertEquals(exp, x); } @Test public void testNEq1() { INDArray x = Nd4j.create(new double[] {0, 1, 2, 1}); INDArray exp = Nd4j.create(new double[] {1, 0, 1, 0}); INDArray z = x.neq(1); assertEquals(exp, z); } @Test public void testLT1() { INDArray x = Nd4j.create(new double[] {0, 1, 2, 1}); INDArray exp = Nd4j.create(new double[] {1, 1, 0, 1}); INDArray z = x.lt(2); assertEquals(exp, z); } @Test public void testLTE1() { INDArray x = Nd4j.create(new double[] {0, 1, 2, 1}); INDArray exp = Nd4j.create(new double[] {1, 1, 1, 1}); x.ltei(2); assertEquals(exp, x); } @Test public void testGT1() { INDArray x = Nd4j.create(new double[] {0, 1, 2, 4}); INDArray exp = Nd4j.create(new double[] {0, 0, 1, 1}); INDArray z = x.gt(1); assertEquals(exp, z); } @Test public void testGTE1() { INDArray x = Nd4j.create(new double[] {0, 1, 2, 4}); INDArray exp = Nd4j.create(new double[] {0, 0, 1, 1}); x.gtei(2); assertEquals(exp, x); } @Test public void testScalarMinMax1() { INDArray x = Nd4j.create(new double[] {1, 3, 5, 7}); INDArray xCopy = x.dup(); INDArray exp1 = Nd4j.create(new double[] {1, 3, 5, 7}); INDArray exp2 = Nd4j.create(new double[] {1e-5, 1e-5, 1e-5, 1e-5}); INDArray z1 = Transforms.max(x, Nd4j.EPS_THRESHOLD, true); INDArray z2 = Transforms.min(x, Nd4j.EPS_THRESHOLD, true); assertEquals(exp1, z1); assertEquals(exp2, z2); // Assert that x was not modified assertEquals(x, xCopy); INDArray exp3 = Nd4j.create(new double[] {10, 10, 10, 10}); Transforms.max(x, 10, false); assertEquals(x, exp3); Transforms.min(x, Nd4j.EPS_THRESHOLD, false); assertEquals(x, exp2); } @Test public void testArrayMinMax() { INDArray x = Nd4j.create(new double[] {1, 3, 5, 7}); INDArray y = Nd4j.create(new double[] {2, 2, 6, 6}); INDArray xCopy = x.dup(); INDArray yCopy = y.dup(); INDArray expMax = Nd4j.create(new double[] {2, 3, 6, 7}); INDArray expMin = Nd4j.create(new double[] {1, 2, 5, 6}); INDArray z1 = Transforms.max(x, y, true); INDArray z2 = Transforms.min(x, y, true); assertEquals(expMax, z1); assertEquals(expMin, z2); // Assert that x was not modified assertEquals(xCopy, x); Transforms.max(x, y, false); // Assert that x was modified assertEquals(expMax, x); // Assert that y was not modified assertEquals(yCopy, y); // Reset the modified x x = xCopy.dup(); Transforms.min(x, y, false); // Assert that X was modified assertEquals(expMin, x); // Assert that y was not modified assertEquals(yCopy, y); } @Test public void testAnd1() { INDArray x = Nd4j.create(new double[] {0, 0, 1, 0, 0}); INDArray y = Nd4j.create(new double[] {0, 0, 1, 1, 0}); INDArray z = Transforms.and(x, y); assertEquals(x, z); } @Test public void testOr1() { INDArray x = Nd4j.create(new double[] {0, 0, 1, 0, 0}); INDArray y = Nd4j.create(new double[] {0, 0, 1, 1, 0}); INDArray z = Transforms.or(x, y); assertEquals(y, z); } @Test public void testXor1() { INDArray x = Nd4j.create(new double[] {0, 0, 1, 0, 0}); INDArray y = Nd4j.create(new double[] {0, 0, 1, 1, 0}); INDArray exp = Nd4j.create(new double[] {0, 0, 0, 1, 0}); INDArray z = Transforms.xor(x, y); assertEquals(exp, z); } @Test public void testNot1() { INDArray x = Nd4j.create(new double[] {0, 0, 1, 0, 0}); INDArray exp = Nd4j.create(new double[] {1, 1, 0, 1, 1}); INDArray z = Transforms.not(x); assertEquals(exp, z); } @Override public char ordering() { return 'c'; } }