package org.nd4j.linalg; import lombok.extern.slf4j.Slf4j; import org.apache.commons.math3.util.Pair; import org.junit.After; import org.junit.Before; import org.junit.Ignore; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.util.DataTypeUtil; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import java.util.ArrayList; import java.util.List; import java.util.concurrent.atomic.AtomicBoolean; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertTrue; /** * @author raver119@gmail.com */ @Slf4j @RunWith(Parameterized.class) public class AveragingTests extends BaseNd4jTest { private final int THREADS = 16; private final int LENGTH = 5120000 * 4; DataBuffer.Type initialType; public AveragingTests(Nd4jBackend backend) { super(backend); this.initialType = Nd4j.dataType(); } @Before public void setUp() { DataTypeUtil.setDTypeForContext(DataBuffer.Type.DOUBLE); } @After public void shutUp() { DataTypeUtil.setDTypeForContext(initialType); } @Test public void testSingleDeviceAveraging1() throws Exception { INDArray array1 = Nd4j.valueArrayOf(LENGTH, 1.0); INDArray array2 = Nd4j.valueArrayOf(LENGTH, 2.0); INDArray array3 = Nd4j.valueArrayOf(LENGTH, 3.0); INDArray array4 = Nd4j.valueArrayOf(LENGTH, 4.0); INDArray array5 = Nd4j.valueArrayOf(LENGTH, 5.0); INDArray array6 = Nd4j.valueArrayOf(LENGTH, 6.0); INDArray array7 = Nd4j.valueArrayOf(LENGTH, 7.0); INDArray array8 = Nd4j.valueArrayOf(LENGTH, 8.0); INDArray array9 = Nd4j.valueArrayOf(LENGTH, 9.0); INDArray array10 = Nd4j.valueArrayOf(LENGTH, 10.0); INDArray array11 = Nd4j.valueArrayOf(LENGTH, 11.0); INDArray array12 = Nd4j.valueArrayOf(LENGTH, 12.0); INDArray array13 = Nd4j.valueArrayOf(LENGTH, 13.0); INDArray array14 = Nd4j.valueArrayOf(LENGTH, 14.0); INDArray array15 = Nd4j.valueArrayOf(LENGTH, 15.0); INDArray array16 = Nd4j.valueArrayOf(LENGTH, 16.0); long time1 = System.currentTimeMillis(); INDArray arrayMean = Nd4j.averageAndPropagate(new INDArray[] {array1, array2, array3, array4, array5, array6, array7, array8, array9, array10, array11, array12, array13, array14, array15, array16}); long time2 = System.currentTimeMillis(); System.out.println("Execution time: " + (time2 - time1)); assertNotEquals(null, arrayMean); assertEquals(8.5f, arrayMean.getFloat(12), 0.1f); assertEquals(8.5f, arrayMean.getFloat(150), 0.1f); assertEquals(8.5f, arrayMean.getFloat(475), 0.1f); assertEquals(8.5f, array1.getFloat(475), 0.1f); assertEquals(8.5f, array2.getFloat(475), 0.1f); assertEquals(8.5f, array3.getFloat(475), 0.1f); assertEquals(8.5f, array5.getFloat(475), 0.1f); assertEquals(8.5f, array16.getFloat(475), 0.1f); assertEquals(8.5, arrayMean.meanNumber().doubleValue(), 0.01); assertEquals(8.5, array1.meanNumber().doubleValue(), 0.01); assertEquals(8.5, array2.meanNumber().doubleValue(), 0.01); assertEquals(arrayMean, array16); } @Test public void testSingleDeviceAveraging2() throws Exception { INDArray exp = Nd4j.linspace(1, LENGTH, LENGTH); List<INDArray> arrays = new ArrayList<>(); for (int i = 0; i < THREADS; i++) arrays.add(exp.dup()); INDArray mean = Nd4j.averageAndPropagate(arrays); assertEquals(exp, mean); for (int i = 0; i < THREADS; i++) assertEquals(exp, arrays.get(i)); } @Override public char ordering() { return 'c'; } }