package org.nd4j.linalg.util;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.executioner.OpExecutionerUtil;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import static org.junit.Assert.*;
@RunWith(Parameterized.class)
public class TestOpExecutionerUtil extends BaseNd4jTest {
public TestOpExecutionerUtil(Nd4jBackend backend) {
super(backend);
}
@Test
public void testCanDoDirectly() {
INDArray f1_100 = Nd4j.create(new int[] {1, 100}, 'f');
INDArray f100_1 = Nd4j.create(new int[] {100, 1}, 'f');
INDArray c1_100 = Nd4j.create(new int[] {1, 100}, 'c');
INDArray c100_1 = Nd4j.create(new int[] {100, 1}, 'c');
INDArray f100_100 = Nd4j.create(new int[] {100, 100}, 'f');
INDArray c100_100 = Nd4j.create(new int[] {100, 100}, 'c');
INDArray f20_20_20 = Nd4j.create(new int[] {20, 20, 20}, 'f');
INDArray c20_20_20 = Nd4j.create(new int[] {20, 20, 20}, 'c');
//Trivial cases that can obviously be done directly
assertTrue(OpExecutionerUtil.canDoOpDirectly(f1_100));
assertTrue(OpExecutionerUtil.canDoOpDirectly(f100_1));
assertTrue(OpExecutionerUtil.canDoOpDirectly(c1_100));
assertTrue(OpExecutionerUtil.canDoOpDirectly(c100_1));
assertTrue(OpExecutionerUtil.canDoOpDirectly(f100_100));
assertTrue(OpExecutionerUtil.canDoOpDirectly(c100_100));
assertTrue(OpExecutionerUtil.canDoOpDirectly(f20_20_20));
assertTrue(OpExecutionerUtil.canDoOpDirectly(c20_20_20));
assertTrue(OpExecutionerUtil.canDoOpDirectly(f1_100, f1_100));
assertTrue(OpExecutionerUtil.canDoOpDirectly(f100_1, f100_1));
assertTrue(OpExecutionerUtil.canDoOpDirectly(f100_100, f100_100));
assertTrue(OpExecutionerUtil.canDoOpDirectly(f20_20_20, f20_20_20));
assertTrue(OpExecutionerUtil.canDoOpDirectly(c1_100, c1_100));
assertTrue(OpExecutionerUtil.canDoOpDirectly(c100_1, c100_1));
assertTrue(OpExecutionerUtil.canDoOpDirectly(c100_100, c100_100));
assertTrue(OpExecutionerUtil.canDoOpDirectly(c20_20_20, c20_20_20));
assertTrue(OpExecutionerUtil.canDoOpDirectly(f1_100, c1_100));
assertTrue(OpExecutionerUtil.canDoOpDirectly(f100_1, c100_1));
assertTrue(OpExecutionerUtil.canDoOpDirectly(f1_100, f1_100, f1_100));
assertTrue(OpExecutionerUtil.canDoOpDirectly(f100_1, f100_1, f100_1));
assertTrue(OpExecutionerUtil.canDoOpDirectly(f100_100, f100_100, f100_100));
assertTrue(OpExecutionerUtil.canDoOpDirectly(f20_20_20, f20_20_20, f20_20_20));
assertTrue(OpExecutionerUtil.canDoOpDirectly(c1_100, c1_100, c1_100));
assertTrue(OpExecutionerUtil.canDoOpDirectly(c100_1, c100_1, c100_1));
assertTrue(OpExecutionerUtil.canDoOpDirectly(c100_100, c100_100, c100_100));
assertTrue(OpExecutionerUtil.canDoOpDirectly(c20_20_20, c20_20_20, c20_20_20));
assertTrue(OpExecutionerUtil.canDoOpDirectly(f1_100, c1_100, c1_100));
assertTrue(OpExecutionerUtil.canDoOpDirectly(c1_100, c1_100, f1_100));
assertTrue(OpExecutionerUtil.canDoOpDirectly(f100_1, c100_1, c100_1));
assertTrue(OpExecutionerUtil.canDoOpDirectly(c100_1, c100_1, f100_1));
//Cases that we don't expect to be doable directly (elements don't line up in buffer)
assertFalse(OpExecutionerUtil.canDoOpDirectly(f100_100, c100_100));
assertFalse(OpExecutionerUtil.canDoOpDirectly(f20_20_20, c20_20_20));
assertFalse(OpExecutionerUtil.canDoOpDirectly(f100_100, c100_100, f100_100));
assertFalse(OpExecutionerUtil.canDoOpDirectly(c20_20_20, f20_20_20, f20_20_20));
assertFalse(OpExecutionerUtil.canDoOpDirectly(c100_100, c100_100, f100_100));
assertFalse(OpExecutionerUtil.canDoOpDirectly(f20_20_20, c20_20_20, c20_20_20));
}
@Test
public void testChooseElementWiseTensorDimension() {
INDArray f1_100 = Nd4j.create(new int[] {1, 100}, 'f');
INDArray f3_100 = Nd4j.create(new int[] {3, 100}, 'f');
INDArray f100_1 = Nd4j.create(new int[] {100, 1}, 'f');
INDArray f100_3 = Nd4j.create(new int[] {100, 3}, 'f');
INDArray c1_100 = Nd4j.create(new int[] {1, 100}, 'c');
INDArray c3_100 = Nd4j.create(new int[] {3, 100}, 'c');
INDArray c100_1 = Nd4j.create(new int[] {100, 1}, 'c');
INDArray c100_3 = Nd4j.create(new int[] {100, 3}, 'c');
//Test selection for row vectors and NDArrays that are nearly-row vectors
//In such cases, it is obvious which the best dimension is
//However, in other cases it is not immediately clear
assertEquals(OpExecutionerUtil.chooseElementWiseTensorDimension(f1_100), 1);
assertEquals(OpExecutionerUtil.chooseElementWiseTensorDimension(f3_100), 1);
assertEquals(OpExecutionerUtil.chooseElementWiseTensorDimension(f100_1), 0);
assertEquals(OpExecutionerUtil.chooseElementWiseTensorDimension(f100_3), 0);
assertEquals(OpExecutionerUtil.chooseElementWiseTensorDimension(f1_100, f1_100), 1);
assertEquals(OpExecutionerUtil.chooseElementWiseTensorDimension(f3_100, f3_100), 1);
assertEquals(OpExecutionerUtil.chooseElementWiseTensorDimension(f100_1, f100_1), 0);
assertEquals(OpExecutionerUtil.chooseElementWiseTensorDimension(f100_3, f100_3), 0);
assertEquals(OpExecutionerUtil.chooseElementWiseTensorDimension(f1_100, f1_100, f1_100), 1);
assertEquals(OpExecutionerUtil.chooseElementWiseTensorDimension(f3_100, f3_100, f3_100), 1);
assertEquals(OpExecutionerUtil.chooseElementWiseTensorDimension(f100_1, f100_1, f100_1), 0);
assertEquals(OpExecutionerUtil.chooseElementWiseTensorDimension(f100_3, f100_3, f100_3), 0);
assertEquals(OpExecutionerUtil.chooseElementWiseTensorDimension(c1_100), 1);
assertEquals(OpExecutionerUtil.chooseElementWiseTensorDimension(c3_100), 1);
assertEquals(OpExecutionerUtil.chooseElementWiseTensorDimension(c100_1), 0);
assertEquals(OpExecutionerUtil.chooseElementWiseTensorDimension(c100_3), 0);
assertEquals(OpExecutionerUtil.chooseElementWiseTensorDimension(c1_100, c1_100), 1);
assertEquals(OpExecutionerUtil.chooseElementWiseTensorDimension(c3_100, c3_100), 1);
assertEquals(OpExecutionerUtil.chooseElementWiseTensorDimension(c100_1, c100_1), 0);
assertEquals(OpExecutionerUtil.chooseElementWiseTensorDimension(c100_3, c100_3), 0);
assertEquals(OpExecutionerUtil.chooseElementWiseTensorDimension(c1_100, c1_100, c1_100), 1);
assertEquals(OpExecutionerUtil.chooseElementWiseTensorDimension(c3_100, c3_100, c3_100), 1);
assertEquals(OpExecutionerUtil.chooseElementWiseTensorDimension(c100_1, c100_1, c100_1), 0);
assertEquals(OpExecutionerUtil.chooseElementWiseTensorDimension(c100_3, c100_3, c100_3), 0);
}
@Override
public char ordering() {
return 'c';
}
}