package org.nd4j.linalg.dataset;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.TestMultiDataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.MultiNormalizerStandardize;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.ops.transforms.Transforms;
import static org.junit.Assert.*;
/**
* Most of the normalizer functionality is shared with {@link MultiNormalizerStandardize}
* and is covered in {@link NormalizerStandardizeTest}. This test suite just verifies if it deals properly with
* multiple inputs and multiple outputs
*
* @author Ede Meijer
*/
@RunWith(Parameterized.class)
public class MultiNormalizerStandardizeTest extends BaseNd4jTest {
private static final double TOLERANCE_PERC = 0.01; // 0.01% of correct value
private static final int INPUT1_SCALE = 1, INPUT2_SCALE = 2, OUTPUT1_SCALE = 3, OUTPUT2_SCALE = 4;
private MultiNormalizerStandardize SUT;
private MultiDataSet data;
private double meanNaturalNums;
private double stdNaturalNums;
@Before
public void setUp() {
SUT = new MultiNormalizerStandardize();
SUT.fitLabel(true);
// Prepare test data
int nSamples = 5120;
INDArray values = Nd4j.linspace(1, nSamples, nSamples).transpose();
INDArray input1 = values.mul(INPUT1_SCALE);
INDArray input2 = values.mul(INPUT2_SCALE);
INDArray output1 = values.mul(OUTPUT1_SCALE);
INDArray output2 = values.mul(OUTPUT2_SCALE);
data = new MultiDataSet(new INDArray[] {input1, input2}, new INDArray[] {output1, output2});
meanNaturalNums = (nSamples + 1) / 2.0;
stdNaturalNums = Math.sqrt((nSamples * nSamples - 1) / 12.0);
}
public MultiNormalizerStandardizeTest(Nd4jBackend backend) {
super(backend);
}
@Test
public void testMultipleInputsAndOutputsWithDataSet() {
SUT.fit(data);
assertExpectedMeanStd();
}
@Test
public void testMultipleInputsAndOutputsWithIterator() {
MultiDataSetIterator iter = new TestMultiDataSetIterator(1, data);
SUT.fit(iter);
assertExpectedMeanStd();
}
@Test
public void testRevertFeaturesINDArray() {
SUT.fit(data);
MultiDataSet transformed = data.copy();
SUT.preProcess(transformed);
INDArray reverted = transformed.getFeatures(0).dup();
SUT.revertFeatures(reverted, null, 0);
assertNotEquals(reverted, transformed.getFeatures(0));
SUT.revert(transformed);
assertEquals(reverted, transformed.getFeatures(0));
}
@Test
public void testRevertLabelsINDArray() {
SUT.fit(data);
MultiDataSet transformed = data.copy();
SUT.preProcess(transformed);
INDArray reverted = transformed.getLabels(0).dup();
SUT.revertLabels(reverted, null, 0);
assertNotEquals(reverted, transformed.getLabels(0));
SUT.revert(transformed);
assertEquals(reverted, transformed.getLabels(0));
}
@Test
public void testRevertMultiDataSet() {
SUT.fit(data);
MultiDataSet transformed = data.copy();
SUT.preProcess(transformed);
double diffBeforeRevert = getMaxRelativeDifference(data, transformed);
assertTrue(diffBeforeRevert > TOLERANCE_PERC);
SUT.revert(transformed);
double diffAfterRevert = getMaxRelativeDifference(data, transformed);
assertTrue(diffAfterRevert < TOLERANCE_PERC);
}
@Test
public void testFullyMaskedData() {
MultiDataSetIterator iter = new TestMultiDataSetIterator(1,
new MultiDataSet(new INDArray[] {Nd4j.create(new float[] {1}).reshape(1, 1, 1)},
new INDArray[] {Nd4j.create(new float[] {2}).reshape(1, 1, 1)}),
new MultiDataSet(new INDArray[] {Nd4j.create(new float[] {2}).reshape(1, 1, 1)},
new INDArray[] {Nd4j.create(new float[] {4}).reshape(1, 1, 1)}, null,
new INDArray[] {Nd4j.create(new float[] {0}).reshape(1, 1)}));
SUT.fit(iter);
// The label mean should be 2, as the second row with 4 is masked.
assertEquals(2f, SUT.getLabelMean(0).getFloat(0), 1e-6);
}
private double getMaxRelativeDifference(MultiDataSet a, MultiDataSet b) {
double max = 0;
for (int i = 0; i < a.getFeatures().length; i++) {
INDArray inputA = a.getFeatures()[i];
INDArray inputB = b.getFeatures()[i];
INDArray delta = Transforms.abs(inputA.sub(inputB)).div(inputB);
double maxdeltaPerc = delta.max(0, 1).mul(100).getDouble(0, 0);
if (maxdeltaPerc > max) {
max = maxdeltaPerc;
}
}
return max;
}
private void assertExpectedMeanStd() {
assertSmallDifference(meanNaturalNums * INPUT1_SCALE, SUT.getFeatureMean(0).getDouble(0));
assertSmallDifference(stdNaturalNums * INPUT1_SCALE, SUT.getFeatureStd(0).getDouble(0));
assertSmallDifference(meanNaturalNums * INPUT2_SCALE, SUT.getFeatureMean(1).getDouble(0));
assertSmallDifference(stdNaturalNums * INPUT2_SCALE, SUT.getFeatureStd(1).getDouble(0));
assertSmallDifference(meanNaturalNums * OUTPUT1_SCALE, SUT.getLabelMean(0).getDouble(0));
assertSmallDifference(stdNaturalNums * OUTPUT1_SCALE, SUT.getLabelStd(0).getDouble(0));
assertSmallDifference(meanNaturalNums * OUTPUT2_SCALE, SUT.getLabelMean(1).getDouble(0));
assertSmallDifference(stdNaturalNums * OUTPUT2_SCALE, SUT.getLabelStd(1).getDouble(0));
}
private void assertSmallDifference(double expected, double actual) {
double delta = Math.abs(expected - actual);
double deltaPerc = (delta / expected) * 100;
assertTrue(deltaPerc < TOLERANCE_PERC);
}
@Override
public char ordering() {
return 'c';
}
}