package org.nd4j.linalg.indexing;
import com.google.common.base.Function;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.accum.MatchCondition;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndReplace;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.indexing.conditions.AbsValueGreaterThan;
import org.nd4j.linalg.indexing.conditions.Condition;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.indexing.functions.Value;
import java.util.Arrays;
import static org.junit.Assert.*;
/**
* @author raver119@gmail.com
*/
@RunWith(Parameterized.class)
public class BooleanIndexingTest extends BaseNd4jTest {
public BooleanIndexingTest(Nd4jBackend backend) {
super(backend);
}
/*
1D array checks
*/
@Test
public void testAnd1() throws Exception {
INDArray array = Nd4j.create(new float[] {1.0f, 2.0f, 3.0f, 4.0f, 5.0f});
assertTrue(BooleanIndexing.and(array, Conditions.greaterThan(0.5f)));
}
@Test
public void testAnd2() throws Exception {
INDArray array = Nd4j.create(new float[] {1.0f, 2.0f, 3.0f, 4.0f, 5.0f});
assertTrue(BooleanIndexing.and(array, Conditions.lessThan(6.0f)));
}
@Test
public void testAnd3() throws Exception {
INDArray array = Nd4j.create(new float[] {1.0f, 2.0f, 3.0f, 4.0f, 5.0f});
assertFalse(BooleanIndexing.and(array, Conditions.lessThan(5.0f)));
}
@Test
public void testAnd4() throws Exception {
INDArray array = Nd4j.create(new float[] {1.0f, 2.0f, 3.0f, 4.0f, 5.0f});
assertFalse(BooleanIndexing.and(array, Conditions.greaterThan(4.0f)));
}
@Test
public void testAnd5() throws Exception {
INDArray array = Nd4j.create(new float[] {1e-5f, 1e-5f, 1e-5f, 1e-5f, 1e-5f});
assertTrue(BooleanIndexing.and(array, Conditions.greaterThanOEqual(1e-5f)));
}
@Test
public void testAnd6() throws Exception {
INDArray array = Nd4j.create(new float[] {1e-5f, 1e-5f, 1e-5f, 1e-5f, 1e-5f});
assertFalse(BooleanIndexing.and(array, Conditions.lessThan(1e-5f)));
}
@Test
public void testAnd7() throws Exception {
INDArray array = Nd4j.create(new float[] {1e-5f, 1e-5f, 1e-5f, 1e-5f, 1e-5f});
assertTrue(BooleanIndexing.and(array, Conditions.equals(1e-5f)));
}
@Test
public void testOr1() throws Exception {
INDArray array = Nd4j.create(new float[] {1.0f, 2.0f, 3.0f, 4.0f, 5.0f});
assertTrue(BooleanIndexing.or(array, Conditions.greaterThan(3.0f)));
}
@Test
public void testOr2() throws Exception {
INDArray array = Nd4j.create(new float[] {1.0f, 2.0f, 3.0f, 4.0f, 5.0f});
assertTrue(BooleanIndexing.or(array, Conditions.lessThan(3.0f)));
}
@Test
public void testOr3() throws Exception {
INDArray array = Nd4j.create(new float[] {1.0f, 2.0f, 3.0f, 4.0f, 5.0f});
assertFalse(BooleanIndexing.or(array, Conditions.greaterThan(6.0f)));
}
@Test
public void testApplyWhere1() throws Exception {
INDArray array = Nd4j.create(new float[] {-1f, -1f, -1f, -1f, -1f});
BooleanIndexing.applyWhere(array, Conditions.lessThan(Nd4j.EPS_THRESHOLD), new Value(Nd4j.EPS_THRESHOLD));
//System.out.println("Array contains: " + Arrays.toString(array.data().asFloat()));
assertTrue(BooleanIndexing.and(array, Conditions.equals(Nd4j.EPS_THRESHOLD)));
}
@Test
public void testApplyWhere2() throws Exception {
INDArray array = Nd4j.create(new float[] {0f, 0f, 0f, 0f, 0f});
BooleanIndexing.applyWhere(array, Conditions.lessThan(1.0f), new Value(1.0f));
assertTrue(BooleanIndexing.and(array, Conditions.equals(1.0f)));
}
@Test
public void testApplyWhere3() throws Exception {
INDArray array = Nd4j.create(new float[] {1e-18f, 1e-18f, 1e-18f, 1e-18f, 1e-18f});
BooleanIndexing.applyWhere(array, Conditions.lessThan(1e-12f), new Value(1e-12f));
//System.out.println("Array contains: " + Arrays.toString(array.data().asFloat()));
assertTrue(BooleanIndexing.and(array, Conditions.equals(1e-12f)));
}
@Test
public void testApplyWhere4() throws Exception {
INDArray array = Nd4j.create(new float[] {1e-18f, Float.NaN, 1e-18f, 1e-18f, 1e-18f});
BooleanIndexing.applyWhere(array, Conditions.lessThan(1e-12f), new Value(1e-12f));
//System.out.println("Array contains: " + Arrays.toString(array.data().asFloat()));
BooleanIndexing.applyWhere(array, Conditions.isNan(), new Value(1e-16f));
System.out.println("Array contains: " + Arrays.toString(array.data().asFloat()));
assertFalse(BooleanIndexing.or(array, Conditions.isNan()));
assertTrue(BooleanIndexing.or(array, Conditions.equals(1e-12f)));
assertTrue(BooleanIndexing.or(array, Conditions.equals(1e-16f)));
}
/*
2D array checks
*/
@Test
public void test2dAnd1() throws Exception {
INDArray array = Nd4j.zeros(10, 10);
assertTrue(BooleanIndexing.and(array, Conditions.equals(0f)));
}
@Test
public void test2dAnd2() throws Exception {
INDArray array = Nd4j.zeros(10, 10);
array.slice(4).putScalar(2, 1e-5f);
System.out.println(array);
assertFalse(BooleanIndexing.and(array, Conditions.equals(0f)));
}
@Test
public void test2dAnd3() throws Exception {
INDArray array = Nd4j.zeros(10, 10);
array.slice(4).putScalar(2, 1e-5f);
assertFalse(BooleanIndexing.and(array, Conditions.greaterThan(0f)));
}
@Test
public void test2dAnd4() throws Exception {
INDArray array = Nd4j.zeros(10, 10);
array.slice(4).putScalar(2, 1e-5f);
assertTrue(BooleanIndexing.or(array, Conditions.greaterThan(1e-6f)));
}
@Test
public void test2dApplyWhere1() throws Exception {
INDArray array = Nd4j.ones(4, 4);
array.slice(3).putScalar(2, 1e-5f);
//System.out.println("Array before: " + Arrays.toString(array.data().asFloat()));
BooleanIndexing.applyWhere(array, Conditions.lessThan(1e-4f), new Value(1e-12f));
//System.out.println("Array after 1: " + Arrays.toString(array.data().asFloat()));
assertTrue(BooleanIndexing.or(array, Conditions.equals(1e-12f)));
assertTrue(BooleanIndexing.or(array, Conditions.equals(1.0f)));
assertFalse(BooleanIndexing.and(array, Conditions.equals(1e-12f)));
}
/**
* This test fails, because it highlights current mechanics on SpecifiedIndex stuff.
* Internally there's
*
* @throws Exception
*/
@Test
public void testSliceAssign1() throws Exception {
INDArray array = Nd4j.zeros(4, 4);
INDArray patch = Nd4j.create(new float[] {1e-5f, 1e-5f, 1e-5f});
INDArray slice = array.slice(1);
int[] idx = new int[] {0, 1, 3};
INDArrayIndex[] range = new INDArrayIndex[] {new SpecifiedIndex(idx)};
INDArray subarray = slice.get(range);
System.out.println("Subarray: " + Arrays.toString(subarray.data().asFloat()) + " isView: " + subarray.isView());
slice.put(range, patch);
System.out.println("Array after being patched: " + Arrays.toString(array.data().asFloat()));
assertFalse(BooleanIndexing.and(array, Conditions.equals(0f)));
}
@Test
public void testConditionalAssign1() throws Exception {
INDArray array1 = Nd4j.create(new double[] {1, 2, 3, 4, 5, 6, 7});
INDArray array2 = Nd4j.create(new double[] {7, 6, 5, 4, 3, 2, 1});
INDArray comp = Nd4j.create(new double[] {1, 2, 3, 4, 3, 2, 1});
BooleanIndexing.replaceWhere(array1, array2, Conditions.greaterThan(4));
assertEquals(comp, array1);
}
@Test
public void testCaSTransform1() throws Exception {
INDArray array = Nd4j.create(new double[] {1, 2, 0, 4, 5});
INDArray comp = Nd4j.create(new double[] {1, 2, 3, 4, 5});
Nd4j.getExecutioner().exec(new CompareAndSet(array, 3, Conditions.equals(0)));
assertEquals(comp, array);
}
@Test
public void testCaSTransform2() throws Exception {
INDArray array = Nd4j.create(new double[] {1, 2, 0, 4, 5});
INDArray comp = Nd4j.create(new double[] {3, 2, 3, 4, 5});
Nd4j.getExecutioner().exec(new CompareAndSet(array, 3.0, Conditions.lessThan(2)));
assertEquals(comp, array);
}
@Test
public void testCaSPairwiseTransform1() throws Exception {
INDArray array = Nd4j.create(new double[] {1, 2, 0, 4, 5});
INDArray comp = Nd4j.create(new double[] {1, 2, 3, 4, 5});
Nd4j.getExecutioner().exec(new CompareAndSet(array, comp, Conditions.lessThan(5)));
assertEquals(comp, array);
}
@Test
public void testCaRPairwiseTransform1() throws Exception {
INDArray array = Nd4j.create(new double[] {1, 2, 0, 4, 5});
INDArray comp = Nd4j.create(new double[] {1, 2, 3, 4, 5});
Nd4j.getExecutioner().exec(new CompareAndReplace(array, comp, Conditions.lessThan(1)));
assertEquals(comp, array);
}
@Test
public void testCaSPairwiseTransform2() throws Exception {
INDArray x = Nd4j.create(new double[] {1, 2, 0, 4, 5});
INDArray y = Nd4j.create(new double[] {2, 4, 3, 0, 5});
INDArray comp = Nd4j.create(new double[] {2, 4, 3, 4, 5});
Nd4j.getExecutioner().exec(new CompareAndSet(x, y, Conditions.epsNotEquals(0.0)));
assertEquals(comp, x);
}
@Test
public void testCaRPairwiseTransform2() throws Exception {
INDArray x = Nd4j.create(new double[] {1, 2, 0, 4, 5});
INDArray y = Nd4j.create(new double[] {2, 4, 3, 4, 5});
INDArray comp = Nd4j.create(new double[] {2, 4, 0, 4, 5});
Nd4j.getExecutioner().exec(new CompareAndReplace(x, y, Conditions.epsNotEquals(0.0)));
assertEquals(comp, x);
}
@Test
public void testCaSPairwiseTransform3() throws Exception {
INDArray x = Nd4j.create(new double[] {1, 2, 0, 4, 5});
INDArray y = Nd4j.create(new double[] {2, 4, 3, 4, 5});
INDArray comp = Nd4j.create(new double[] {2, 4, 3, 4, 5});
Nd4j.getExecutioner().exec(new CompareAndReplace(x, y, Conditions.lessThan(4)));
assertEquals(comp, x);
}
@Test
public void testCaRPairwiseTransform3() throws Exception {
INDArray x = Nd4j.create(new double[] {1, 2, 0, 4, 5});
INDArray y = Nd4j.create(new double[] {2, 4, 3, 4, 5});
INDArray comp = Nd4j.create(new double[] {2, 2, 3, 4, 5});
Nd4j.getExecutioner().exec(new CompareAndReplace(x, y, Conditions.lessThan(2)));
assertEquals(comp, x);
}
@Test
public void testMatchConditionAllDimensions1() throws Exception {
INDArray array = Nd4j.create(new double[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9});
int val = (int) Nd4j.getExecutioner().exec(new MatchCondition(array, Conditions.lessThan(5)), Integer.MAX_VALUE)
.getDouble(0);
assertEquals(5, val);
}
@Test
public void testMatchConditionAllDimensions2() throws Exception {
INDArray array = Nd4j.create(new double[] {0, 1, 2, 3, Double.NaN, 5, 6, 7, 8, 9});
int val = (int) Nd4j.getExecutioner().exec(new MatchCondition(array, Conditions.isNan()), Integer.MAX_VALUE)
.getDouble(0);
assertEquals(1, val);
}
@Test
public void testMatchConditionAllDimensions3() throws Exception {
INDArray array = Nd4j.create(new double[] {0, 1, 2, 3, Double.NEGATIVE_INFINITY, 5, 6, 7, 8, 9});
int val = (int) Nd4j.getExecutioner()
.exec(new MatchCondition(array, Conditions.isInfinite()), Integer.MAX_VALUE).getDouble(0);
assertEquals(1, val);
}
@Test
public void testAbsValueGreaterThan() {
final double threshold = 2;
Condition absValueCondition = new AbsValueGreaterThan(threshold);
Function<Number, Number> clipFn = new Function<Number, Number>() {
@Override
public Number apply(Number number) {
System.out.println("Number: " + number.doubleValue());
return (number.doubleValue() > threshold ? threshold : -threshold);
}
};
Nd4j.getRandom().setSeed(12345);
INDArray orig = Nd4j.rand(1, 20).muli(6).subi(3); //Random numbers: -3 to 3
INDArray exp = orig.dup();
INDArray after = orig.dup();
for (int i = 0; i < exp.length(); i++) {
double d = exp.getDouble(i);
if (d > threshold) {
exp.putScalar(i, threshold);
} else if (d < -threshold) {
exp.putScalar(i, -threshold);
}
}
BooleanIndexing.applyWhere(after, absValueCondition, clipFn);
System.out.println(orig);
System.out.println(exp);
System.out.println(after);
assertEquals(exp, after);
}
@Test
public void testMatchConditionAlongDimension1() throws Exception {
INDArray array = Nd4j.ones(3, 10);
array.getRow(2).assign(0.0);
boolean result[] = BooleanIndexing.and(array, Conditions.equals(0.0), 1);
boolean comp[] = new boolean[] {false, false, true};
System.out.println("Result: " + Arrays.toString(result));
assertArrayEquals(comp, result);
}
@Test
public void testMatchConditionAlongDimension2() throws Exception {
INDArray array = Nd4j.ones(3, 10);
array.getRow(2).assign(0.0).putScalar(0, 1.0);
System.out.println("Array: " + array);
boolean result[] = BooleanIndexing.or(array, Conditions.lessThan(0.9), 1);
boolean comp[] = new boolean[] {false, false, true};
System.out.println("Result: " + Arrays.toString(result));
assertArrayEquals(comp, result);
}
@Test
public void testMatchConditionAlongDimension3() throws Exception {
INDArray array = Nd4j.ones(3, 10);
array.getRow(2).assign(0.0).putScalar(0, 1.0);
boolean result[] = BooleanIndexing.and(array, Conditions.lessThan(0.0), 1);
boolean comp[] = new boolean[] {false, false, false};
System.out.println("Result: " + Arrays.toString(result));
assertArrayEquals(comp, result);
}
@Test
public void testConditionalUpdate() {
INDArray arr = Nd4j.linspace(-2, 2, 5);
INDArray ones = Nd4j.ones(5);
INDArray exp = Nd4j.create(new double[] {1, 1, 0, 1, 1});
Nd4j.getExecutioner().exec(new CompareAndSet(ones, arr, ones, Conditions.equals(0.0)));
assertEquals(exp, ones);
}
@Test
public void testFirstIndex1() {
INDArray arr = Nd4j.create(new double[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 0});
INDArray result = BooleanIndexing.firstIndex(arr, Conditions.greaterThanOrEqual(3));
assertEquals(2, result.getDouble(0), 0.0);
}
@Test
public void testFirstIndex2() {
INDArray arr = Nd4j.create(new double[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 0});
INDArray result = BooleanIndexing.firstIndex(arr, Conditions.lessThan(3));
assertEquals(0, result.getDouble(0), 0.0);
}
@Test
public void testLastIndex1() {
INDArray arr = Nd4j.create(new double[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 0});
INDArray result = BooleanIndexing.lastIndex(arr, Conditions.greaterThanOrEqual(3));
assertEquals(8, result.getDouble(0), 0.0);
}
@Test
public void testFirstIndex2D() {
INDArray arr = Nd4j.create(new double[] {1, 2, 3, 0, 1, 3, 7, 8, 9}).reshape('c', 3, 3);
INDArray result = BooleanIndexing.firstIndex(arr, Conditions.greaterThanOrEqual(2), 1);
INDArray exp = Nd4j.create(new double[] {1, 2, 0});
assertEquals(exp, result);
}
@Test
public void testLastIndex2D() {
INDArray arr = Nd4j.create(new double[] {1, 2, 3, 0, 1, 3, 7, 8, 0}).reshape('c', 3, 3);
INDArray result = BooleanIndexing.lastIndex(arr, Conditions.greaterThanOrEqual(2), 1);
INDArray exp = Nd4j.create(new double[] {2, 2, 1});
assertEquals(exp, result);
}
@Override
public char ordering() {
return 'c';
}
}