package org.nd4j.linalg.dataset; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.preprocessor.stats.MinMaxStats; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import static org.junit.Assert.assertEquals; /** * @author Ede Meijer */ @RunWith(Parameterized.class) public class MinMaxStatsTest extends BaseNd4jTest { public MinMaxStatsTest(Nd4jBackend backend) { super(backend); } @Test public void testEnforcingNonZeroRange() { INDArray lower = Nd4j.create(new double[] {2, 3, 4, 5}); MinMaxStats stats = new MinMaxStats(lower.dup(), Nd4j.create(new double[] {8, 3, 3.9, 5 + Nd4j.EPS_THRESHOLD * 0.5})); INDArray expectedUpper = Nd4j.create( new double[] {8, 3 + Nd4j.EPS_THRESHOLD, 4 + Nd4j.EPS_THRESHOLD, 5 + Nd4j.EPS_THRESHOLD}); assertEquals(lower, stats.getLower()); assertEquals(expectedUpper, stats.getUpper()); } @Override public char ordering() { return 'c'; } }