package org.nd4j.linalg.dataset;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.preprocessor.MultiNormalizerHybrid;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
/**
* In-depth testing of correctness of standardization and min-max scaling is covered by other tests, since the code for
* doing that is reused in MultiNormalizerHybrid. These tests will just cover the configurability.
*/
@RunWith(Parameterized.class)
public class MultiNormalizerHybridTest extends BaseNd4jTest {
private MultiNormalizerHybrid SUT;
private MultiDataSet data;
private MultiDataSet dataCopy;
@Before
public void setUp() {
SUT = new MultiNormalizerHybrid();
data = new MultiDataSet(
new INDArray[] {Nd4j.create(new float[][] {{1, 2}, {3, 4}}),
Nd4j.create(new float[][] {{3, 4}, {5, 6}}),},
new INDArray[] {Nd4j.create(new float[][] {{10, 11}, {12, 13}}),
Nd4j.create(new float[][] {{14, 15}, {16, 17}}),});
dataCopy = data.copy();
}
public MultiNormalizerHybridTest(Nd4jBackend backend) {
super(backend);
}
@Test
public void testNoNormalizationByDefault() {
SUT.fit(data);
SUT.preProcess(data);
assertEquals(dataCopy, data);
SUT.revert(data);
assertEquals(dataCopy, data);
}
@Test
public void testGlobalNormalization() {
SUT.standardizeAllInputs().minMaxScaleAllOutputs(-10, 10).fit(data);
SUT.preProcess(data);
MultiDataSet expected = new MultiDataSet(
new INDArray[] {Nd4j.create(new float[][] {{-1, -1}, {1, 1}}),
Nd4j.create(new float[][] {{-1, -1}, {1, 1}}),},
new INDArray[] {Nd4j.create(new float[][] {{-10, -10}, {10, 10}}),
Nd4j.create(new float[][] {{-10, -10}, {10, 10}}),});
assertEquals(expected, data);
SUT.revert(data);
assertEquals(dataCopy, data);
}
@Test
public void testSpecificInputOutputNormalization() {
SUT.minMaxScaleAllInputs().standardizeInput(1).standardizeOutput(0).fit(data);
SUT.preProcess(data);
MultiDataSet expected = new MultiDataSet(
new INDArray[] {Nd4j.create(new float[][] {{0, 0}, {1, 1}}),
Nd4j.create(new float[][] {{-1, -1}, {1, 1}}),},
new INDArray[] {Nd4j.create(new float[][] {{-1, -1}, {1, 1}}),
Nd4j.create(new float[][] {{14, 15}, {16, 17}}),});
assertEquals(expected, data);
SUT.revert(data);
assertEquals(dataCopy, data);
}
@Test
public void testMasking() {
MultiDataSet timeSeries = new MultiDataSet(
new INDArray[] {Nd4j.create(new float[] {1, 2, 3, 4, 5, 0, 7, 0}).reshape(2, 2, 2),},
new INDArray[] {Nd4j.create(new float[] {0, 20, 0, 40, 50, 60, 70, 80}).reshape(2, 2, 2)},
new INDArray[] {Nd4j.create(new float[][] {{1, 1}, {1, 0}})},
new INDArray[] {Nd4j.create(new float[][] {{0, 1}, {1, 1}})});
MultiDataSet timeSeriesCopy = timeSeries.copy();
SUT.minMaxScaleAllInputs(-10, 10).minMaxScaleAllOutputs(-10, 10).fit(timeSeries);
SUT.preProcess(timeSeries);
MultiDataSet expected = new MultiDataSet(
new INDArray[] {Nd4j.create(new float[] {-10, -5, -10, -5, 10, 0, 10, 0}).reshape(2, 2, 2),},
new INDArray[] {Nd4j.create(new float[] {0, -10, 0, -10, 5, 10, 5, 10}).reshape(2, 2, 2),},
new INDArray[] {Nd4j.create(new float[][] {{1, 1}, {1, 0}})},
new INDArray[] {Nd4j.create(new float[][] {{0, 1}, {1, 1}})});
assertEquals(expected, timeSeries);
SUT.revert(timeSeries);
assertEquals(timeSeriesCopy, timeSeries);
}
@Test
public void testDataSetWithoutLabels() {
SUT.standardizeAllInputs().standardizeAllOutputs().fit(data);
data.setLabels(null);
data.setLabelsMaskArray(null);
SUT.preProcess(data);
}
@Test
public void testDataSetWithoutFeatures() {
SUT.standardizeAllInputs().standardizeAllOutputs().fit(data);
data.setFeatures(null);
data.setFeaturesMaskArrays(null);
SUT.preProcess(data);
}
@Override
public char ordering() {
return 'c';
}
}