package org.nd4j.linalg.dataset; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.TestDataSetIterator; import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization; import org.nd4j.linalg.dataset.api.preprocessor.NormalizerMinMaxScaler; import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.ops.transforms.Transforms; import java.util.ArrayList; import java.util.List; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; /** * Created by susaneraly on 11/13/16. */ @RunWith(Parameterized.class) public class NormalizerTests extends BaseNd4jTest { public NormalizerTests(Nd4jBackend backend) { super(backend); } private NormalizerStandardize stdScaler; private NormalizerMinMaxScaler minMaxScaler; private DataSet data; private int batchSize; private int batchCount; private int lastBatch; private final float thresholdPerc = 2.0f; //this is the difference in percentage! @Before public void randomData() { Nd4j.getRandom().setSeed(12345); batchSize = 13; batchCount = 20; lastBatch = batchSize / 2; INDArray origFeatures = Nd4j.rand(batchCount * batchSize + lastBatch, 10); INDArray origLabels = Nd4j.rand(batchCount * batchSize + lastBatch, 3); data = new DataSet(origFeatures, origLabels); stdScaler = new NormalizerStandardize(); minMaxScaler = new NormalizerMinMaxScaler(); } @Test public void testPreProcessors() { System.out.println("Running iterator vs non-iterator std scaler.."); double d1 = testItervsDataset(stdScaler); assertTrue(d1 + " < " + thresholdPerc, d1 < thresholdPerc); System.out.println("Running iterator vs non-iterator min max scaler.."); double d2 = testItervsDataset(minMaxScaler); assertTrue(d2 + " < " + thresholdPerc, d2 < thresholdPerc); } public float testItervsDataset(DataNormalization preProcessor) { DataSet dataCopy = data.copy(); DataSetIterator dataIter = new TestDataSetIterator(dataCopy, batchSize); preProcessor.fit(dataCopy); preProcessor.transform(dataCopy); INDArray transformA = dataCopy.getFeatures(); preProcessor.fit(dataIter); dataIter.setPreProcessor(preProcessor); DataSet next = dataIter.next(); INDArray transformB = next.getFeatures(); while (dataIter.hasNext()) { next = dataIter.next(); INDArray transformb = next.getFeatures(); transformB = Nd4j.vstack(transformB, transformb); } return Transforms.abs(transformB.div(transformA).rsub(1)).maxNumber().floatValue(); } @Test public void testMasking() { Nd4j.getRandom().setSeed(235); DataNormalization[] normalizers = new DataNormalization[] {new NormalizerMinMaxScaler(), new NormalizerStandardize()}; DataNormalization[] normalizersNoMask = new DataNormalization[] {new NormalizerMinMaxScaler(), new NormalizerStandardize()}; DataNormalization[] normalizersByRow = new DataNormalization[] {new NormalizerMinMaxScaler(), new NormalizerStandardize()}; for (int i = 0; i < normalizers.length; i++) { //First: check that normalization is the same with/without masking arrays DataNormalization norm = normalizers[i]; DataNormalization normFitSubset = normalizersNoMask[i]; DataNormalization normByRow = normalizersByRow[i]; System.out.println(norm.getClass()); INDArray arr = Nd4j.rand('c', new int[] {2, 3, 5}).muli(100).addi(100); arr.get(NDArrayIndex.point(1), NDArrayIndex.all(), NDArrayIndex.interval(3, 5)).assign(0); INDArray arrCopy = arr.dup(); INDArray arrPt1 = arr.get(NDArrayIndex.interval(0, 0, true), NDArrayIndex.all(), NDArrayIndex.all()).dup(); INDArray arrPt2 = arr.get(NDArrayIndex.interval(1, 1, true), NDArrayIndex.all(), NDArrayIndex.interval(0, 3)) .dup(); INDArray mask = Nd4j.create(new double[][] {{1, 1, 1, 1, 1}, {1, 1, 1, 0, 0}}); DataSet ds = new DataSet(arr, null, mask, null); DataSet dsCopy1 = new DataSet(arr.dup(), null, mask, null); DataSet dsCopy2 = new DataSet(arr.dup(), null, mask, null); norm.fit(ds); //Check that values aren't modified by fit op assertEquals(arrCopy, arr); List<DataSet> toFitTimeSeries1Ex = new ArrayList<>(); toFitTimeSeries1Ex.add(new DataSet(arrPt1, arrPt1)); toFitTimeSeries1Ex.add(new DataSet(arrPt2, arrPt2)); normFitSubset.fit(new TestDataSetIterator(toFitTimeSeries1Ex, 1)); List<DataSet> toFitRows = new ArrayList<>(); for (int j = 0; j < 5; j++) { INDArray row = arr.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.interval(j, j, true)) .transpose(); assertTrue(row.isRowVector()); toFitRows.add(new DataSet(row, row)); } for (int j = 0; j < 3; j++) { INDArray row = arr.get(NDArrayIndex.point(1), NDArrayIndex.all(), NDArrayIndex.interval(j, j, true)) .transpose(); assertTrue(row.isRowVector()); toFitRows.add(new DataSet(row, row)); } normByRow.fit(new TestDataSetIterator(toFitRows, 1)); norm.transform(ds); normFitSubset.transform(dsCopy1); normByRow.transform(dsCopy2); assertEquals(ds.getFeatures(), dsCopy1.getFeatures()); assertEquals(ds.getLabels(), dsCopy1.getLabels()); assertEquals(ds.getFeaturesMaskArray(), dsCopy1.getFeaturesMaskArray()); assertEquals(ds.getLabelsMaskArray(), dsCopy1.getLabelsMaskArray()); assertEquals(ds, dsCopy1); assertEquals(ds, dsCopy2); //Second: ensure time steps post normalization (and post revert) are 0.0 INDArray shouldBe0_1 = ds.getFeatureMatrix().get(NDArrayIndex.point(1), NDArrayIndex.all(), NDArrayIndex.interval(3, 5)); INDArray shouldBe0_2 = dsCopy1.getFeatureMatrix().get(NDArrayIndex.point(1), NDArrayIndex.all(), NDArrayIndex.interval(3, 5)); INDArray shouldBe0_3 = dsCopy2.getFeatureMatrix().get(NDArrayIndex.point(1), NDArrayIndex.all(), NDArrayIndex.interval(3, 5)); INDArray zeros = Nd4j.zeros(shouldBe0_1.shape()); for (int j = 0; j < 2; j++) { System.out.println(ds.getFeatureMatrix().get(NDArrayIndex.point(j), NDArrayIndex.all(), NDArrayIndex.all())); System.out.println(); } assertEquals(zeros, shouldBe0_1); assertEquals(zeros, shouldBe0_2); assertEquals(zeros, shouldBe0_3); //Check same thing after reverting: norm.revert(ds); normFitSubset.revert(dsCopy1); normByRow.revert(dsCopy2); shouldBe0_1 = ds.getFeatureMatrix().get(NDArrayIndex.point(1), NDArrayIndex.all(), NDArrayIndex.interval(3, 5)); shouldBe0_2 = dsCopy1.getFeatureMatrix().get(NDArrayIndex.point(1), NDArrayIndex.all(), NDArrayIndex.interval(3, 5)); shouldBe0_3 = dsCopy2.getFeatureMatrix().get(NDArrayIndex.point(1), NDArrayIndex.all(), NDArrayIndex.interval(3, 5)); assertEquals(zeros, shouldBe0_1); assertEquals(zeros, shouldBe0_2); assertEquals(zeros, shouldBe0_3); } } @Override public char ordering() { return 'c'; } }