/*- * * * Copyright 2015 Skymind,Inc. * * * * Licensed under the Apache License, Version 2.0 (the "License"); * * you may not use this file except in compliance with the License. * * You may obtain a copy of the License at * * * * http://www.apache.org/licenses/LICENSE-2.0 * * * * Unless required by applicable law or agreed to in writing, software * * distributed under the License is distributed on an "AS IS" BASIS, * * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * * See the License for the specific language governing permissions and * * limitations under the License. * * */ package org.nd4j.linalg.ops; import lombok.extern.slf4j.Slf4j; import org.apache.commons.math3.util.Pair; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.util.DataTypeUtil; import org.nd4j.linalg.api.iter.NdIndexIterator; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.IndexAccumulation; import org.nd4j.linalg.api.ops.TadCollapseAccumulation; import org.nd4j.linalg.api.ops.exception.IllegalOpException; import org.nd4j.linalg.api.ops.executioner.GridExecutioner; import org.nd4j.linalg.api.ops.executioner.OpExecutioner; import org.nd4j.linalg.api.ops.impl.accum.*; import org.nd4j.linalg.api.ops.impl.accum.distances.EuclideanDistance; import org.nd4j.linalg.api.ops.impl.accum.distances.ManhattanDistance; import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMulOp; import org.nd4j.linalg.api.ops.impl.indexaccum.IMax; import org.nd4j.linalg.api.ops.impl.indexaccum.IMin; import org.nd4j.linalg.api.ops.impl.scalar.ScalarAdd; import org.nd4j.linalg.api.ops.impl.scalar.ScalarMax; import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarGreaterThan; import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarLessThan; import org.nd4j.linalg.api.ops.impl.transforms.*; import org.nd4j.linalg.api.ops.impl.transforms.arithmetic.AddOp; import org.nd4j.linalg.api.ops.impl.transforms.arithmetic.MulOp; import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.indexing.INDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.ops.transforms.Transforms; import org.nd4j.linalg.util.ArrayUtil; import java.util.*; import static org.junit.Assert.*; import static org.nd4j.linalg.indexing.NDArrayIndex.all; import static org.nd4j.linalg.indexing.NDArrayIndex.point; /** * Created by agibsonccc on 2/22/15. */ @Slf4j @RunWith(Parameterized.class) public class OpExecutionerTestsC extends BaseNd4jTest { public OpExecutionerTestsC(Nd4jBackend backend) { super(backend); } @Test public void testBroadcastMultiDim() { INDArray data = Nd4j.linspace(1, 30, 30).reshape(2, 3, 5); System.out.println(data); INDArray mask = Nd4j.create(new double[][] {{1.00, 1.00, 1.00, 1.00, 1.00}, {1.00, 1.00, 1.00, 0.00, 0.00}}); Nd4j.getExecutioner().exec(new BroadcastMulOp(data, mask, data, 0, 2)); INDArray assertion = Nd4j.create(new double[] {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 0.0, 0.0, 21.0, 22.0, 23.0, 0.0, 0.0, 26.0, 27.0, 28.0, 0.0, 0.0}).reshape(2, 3, 5); assertEquals(assertion, data); } @Test public void testCosineSimilarity() { INDArray vec1 = Nd4j.create(new float[] {1, 2, 3, 4, 5}); INDArray vec2 = Nd4j.create(new float[] {1, 2, 3, 4, 5}); double sim = Transforms.cosineSim(vec1, vec2); assertEquals(getFailureMessage(), 1, sim, 1e-1); } @Test public void testLog() { INDArray log = Nd4j.linspace(1, 6, 6); INDArray transformed = Transforms.log(log); INDArray assertion = Nd4j.create(new double[] {0., 0.69314718, 1.09861229, 1.38629436, 1.60943791, 1.79175947}); assertEquals(assertion, transformed); } @Test public void testNorm1AlongDimension() { INDArray arr = Nd4j.linspace(1, 8, 8).reshape(2, 4); INDArray arrNorm1 = arr.norm2(1); INDArray assertion = Nd4j.create(new double[] {5.47722558, 13.19090596}); assertEquals(assertion, arrNorm1); } @Test public void testEuclideanDistance() { INDArray arr = Nd4j.create(new double[] {55, 55}); INDArray arr2 = Nd4j.create(new double[] {60, 60}); double result = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(arr, arr2)).currentResult() .doubleValue(); assertEquals(getFailureMessage(), 7.0710678118654755, result, 1e-1); } @Test public void testScalarMaxOp() { INDArray scalarMax = Nd4j.linspace(1, 6, 6).negi(); INDArray postMax = Nd4j.ones(6); Nd4j.getExecutioner().exec(new ScalarMax(scalarMax, 1)); assertEquals(getFailureMessage(), scalarMax, postMax); } @Test public void testSetRange() { INDArray linspace = Nd4j.linspace(1, 4, 4); Nd4j.getExecutioner().exec(new SetRange(linspace, 0, 1)); for (int i = 0; i < linspace.length(); i++) { double val = linspace.getDouble(i); assertTrue(getFailureMessage(), val >= 0 && val <= 1); } INDArray linspace2 = Nd4j.linspace(1, 4, 4); Nd4j.getExecutioner().exec(new SetRange(linspace2, 2, 4)); for (int i = 0; i < linspace2.length(); i++) { double val = linspace2.getDouble(i); assertTrue(getFailureMessage(), val >= 2 && val <= 4); } } @Test public void testNormMax() { INDArray arr = Nd4j.create(new float[] {1, 2, 3, 4}); double normMax = Nd4j.getExecutioner().execAndReturn(new NormMax(arr)).currentResult().doubleValue(); assertEquals(getFailureMessage(), 4, normMax, 1e-1); } @Test public void testNorm2() { INDArray arr = Nd4j.create(new float[] {1, 2, 3, 4}); double norm2 = Nd4j.getExecutioner().execAndReturn(new Norm2(arr)).currentResult().doubleValue(); assertEquals(getFailureMessage(), 5.4772255750516612, norm2, 1e-1); } @Test public void testAdd() { OpExecutioner opExecutioner = Nd4j.getExecutioner(); INDArray x = Nd4j.ones(5); INDArray xDup = x.dup(); INDArray solution = Nd4j.valueArrayOf(5, 2.0); opExecutioner.exec(new AddOp(x, xDup, x)); assertEquals(getFailureMessage(), solution, x); } @Test public void testMul() { OpExecutioner opExecutioner = Nd4j.getExecutioner(); INDArray x = Nd4j.ones(5); INDArray xDup = x.dup(); INDArray solution = Nd4j.valueArrayOf(5, 1.0); opExecutioner.exec(new MulOp(x, xDup, x)); assertEquals(solution, x); } @Test public void testExecutioner() throws IllegalOpException { OpExecutioner opExecutioner = Nd4j.getExecutioner(); INDArray x = Nd4j.ones(5); INDArray xDup = x.dup(); INDArray solution = Nd4j.valueArrayOf(5, 2.0); opExecutioner.exec(new AddOp(x, xDup, x)); assertEquals(getFailureMessage(), solution, x); Sum acc = new Sum(x.dup()); opExecutioner.exec(acc); assertEquals(getFailureMessage(), 10.0, acc.currentResult().doubleValue(), 1e-1); Prod prod = new Prod(x.dup()); opExecutioner.exec(prod); assertEquals(getFailureMessage(), 32.0, prod.currentResult().doubleValue(), 1e-1); } @Test public void testMaxMin() { OpExecutioner opExecutioner = Nd4j.getExecutioner(); INDArray x = Nd4j.linspace(1, 5, 5); Max max = new Max(x); opExecutioner.exec(max); assertEquals(5, max.currentResult().doubleValue(), 1e-1); Min min = new Min(x); opExecutioner.exec(min); assertEquals(1, min.currentResult().doubleValue(), 1e-1); } @Test public void testProd() { INDArray linspace = Nd4j.linspace(1, 6, 6); Prod prod = new Prod(linspace); double prod2 = Nd4j.getExecutioner().execAndReturn(prod).currentResult().doubleValue(); assertEquals(720, prod2, 1e-1); } @Test public void testSum() { INDArray linspace = Nd4j.linspace(1, 6, 6); Sum sum = new Sum(linspace); double sum2 = Nd4j.getExecutioner().execAndReturn(sum).getFinalResult().doubleValue(); assertEquals(21, sum2, 1e-1); INDArray matrixSums = linspace.reshape(2, 3); INDArray rowSums = matrixSums.sum(1); assertEquals(Nd4j.create(new double[] {6, 15}), rowSums); } @Test public void testDescriptiveStatsDouble() { OpExecutioner opExecutioner = Nd4j.getExecutioner(); INDArray x = Nd4j.linspace(1, 5, 5); Mean mean = new Mean(x); opExecutioner.exec(mean); assertEquals(3.0, mean.currentResult().doubleValue(), 1e-1); Variance variance = new Variance(x.dup(), true); opExecutioner.exec(variance); assertEquals(getFailureMessage(), 2.5, variance.currentResult().doubleValue(), 1e-1); } @Test public void testDescriptiveStats() { OpExecutioner opExecutioner = Nd4j.getExecutioner(); INDArray x = Nd4j.linspace(1, 5, 5); Mean mean = new Mean(x); opExecutioner.exec(mean); assertEquals(getFailureMessage(), 3.0, mean.currentResult().doubleValue(), 1e-1); Variance variance = new Variance(x.dup(), true); opExecutioner.exec(variance); assertEquals(getFailureMessage(), 2.5, variance.currentResult().doubleValue(), 1e-1); } @Test public void testRowSoftmax() { OpExecutioner opExecutioner = Nd4j.getExecutioner(); INDArray arr = Nd4j.linspace(1, 6, 6); SoftMax softMax = new SoftMax(arr); opExecutioner.exec(softMax); assertEquals(getFailureMessage(), 1.0, softMax.z().sumNumber().doubleValue(), 1e-1); } @Test public void testAddiRowVector() { INDArray arr = Nd4j.linspace(1, 6, 6).reshape(2, 3); INDArray arr2 = Nd4j.linspace(1, 3, 3); INDArray assertion = Nd4j.create(new double[] {2, 4, 6, 5, 7, 9}).reshape(2, 3); INDArray test = arr.addRowVector(arr2); assertEquals(assertion, test); } @Test public void testTad() { INDArray arr = Nd4j.linspace(1, 12, 12).reshape(2, 3, 2); for (int i = 0; i < arr.tensorssAlongDimension(0); i++) { System.out.println(arr.tensorAlongDimension(i, 0)); } } @Test public void testPow() { INDArray oneThroughSix = Nd4j.linspace(1, 6, 6); Pow pow = new Pow(oneThroughSix, 2); Nd4j.getExecutioner().exec(pow); INDArray answer = Nd4j.create(new float[] {1, 4, 9, 16, 25, 36}); assertEquals(getFailureMessage(), answer, pow.z()); } @Test public void testComparisonOps() { INDArray linspace = Nd4j.linspace(1, 6, 6); INDArray ones = Nd4j.ones(6); INDArray zeros = Nd4j.zeros(6); assertEquals(ones, Nd4j.getExecutioner().execAndReturn(new ScalarGreaterThan(linspace, 0))); assertEquals(zeros, Nd4j.getExecutioner().execAndReturn(new ScalarGreaterThan(linspace, 7))); assertEquals(zeros, Nd4j.getExecutioner().execAndReturn(new ScalarLessThan(linspace, 0))); assertEquals(ones, Nd4j.getExecutioner().execAndReturn(new ScalarLessThan(linspace, 7))); } @Test public void testScalarArithmetic() { INDArray linspace = Nd4j.linspace(1, 6, 6); INDArray plusOne = Nd4j.linspace(2, 7, 6); Nd4j.getExecutioner().exec(new ScalarAdd(linspace, 1)); assertEquals(plusOne, linspace); } @Test public void testDimensionMax() { INDArray linspace = Nd4j.linspace(1, 6, 6).reshape(2, 3); int axis = 0; INDArray row = linspace.slice(axis); Max max = new Max(row); double max2 = Nd4j.getExecutioner().execAndReturn(max).currentResult().doubleValue(); assertEquals(3.0, max2, 1e-1); Min min = new Min(row); double min2 = Nd4j.getExecutioner().execAndReturn(min).currentResult().doubleValue(); assertEquals(1.0, min2, 1e-1); Max matrixMax = new Max(linspace); INDArray exec2 = Nd4j.getExecutioner().exec(matrixMax, 1); assertEquals(Nd4j.create(new double[] {3, 6}), exec2); } @Test public void testStridedLog() { OpExecutioner opExecutioner = Nd4j.getExecutioner(); INDArray arr = Nd4j.linspace(1, 6, 6).reshape(2, 3); INDArray slice = arr.slice(0); Log exp = new Log(slice); opExecutioner.exec(exp); INDArray assertion = Nd4j.create(Nd4j.createBuffer(new double[] {0.0, 0.6931471824645996, 1.0986123085021973})); assertEquals(getFailureMessage(), assertion, slice); } @Test public void testStridedExp() { OpExecutioner opExecutioner = Nd4j.getExecutioner(); INDArray arr = Nd4j.linspace(1, 6, 6).reshape(2, 3); INDArray slice = arr.slice(0); float[] expected = new float[slice.length()]; for (int i = 0; i < slice.length(); i++) expected[i] = (float) Math.exp(slice.getDouble(i)); Exp exp = new Exp(slice); opExecutioner.exec(exp); assertEquals(getFailureMessage(), Nd4j.create(Nd4j.createBuffer(expected)), slice); } @Test public void testSoftMax() { OpExecutioner opExecutioner = Nd4j.getExecutioner(); INDArray arr = Nd4j.linspace(1, 6, 6); SoftMax softMax = new SoftMax(arr); opExecutioner.exec(softMax); assertEquals(getFailureMessage(), 1.0, softMax.z().sumNumber().doubleValue(), 1e-1); INDArray linspace = Nd4j.linspace(1, 6, 6).reshape(2, 3); SoftMax softmax = new SoftMax(linspace.dup()); Nd4j.getExecutioner().exec(softmax); assertEquals(linspace.rows(), softmax.z().sumNumber().doubleValue(), 1e-1); } @Test public void testDimensionSoftMax() { INDArray linspace = Nd4j.linspace(1, 6, 6).reshape(2, 3); SoftMax max = new SoftMax(linspace); Nd4j.getExecutioner().exec(max, 1); linspace.assign(max.z()); assertEquals(getFailureMessage(), linspace.getRow(0).sumNumber().doubleValue(), 1.0, 1e-1); } @Test public void testColumnMean() { INDArray twoByThree = Nd4j.linspace(1, 4, 4).reshape(2, 2); INDArray columnMean = twoByThree.mean(0); INDArray assertion = Nd4j.create(new float[] {2, 3}); assertEquals(assertion, columnMean); } @Test public void testColumnVar() { INDArray twoByThree = Nd4j.linspace(1, 600, 600).reshape(150, 4); INDArray columnStd = twoByThree.var(0); INDArray assertion = Nd4j.create(new float[] {30200f, 30200f, 30200f, 30200f}); assertEquals(assertion, columnStd); } @Test public void testColumnStd() { Nd4j.MAX_ELEMENTS_PER_SLICE = Integer.MAX_VALUE; Nd4j.MAX_SLICES_TO_PRINT = Integer.MAX_VALUE; INDArray twoByThree = Nd4j.linspace(1, 600, 600).reshape(150, 4); INDArray columnStd = twoByThree.std(0); INDArray assertion = Nd4j.create(new float[] {173.78147196982766f, 173.78147196982766f, 173.78147196982766f, 173.78147196982766f}); assertEquals(assertion, columnStd); } @Test public void testDim1() { INDArray sum = Nd4j.linspace(1, 2, 2).reshape(2, 1); INDArray same = sum.dup(); assertEquals(same.sum(1), sum); } @Test public void testIMax() { INDArray arr = Nd4j.linspace(1, 10, 10); IMax imax = new IMax(arr); assertEquals(9, ((IndexAccumulation) Nd4j.getExecutioner().exec(imax)).getFinalResult()); arr.muli(-1); imax = new IMax(arr); int maxIdx = ((IndexAccumulation) Nd4j.getExecutioner().exec(imax)).getFinalResult(); assertEquals(0, maxIdx); } @Test public void testIMin() { INDArray arr = Nd4j.linspace(1, 10, 10); IMin imin = new IMin(arr); assertEquals(0, ((IndexAccumulation) Nd4j.getExecutioner().exec(imin)).getFinalResult()); arr.muli(-1); imin = new IMin(arr); int minIdx = ((IndexAccumulation) Nd4j.getExecutioner().exec(imin)).getFinalResult(); assertEquals(9, minIdx); } @Test public void testMeanSumSimple() { System.out.println("3d"); INDArray arr = Nd4j.ones(1, 4, 4); assertEquals(Nd4j.ones(1), arr.mean(1, 2)); assertEquals(Nd4j.ones(1).muli(16), arr.sum(1, 2)); System.out.println("4d"); INDArray arr4 = Nd4j.ones(1, 1, 4, 4); INDArray arr4m = arr4.mean(2, 3); INDArray arr4s = arr4.sum(2, 3); for (int i = 0; i < arr4m.length(); i++) assertEquals(arr4m.getDouble(i), 1, 1e-1); for (int i = 0; i < arr4s.length(); i++) assertEquals(arr4s.getDouble(i), 16, 1e-1); System.out.println("5d"); INDArray arr5 = Nd4j.ones(1, 1, 4, 4, 4); INDArray arr5s = arr5.sum(2, 3); for (int i = 0; i < arr5s.length(); i++) assertEquals(arr5s.getDouble(i), 16, 1e-1); INDArray arr5m = arr5.mean(2, 3); for (int i = 0; i < arr5m.length(); i++) assertEquals(1, arr5m.getDouble(i), 1e-1); System.out.println("6d"); INDArray arr6 = Nd4j.ones(1, 1, 4, 4, 4, 4); INDArray arr6m = arr6.mean(2, 3); for (int i = 0; i < arr6m.length(); i++) assertEquals(arr6m.getDouble(i), 1, 1e-1); INDArray arr6s = arr6.sum(2, 3); for (int i = 0; i < arr6s.length(); i++) assertEquals(arr6s.getDouble(i), 16, 1e-1); } @Test public void testSum6d() { INDArray arr6 = Nd4j.ones(1, 1, 4, 4, 4, 4); INDArray arr6s = arr6.sum(2, 3); for (int i = 0; i < arr6s.length(); i++) assertEquals(16, arr6s.getDouble(i), 1e-1); } @Test public void testMean() { int[] shape = new int[] {1, 2, 2, 2, 2, 2}; int len = ArrayUtil.prod(shape); INDArray val = Nd4j.linspace(1, len, len).reshape('c', shape); /** * Failure comes from the lack of a jump * when doing tad offset in c++ * * We need to jump from the last element rather than the * first for the next element. * * This happens when the index for a tad is >= the * stride[0] * * When the index is >= a stride[0] then you take * the offset at the end of the tad and use that + * (possibly the last stride?) * to get to the next offset. * * In order to get to the last element for a jump, just iterate * over the tad (coordinate wise) to get the coordinate pair + * offset at which to do compute. * * Another possible solution is to create an initialize pointer * method that will just set up the tad pointer directly. * Right now it is a simplistic base pointer + offset that * we could turn in to an init method instead. * This would allow use to use coordinate based techniques * on the pointer directly. The proposal here * would then be turning tad offset given an index * in to a pointer initialization method which * will auto insert the pointer at the right index. */ INDArray sum = val.sum(2, 3); double[] assertionData = new double[] {28.0, 32.0, 36.0, 40.0, 92.0, 96.0, 100.0, 104.0}; INDArray avgExpected = Nd4j.create(assertionData).reshape(1, 2, 2, 2); assertEquals(avgExpected, sum); } @Test public void testSum5d() throws Exception { System.out.println("5d"); INDArray arr5 = Nd4j.ones(1, 1, 4, 4, 4); INDArray arr5s = arr5.sum(2, 3); Thread.sleep(1000); System.out.println("5d length: " + arr5s.length()); for (int i = 0; i < arr5s.length(); i++) assertEquals(16, arr5s.getDouble(i), 1e-1); INDArray arrF = Nd4j.ones(1, 1, 4, 4, 4); System.out.println("A: " + arrF); } @Test public void testOneMinus() { INDArray in = Nd4j.linspace(1, 3, 3); INDArray out = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("timesoneminus", in)); //Expect: 0, -2, -6 -> from 1*(1-1), 2*(1-2), 3*(1-3). Getting: [0,0,0] INDArray exp = Nd4j.create(new double[] {0, -2.0, -6.0}); assertEquals(out, exp); } @Test public void testReductionIndex() { Map<Integer, Integer> assertionMap = new HashMap<>(); assertionMap.put(0, 0); assertionMap.put(1, 0); assertionMap.put(2, 0); assertionMap.put(3, 1); assertionMap.put(4, 1); assertionMap.put(5, 1); assertionMap.put(6, 2); assertionMap.put(7, 2); assertionMap.put(8, 2); assertionMap.put(9, 3); assertionMap.put(10, 3); assertionMap.put(11, 3); assertionMap.put(12, 3); assertEquals(3, TadCollapseAccumulation.tadsPerReduceIndex(4, 12)); for (int i = 0; i < 12; i++) { int val = assertionMap.get(i); assertEquals(val, TadCollapseAccumulation.reductionIndexForTad(i, 4, 12)); } } @Test public void testSubColumnVector() { INDArray vec = Nd4j.linspace(1, 18, 18); INDArray matrix = vec.dup().reshape(3, 6); INDArray vector = Nd4j.create(new double[] {6, 12, 18}).reshape(3, 1); INDArray assertion = Nd4j.create(new double[] {-5.0, -4.0, -3.0, -2.0, -1.0, 0.0, -5.0, -4.0, -3.0, -2.0, -1.0, 0.0, -5.0, -4.0, -3.0, -2.0, -1.0, 0.0}, new int[] {3, 6}); INDArray test = matrix.subColumnVector(vector); assertEquals(assertion, test); } @Test public void testLogSoftmaxVector() { INDArray temp = Nd4j.create(new double[] {1.0, 2.0, 3.0, 4.0}); INDArray logsoftmax = Nd4j.getExecutioner().execAndReturn(new LogSoftMax(temp.dup())); INDArray assertion = Nd4j.create(new double[] {-3.4401898, -2.4401898, -1.4401897, -0.44018975}); assertEquals(assertion, logsoftmax); } @Test public void testSumDifferentOrder() { INDArray toAssign = Nd4j.linspace(0, 3, 4).reshape(2, 2); INDArray cOrder = Nd4j.create(new int[] {2, 2}, 'c').assign(toAssign); INDArray fOrder = Nd4j.create(new int[] {2, 2}, 'f').assign(toAssign); System.out.println(cOrder); System.out.println(cOrder.sum(0)); //[2,4] -> correct System.out.println(fOrder.sum(0)); //[2,3] -> incorrect assertEquals(cOrder, fOrder); assertEquals(cOrder.sum(0), fOrder.sum(0)); } @Test public void testLogSoftmax() { INDArray test = Nd4j.create(new double[] {-0.115370326, -0.12137828, -0.120233774, -0.12121266, -0.11363905, -0.101017155, -0.11571029, -0.116997495, -0.123033985, -0.1222254, -0.11120513, -0.11710341, -0.12319958, -0.124424405, -0.105285235, -0.08768927, -0.10296882, -0.11346505, -0.10607526, -0.10681274, -0.11604863, -0.1070115, -0.114202365, -0.11168295, -0.11615404, -0.120522454, -0.11282451, -0.11514864, -0.11681116, -0.11987897, -0.12054029, -0.112625614, -0.10337835, -0.098809384, -0.1222254, -0.11966098, -0.11500366, -0.1222254, -0.122691356, -0.1168594, -0.11369472, -0.11666928, -0.12075868, -0.10658686, -0.10251844, -0.119958505, -0.10873747, -0.12036781, -0.11125211, -0.118474, 0.07354958, 0.06268418, 0.08751996, 0.05259535, 0.07969022, 0.062334962, 0.07089297, -0.006484107, 0.0702586, 0.03601057, 0.03228142, 0.051330067, 0.048092633, 0.0753836, 0.0026741663, 0.060346458, 0.064265735, 0.03208362, 0.07322607, 0.034286126, 0.08459597, 0.040570714, 0.08494339, 0.06835921, 0.055334114, 0.06346921, 0.08284429, 0.09769646, 0.07128828, 0.0012985547, 0.033257447, 0.024084045, 0.03130147, 0.09381818, 0.062283173, 0.049273495, 0.0789609, 0.06648661, 0.030163772, 0.047266945, 0.05704684, 0.06862679, 0.04134995, 0.0029913357, 0.050757334, 0.031863946, 0.043180045, 0.053592253, -0.02633951, 0.04229047, 0.12401424, 0.1025523, 0.11914653, 0.10838079, 0.119204566, 0.120582364, 0.079642124, 0.1136303, 0.103594445, 0.12434465, 0.10481718, 0.10615024, 0.1161067, 0.101516, 0.11543929, 0.11498181, 0.1083647, 0.12498043, 0.117732316, 0.080594465, 0.12140614, 0.10168964, 0.11630502, 0.097365364, 0.11659742, 0.11525785, 0.095346555, 0.095523514, 0.1145297, 0.10820676, 0.113681756, 0.12088448, 0.11661095, 0.09196416, 0.09367608, 0.12396194, 0.11715822, 0.10781161, 0.09206241, 0.11529953, 0.12193694, 0.11471913, 0.1025523, 0.12246918, 0.12278436, 0.11647938, 0.09907566, 0.10939402, 0.11121245, 0.09931412, -0.2015398, -0.19392101, -0.19934568, -0.19083071, -0.20022182, -0.18812077, -0.19819336, -0.19751601, -0.18787658, -0.1910854, -0.19982933, -0.19259657, -0.1910668, -0.19623408, -0.20643783, -0.17979786, -0.20085241, -0.20226628, -0.1943775, -0.19513902, -0.1944603, -0.19675966, -0.20814213, -0.19372807, -0.18230462, -0.18796724, -0.19594413, -0.19937015, -0.20221426, -0.1900377, -0.18905015, -0.20246184, -0.18973471, -0.1917036, -0.1910854, -0.2045007, -0.20772256, -0.1910854, -0.19349803, -0.19836159, -0.20438254, -0.16650572, -0.19694945, -0.19511227, -0.18056169, -0.19521528, -0.19218414, -0.19556037, -0.1989097, -0.19989866, 0.110895164, 0.09209204, 0.13636513, 0.09708423, 0.12663901, 0.11280878, 0.10437618, 0.008251642, 0.11656475, 0.062448665, 0.07663319, 0.076713376, 0.09773914, 0.1284772, 0.0019391886, 0.08873351, 0.10645666, 0.06874694, 0.12830636, 0.069761865, 0.12597786, 0.064558044, 0.14945637, 0.12600589, 0.08889626, 0.096229844, 0.13689923, 0.15111938, 0.11476847, 0.012906413, 0.06886689, 0.05653629, 0.056540295, 0.1647724, 0.1054803, 0.06795046, 0.12039944, 0.11954296, 0.052694272, 0.085520394, 0.110611565, 0.11398453, 0.07550961, 0.023511963, 0.090924345, 0.0600122, 0.07526812, 0.088270955, -0.03518031, 0.073293336, 0.17944553, 0.16982275, 0.1886539, 0.18693338, 0.18788463, 0.2058602, 0.13861835, 0.20437749, 0.18895163, 0.16544276, 0.149991, 0.17463979, 0.17583887, 0.16696452, 0.16749835, 0.1592365, 0.17954215, 0.1818188, 0.21207899, 0.15266286, 0.17395115, 0.15906107, 0.21057771, 0.15467106, 0.17414747, 0.19151127, 0.14792846, 0.14762704, 0.1860418, 0.18808068, 0.19654934, 0.17514904, 0.18510495, 0.16045007, 0.18320344, 0.18669076, 0.16069236, 0.17718756, 0.14080223, 0.1681495, 0.17300002, 0.1528326, 0.16982275, 0.1817097, 0.16696694, 0.16177535, 0.1604718, 0.16464049, 0.15210003, 0.16091338, 0.19544502, 0.1334315, 0.16168839, 0.11322618, 0.19517533, 0.18929626, 0.17545204, 0.1665815, 0.09131178, 0.11004268, 0.20550796, 0.13831247, 0.10610545, 0.12289211, 0.27147663, 0.20504008, 0.2518754, 0.20981932, 0.20138234, 0.19962592, 0.15790789, 0.20949593, 0.23528637, 0.18096939, 0.08758456, 0.10911943, 0.18139273, 0.18525626, 0.19391479, 0.11438076, 0.1093913, 0.22006766, 0.18334126, 0.21811387, 0.11004268, 0.19371085, 0.23279056, 0.11004268, 0.11990581, 0.17242423, 0.21975593, 0.046734467, 0.1444371, 0.20759591, 0.13962208, 0.14867997, 0.17288592, 0.14028637, 0.19978605, 0.1737019, -0.038705423, -0.03880039, -0.060744748, 0.005578369, -0.026154364, -0.09166601, -0.061155446, 0.008943805, -0.04777039, -0.012912485, -0.010861377, -0.01913654, -0.0061141956, -0.09119834, 0.034481876, -0.008210908, -0.09062711, -0.0464008, -0.0038113478, -0.006515413, -0.06737334, 0.022068182, -0.078238964, -0.10467487, -0.012385059, -0.008899481, -0.0507185, -0.0612416, -0.05302817, 0.03657996, 0.0040081483, 0.0017336496, 0.00966107, -0.13457696, -0.106228024, -0.05810899, -0.042826205, -0.004804179, -0.054947495, -0.0023088162, -0.083174944, -0.0812491, 0.0012216767, 0.017188948, -0.0416347, -0.0750825, -0.052436177, -0.028371494, 0.07799446, -0.02655019, -0.04801802, -0.11302035, -0.114139326, -0.17401277, -0.11443192, -0.19375448, -0.08697115, -0.22462566, -0.18594599, 0.029962104, -0.03072077, -0.10795037, -0.0687454, -0.08853653, -0.02800453, -0.0044006817, -0.14119355, -0.057319514, -0.23839943, -0.09940908, -0.03132951, -0.07696326, -0.23962279, -0.05578459, -0.073864885, -0.16175121, -0.046830498, -0.071334355, -0.12525235, -0.1762308, -0.17853433, -0.05481769, -0.10788009, -0.12848935, -0.21946594, -0.07054761, -0.0043790466, -0.1421547, -0.062456187, -0.038439218, -0.01970637, 0.04187341, -0.11302035, -0.06571084, 0.012916437, 0.008474918, -0.058553338, -0.05822342, -0.0072570713, -0.117029555}, new int[] {150, 3}, 'c'); INDArray assertion = Nd4j.create(new double[] {-1.0949919, -1.1009998, -1.0998554, -1.1079034, -1.1003298, -1.0877079, -1.0957471, -1.0970343, -1.1030709, -1.1040032, -1.0929829, -1.0988811, -1.1042137, -1.1054386, -1.0862994, -1.0849832, -1.1002628, -1.110759, -1.0950522, -1.0957897, -1.1050256, -1.0946627, -1.1018535, -1.0993341, -1.098271, -1.1026394, -1.0949415, -1.0964833, -1.0981458, -1.1012137, -1.1069958, -1.0990812, -1.0898339, -1.0839114, -1.1073275, -1.104763, -1.0936487, -1.1008704, -1.1013364, -1.0997316, -1.0965669, -1.0995414, -1.1094468, -1.0952749, -1.0912066, -1.1022308, -1.0910097, -1.10264, -1.1618325, -1.1690543, -0.97703075, -1.1036359, -1.0788001, -1.1137247, -1.0899199, -1.1072751, -1.0987172, -1.13885, -1.0621073, -1.0963553, -1.1102668, -1.0912181, -1.0944556, -1.0698514, -1.1425608, -1.0848886, -1.0910273, -1.1232094, -1.0820669, -1.1177288, -1.0674189, -1.1114442, -1.083288, -1.0998721, -1.1128973, -1.1165779, -1.0972028, -1.0823506, -1.063015, -1.1330047, -1.1010458, -1.1247563, -1.1175389, -1.0550222, -1.0999088, -1.1129185, -1.0832311, -1.0802083, -1.1165311, -1.0994279, -1.0973024, -1.0857224, -1.1129993, -1.124351, -1.076585, -1.0954784, -1.0795343, -1.0691221, -1.1490538, -1.1465356, -1.0648118, -1.0862738, -1.0950559, -1.1058216, -1.0949979, -1.0828075, -1.1237478, -1.0897596, -1.1059818, -1.0852317, -1.1047591, -1.100405, -1.0904485, -1.1050392, -1.0961069, -1.0965644, -1.1031815, -1.0815891, -1.0888373, -1.125975, -1.0903746, -1.1100911, -1.0954757, -1.1110255, -1.0917934, -1.093133, -1.1051062, -1.1049292, -1.0859231, -1.1046766, -1.0992017, -1.0919989, -1.082815, -1.1074618, -1.10575, -1.0909829, -1.0977867, -1.1071333, -1.116398, -1.0931609, -1.0865234, -1.0971736, -1.1093404, -1.0894235, -1.0886579, -1.0949628, -1.1123666, -1.095872, -1.0940536, -1.1059519, -1.1018884, -1.0942696, -1.0996943, -1.0963987, -1.1057898, -1.0936887, -1.102288, -1.1016107, -1.0919713, -1.0952013, -1.1039451, -1.0967125, -1.0917866, -1.0969539, -1.1071577, -1.0841576, -1.1052121, -1.106626, -1.098331, -1.0990925, -1.0984138, -1.095848, -1.1072304, -1.0928164, -1.0921938, -1.0978565, -1.1058333, -1.1007886, -1.1036327, -1.0914562, -1.0939325, -1.1073442, -1.0946171, -1.0945718, -1.0939536, -1.107369, -1.1089264, -1.0922892, -1.0947019, -1.1073625, -1.1133835, -1.0755067, -1.1047142, -1.102877, -1.0883265, -1.0995088, -1.0964776, -1.0998539, -1.2125868, -1.2135757, -0.9027819, -1.115231, -1.0709579, -1.1102388, -1.0866234, -1.1004536, -1.1088862, -1.1537597, -1.0454466, -1.0995628, -1.1057239, -1.1056436, -1.0846179, -1.0445701, -1.1711081, -1.0843138, -1.0936275, -1.1313372, -1.0717777, -1.1160054, -1.0597894, -1.1212093, -1.0709189, -1.0943694, -1.131479, -1.1307347, -1.0900652, -1.0758451, -1.0502236, -1.1520857, -1.0961251, -1.1360092, -1.1360053, -1.0277731, -1.091318, -1.1288478, -1.0763988, -1.065361, -1.1322097, -1.0993836, -1.0881867, -1.0848137, -1.1232886, -1.133629, -1.0662166, -1.0971287, -1.0676445, -1.0546416, -1.1780928, -1.1673087, -1.0611565, -1.0707793, -1.0977826, -1.0995032, -1.0985519, -1.0761919, -1.1434338, -1.0776746, -1.0779177, -1.1014266, -1.1168783, -1.0964613, -1.0952622, -1.1041365, -1.0999078, -1.1081696, -1.0878639, -1.0992746, -1.0690144, -1.1284306, -1.1060928, -1.1209829, -1.0694662, -1.1174977, -1.0980213, -1.0806575, -1.1113796, -1.111681, -1.0732663, -1.0971633, -1.0886947, -1.110095, -1.0898226, -1.1144775, -1.0917242, -1.0868361, -1.1128345, -1.0963393, -1.1185608, -1.0912135, -1.086363, -1.1139716, -1.0969814, -1.0850945, -1.0947206, -1.0999122, -1.1012157, -1.0932035, -1.105744, -1.0969306, -1.0670104, -1.1290239, -1.100767, -1.1519758, -1.0700266, -1.0759057, -1.0683149, -1.0771854, -1.1524552, -1.1406635, -1.0451982, -1.1123937, -1.1621376, -1.1453509, -0.99676645, -1.1160396, -1.0692043, -1.1112604, -1.0837362, -1.0854926, -1.1272106, -1.0979462, -1.0721557, -1.1264727, -1.1378707, -1.1163357, -1.0440625, -1.0785028, -1.0698442, -1.1493783, -1.1612072, -1.0505308, -1.0872571, -1.0555155, -1.1635867, -1.0799185, -1.0216377, -1.1443856, -1.1345224, -1.0751246, -1.0277929, -1.2008144, -1.1185431, -1.0553844, -1.1233582, -1.1039788, -1.0797728, -1.1123724, -1.0159799, -1.0420641, -1.2544713, -1.1064723, -1.1284167, -1.0620935, -1.0654664, -1.1309781, -1.1004674, -1.0726943, -1.1294085, -1.0945506, -1.0974507, -1.1057259, -1.0927036, -1.1695204, -1.0438402, -1.086533, -1.1429209, -1.0986946, -1.0561051, -1.0885462, -1.149404, -1.0599625, -1.112509, -1.1389449, -1.046655, -1.0674819, -1.1093009, -1.119824, -1.1481767, -1.0585686, -1.0911404, -1.0579745, -1.050047, -1.194285, -1.136149, -1.08803, -1.0727472, -1.0830219, -1.1331651, -1.0805265, -1.1281672, -1.1262413, -1.0437706, -1.0489775, -1.1078012, -1.141249, -1.1517346, -1.1276698, -1.0213039, -1.0633042, -1.084772, -1.1497743, -1.0789506, -1.1388241, -1.0792432, -1.125674, -1.0188907, -1.1565453, -1.2263924, -1.0104843, -1.0711672, -1.1182799, -1.079075, -1.0988661, -1.0705098, -1.046906, -1.1836989, -1.0271709, -1.2082508, -1.0692605, -1.017894, -1.0635278, -1.2261873, -1.0583237, -1.0764041, -1.1642903, -1.0648377, -1.0893415, -1.1432595, -1.140007, -1.1423105, -1.0185939, -1.0557104, -1.0763197, -1.1672963, -1.09838, -1.0322114, -1.1699871, -1.1210208, -1.0970039, -1.078271, -1.0132385, -1.1681323, -1.1208228, -1.0738388, -1.0782803, -1.1453086, -1.0970035, -1.0460371, -1.1558095}, new int[] {150, 3}, 'c'); Nd4j.getExecutioner().exec(new LogSoftMax(test)); assertEquals(assertion, test); } @Test public void testSoftmax() { INDArray vec = Nd4j.linspace(1, 18, 18); INDArray matrix = vec.dup().reshape(3, 6); Nd4j.getExecutioner().exec(new SoftMax(matrix)); INDArray assertion = Nd4j.create( new double[] {0.0042697787, 0.011606461, 0.031549633, 0.085760795, 0.23312202, 0.6336913, 0.0042697787, 0.011606461, 0.031549633, 0.085760795, 0.23312202, 0.6336913, 0.0042697787, 0.011606461, 0.031549633, 0.085760795, 0.23312202, 0.6336913}, new int[] {3, 6}, 'c'); assertEquals(assertion, matrix); } @Test public void testStdev() { INDArray arr = Nd4j.create(new float[] {0.9296161f, 0.31637555f, 0.1839188f}, new int[] {1, 3}, ordering()); double stdev = arr.stdNumber().doubleValue(); double stdev2 = arr.std(1).getDouble(0); assertEquals(stdev, stdev2, 1e-3); double exp = 0.37003588676452637; assertEquals(exp, stdev, 1e-7f); } @Test public void testVariance() { INDArray arr = Nd4j.create(new float[] {0.9296161f, 0.31637555f, 0.1839188f}, new int[] {1, 3}, ordering()); double var = arr.varNumber().doubleValue(); INDArray temp = arr.var(1); double var2 = arr.var(1).getDouble(0); assertEquals(var, var2, 1e-1); double exp = 0.1369265615940094; assertEquals(exp, var, 1e-7f); } @Test public void testEpsOps() { INDArray ones = Nd4j.ones(6); double tiny = 1.000000000000001; assertTrue(ones.eps(tiny).sumNumber().doubleValue() == 6); INDArray consec = Nd4j.linspace(1, 6, 6); assertTrue(consec.eps(5).sumNumber().doubleValue() == 1); assertTrue(consec.sub(1).eps(5).sumNumber().doubleValue() == 1); assertTrue(consec.sub(1).eps(5).getDouble(0, 5) == 1); } @Test public void testVarianceSingleVsMultipleDimensions() { // this test should always run in double DataBuffer.Type type = Nd4j.dataType(); DataTypeUtil.setDTypeForContext(DataBuffer.Type.DOUBLE); Nd4j.getRandom().setSeed(12345); //Generate C order random numbers. Strides: [500,100,10,1] INDArray fourd = Nd4j.rand('c', new int[] {100, 5, 10, 10}).muli(10); INDArray twod = Shape.newShapeNoCopy(fourd, new int[] {100, 5 * 10 * 10}, false); //Population variance. These two should be identical INDArray var4 = fourd.var(false, 1, 2, 3); INDArray var2 = twod.var(false, 1); //Manual calculation of population variance, not bias corrected //https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Na.C3.AFve_algorithm double[] sums = new double[100]; double[] sumSquares = new double[100]; NdIndexIterator iter = new NdIndexIterator(fourd.shape()); while (iter.hasNext()) { int[] next = iter.next(); double d = fourd.getDouble(next); sums[next[0]] += d; sumSquares[next[0]] += d * d; } double[] manualVariance = new double[100]; int N = (fourd.length() / sums.length); for (int i = 0; i < sums.length; i++) { manualVariance[i] = (sumSquares[i] - (sums[i] * sums[i]) / N) / N; } INDArray var4bias = fourd.var(true, 1, 2, 3); INDArray var2bias = twod.var(true, 1); assertArrayEquals(var2.data().asDouble(), var4.data().asDouble(), 1e-5); assertArrayEquals(manualVariance, var2.data().asDouble(), 1e-5); assertArrayEquals(var2bias.data().asDouble(), var4bias.data().asDouble(), 1e-5); DataTypeUtil.setDTypeForContext(type); } @Test public void testHistogram1() throws Exception { INDArray x = Nd4j.linspace(1, 1000, 100000); INDArray z = Nd4j.zeros(20); INDArray xDup = x.dup(); INDArray zDup = z.dup(); INDArray zExp = Nd4j.create(20).assign(5000); Histogram histogram = new Histogram(x, z); Nd4j.getExecutioner().exec(histogram); assertEquals(xDup, x); assertNotEquals(zDup, z); log.info("bins: {}", z); assertEquals(zExp, z); } @Test public void testHistogram2() throws Exception { INDArray x = Nd4j.create(new float[] {0f, 0f, 0f, 5f, 5f, 5f, 10f, 10f, 10f}); INDArray xDup = x.dup(); INDArray zExp = Nd4j.zeros(10).putScalar(0, 3f).putScalar(5, 3f).putScalar(9, 3f); Histogram histogram = new Histogram(x, 10); Nd4j.getExecutioner().exec(histogram); INDArray z = histogram.z(); assertEquals(xDup, x); log.info("bins: {}", z); assertEquals(zExp, z); } @Test public void testEuclideanManhattanDistanceAlongDimension_Rank4() { DataBuffer.Type initialType = Nd4j.dataType(); DataTypeUtil.setDTypeForContext(DataBuffer.Type.DOUBLE); Nd4j.getRandom().setSeed(12345); INDArray firstOneExample = Nd4j.linspace(1, 8, 8).reshape('c', new int[] {1, 2, 2, 2}); INDArray secondOneExample = firstOneExample.add(1); double[] d1 = firstOneExample.data().asDouble(); double[] d2 = secondOneExample.data().asDouble(); double sumSquaredDiff = 0.0; double expManhattanDistance = 0.0; for (int i = 0; i < d1.length; i++) { double diff = d1[i] - d2[i]; sumSquaredDiff += diff * diff; expManhattanDistance += Math.abs(diff); } double expectedEuclidean = Math.sqrt(sumSquaredDiff); System.out.println("Expected, Euclidean: " + expectedEuclidean); System.out.println("Expected, Manhattan: " + expManhattanDistance); int mb = 2; INDArray firstOrig = Nd4j.create(mb, 2, 2, 2); INDArray secondOrig = Nd4j.create(mb, 2, 2, 2); for (int i = 0; i < mb; i++) { firstOrig.put(new INDArrayIndex[] {point(i), all(), all(), all()}, firstOneExample); secondOrig.put(new INDArrayIndex[] {point(i), all(), all(), all()}, secondOneExample); } for (char order : new char[] {'c', 'f'}) { INDArray first = firstOrig.dup(order); INDArray second = secondOrig.dup(order); assertEquals(firstOrig, first); assertEquals(secondOrig, second); INDArray out = Nd4j.getExecutioner().exec(new EuclideanDistance(first, second), 1, 2, 3); for (int i = 0; i < first.tensorssAlongDimension(1, 2, 3); i++) { assertEquals(first.javaTensorAlongDimension(i, 1, 2, 3).shapeInfoDataBuffer(), first.tensorAlongDimension(i, 1, 2, 3).shapeInfoDataBuffer()); } Pair<DataBuffer, DataBuffer> firstTadInfo = Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(first, 1, 2, 3); Pair<DataBuffer, DataBuffer> secondTadInfo = Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(second, 1, 2, 3); for (int i = 0; i < first.tensorssAlongDimension(1, 2, 3); i++) { assertEquals(first.javaTensorAlongDimension(i, 1, 2, 3).offset(), firstTadInfo.getSecond().getInt(i)); assertEquals(second.javaTensorAlongDimension(i, 1, 2, 3).offset(), secondTadInfo.getSecond().getInt(i)); } INDArray outManhattan = Nd4j.getExecutioner().exec(new ManhattanDistance(first, second), 1, 2, 3); System.out.println("\n\nOrder: " + order); System.out.println("Euclidean:"); System.out.println(Arrays.toString(out.getRow(0).dup().data().asDouble())); System.out.println(Arrays.toString(out.getRow(1).dup().data().asDouble())); assertEquals(out.getRow(0), out.getRow(1)); System.out.println("Manhattan:"); System.out.println(Arrays.toString(outManhattan.getRow(0).dup().data().asDouble())); System.out.println(Arrays.toString(outManhattan.getRow(1).dup().data().asDouble())); assertEquals(expManhattanDistance, outManhattan.getRow(0).getDouble(0), 1e-5); assertEquals(expectedEuclidean, out.getRow(0).getDouble(0), 1e-5); } DataTypeUtil.setDTypeForContext(initialType); } @Test public void testPile1() throws Exception { List<INDArray> arrays = new ArrayList<>(); for (int i = 0; i < 10; i++) { arrays.add(Nd4j.create(10, 10).assign(i)); } INDArray pile = Nd4j.pile(arrays); assertEquals(3, pile.rank()); for (int i = 0; i < 10; i++) { assertEquals((float) i, pile.tensorAlongDimension(i, 1,2).getDouble(0),0.01); } } @Test public void testPile2() throws Exception { List<INDArray> arrays = new ArrayList<>(); for (int i = 0; i < 10; i++) { arrays.add(Nd4j.create(10, 10, 10).assign(i)); } INDArray pile = Nd4j.pile(arrays); assertEquals(4, pile.rank()); for (int i = 0; i < 10; i++) { assertEquals((float) i, pile.tensorAlongDimension(i, 1, 2, 3).getDouble(0),0.01); } } @Test public void testMean1() throws Exception { INDArray array = Nd4j.create(32, 100, 100); for (int i = 0; i < 32; i++) { array.tensorAlongDimension(i, 1, 2).assign((float) 100 + i); } for (int i = 0; i < 32; i++) { INDArray tensor = array.tensorAlongDimension(i, 1, 2); assertEquals((float) (100 + i) * (100 * 100), tensor.sumNumber().floatValue(), 0.001f); assertEquals((float) 100 + i, tensor.meanNumber().floatValue(), 0.001f); } } @Test public void testMean2() throws Exception { INDArray array = Nd4j.create(32, 100, 100); for (int i = 0; i < 32; i++) { array.tensorAlongDimension(i, 1, 2).assign((float) 100 + i); } INDArray mean = array.mean(1, 2); for (int i = 0; i < 32; i++) { assertEquals((float) 100 + i, mean.getFloat(i), 0.001f); } } @Test public void testNorm2_1() throws Exception { INDArray array = Nd4j.rand(1769472, 9); INDArray max = array.max(1); } @Test public void testNorm2_2() throws Exception { INDArray array = Nd4j.rand(127,164, 100, 1, 1); double norm2 = array.norm2Number().doubleValue(); } @Test public void testTadEws() throws Exception { INDArray array = Nd4j.create(32, 5, 10); assertEquals(1, array.tensorAlongDimension(0, 1, 2).elementWiseStride()); } @Test public void testTear1() { List<INDArray> arrays = new ArrayList<>(); for (int i = 0; i < 10; i++) { arrays.add(Nd4j.create(10, 10).assign(i)); } INDArray pile = Nd4j.pile(arrays); INDArray[] tears = Nd4j.tear(pile, 1,2); for (int i = 0; i < 10; i++) { assertEquals((float) i, tears[i].meanNumber().floatValue(), 0.01f); } } @Override public char ordering() { return 'c'; } }