package org.nd4j.linalg.dataset;
import lombok.Getter;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.dataset.api.preprocessor.*;
import org.nd4j.linalg.dataset.api.preprocessor.serializer.*;
import org.nd4j.linalg.dataset.api.preprocessor.stats.DistributionStats;
import org.nd4j.linalg.dataset.api.preprocessor.stats.MinMaxStats;
import org.nd4j.linalg.dataset.api.preprocessor.stats.NormalizerStats;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import java.io.*;
import java.util.HashMap;
import java.util.Map;
import static java.util.Arrays.asList;
import static org.junit.Assert.assertEquals;
/**
* @author Ede Meijer
*/
@RunWith(Parameterized.class)
public class NormalizerSerializerTest extends BaseNd4jTest {
private File tmpFile;
private NormalizerSerializer SUT;
public NormalizerSerializerTest(Nd4jBackend backend) {
super(backend);
}
@Before
public void setUp() throws IOException {
tmpFile = File.createTempFile("test", "preProcessor");
tmpFile.deleteOnExit();
SUT = NormalizerSerializer.getDefault();
}
@Test
public void testNormalizerStandardizeNotFitLabels() throws Exception {
NormalizerStandardize original = new NormalizerStandardize(Nd4j.create(new double[] {0.5, 1.5}),
Nd4j.create(new double[] {2.5, 3.5}));
SUT.write(original, tmpFile);
NormalizerStandardize restored = SUT.restore(tmpFile);
assertEquals(original, restored);
}
@Test
public void testNormalizerStandardizeFitLabels() throws Exception {
NormalizerStandardize original = new NormalizerStandardize(Nd4j.create(new double[] {0.5, 1.5}),
Nd4j.create(new double[] {2.5, 3.5}), Nd4j.create(new double[] {4.5, 5.5}),
Nd4j.create(new double[] {6.5, 7.5}));
original.fitLabel(true);
SUT.write(original, tmpFile);
NormalizerStandardize restored = SUT.restore(tmpFile);
assertEquals(original, restored);
}
@Test
public void testNormalizerMinMaxScalerNotFitLabels() throws Exception {
NormalizerMinMaxScaler original = new NormalizerMinMaxScaler(0.1, 0.9);
original.setFeatureStats(Nd4j.create(new double[] {0.5, 1.5}), Nd4j.create(new double[] {2.5, 3.5}));
SUT.write(original, tmpFile);
NormalizerMinMaxScaler restored = SUT.restore(tmpFile);
assertEquals(original, restored);
}
@Test
public void testNormalizerMinMaxScalerFitLabels() throws Exception {
NormalizerMinMaxScaler original = new NormalizerMinMaxScaler(0.1, 0.9);
original.setFeatureStats(Nd4j.create(new double[] {0.5, 1.5}), Nd4j.create(new double[] {2.5, 3.5}));
original.setLabelStats(Nd4j.create(new double[] {4.5, 5.5}), Nd4j.create(new double[] {6.5, 7.5}));
original.fitLabel(true);
SUT.write(original, tmpFile);
NormalizerMinMaxScaler restored = SUT.restore(tmpFile);
assertEquals(original, restored);
}
@Test
public void testMultiNormalizerStandardizeNotFitLabels() throws Exception {
MultiNormalizerStandardize original = new MultiNormalizerStandardize();
original.setFeatureStats(asList(
new DistributionStats(Nd4j.create(new double[] {0.5, 1.5}),
Nd4j.create(new double[] {2.5, 3.5})),
new DistributionStats(Nd4j.create(new double[] {4.5, 5.5, 6.5}),
Nd4j.create(new double[] {7.5, 8.5, 9.5}))));
SUT.write(original, tmpFile);
MultiNormalizerStandardize restored = SUT.restore(tmpFile);
assertEquals(original, restored);
}
@Test
public void testMultiNormalizerStandardizeFitLabels() throws Exception {
MultiNormalizerStandardize original = new MultiNormalizerStandardize();
original.setFeatureStats(asList(
new DistributionStats(Nd4j.create(new double[] {0.5, 1.5}),
Nd4j.create(new double[] {2.5, 3.5})),
new DistributionStats(Nd4j.create(new double[] {4.5, 5.5, 6.5}),
Nd4j.create(new double[] {7.5, 8.5, 9.5}))));
original.setLabelStats(asList(
new DistributionStats(Nd4j.create(new double[] {0.5, 1.5}),
Nd4j.create(new double[] {2.5, 3.5})),
new DistributionStats(Nd4j.create(new double[] {4.5}), Nd4j.create(new double[] {7.5})),
new DistributionStats(Nd4j.create(new double[] {4.5, 5.5, 6.5}),
Nd4j.create(new double[] {7.5, 8.5, 9.5}))));
original.fitLabel(true);
SUT.write(original, tmpFile);
MultiNormalizerStandardize restored = SUT.restore(tmpFile);
assertEquals(original, restored);
}
@Test
public void testMultiNormalizerMinMaxScalerNotFitLabels() throws Exception {
MultiNormalizerMinMaxScaler original = new MultiNormalizerMinMaxScaler(0.1, 0.9);
original.setFeatureStats(asList(
new MinMaxStats(Nd4j.create(new double[] {0.5, 1.5}), Nd4j.create(new double[] {2.5, 3.5})),
new MinMaxStats(Nd4j.create(new double[] {4.5, 5.5, 6.5}),
Nd4j.create(new double[] {7.5, 8.5, 9.5}))));
SUT.write(original, tmpFile);
MultiNormalizerMinMaxScaler restored = SUT.restore(tmpFile);
assertEquals(original, restored);
}
@Test
public void testMultiNormalizerMinMaxScalerFitLabels() throws Exception {
MultiNormalizerMinMaxScaler original = new MultiNormalizerMinMaxScaler(0.1, 0.9);
original.setFeatureStats(asList(
new MinMaxStats(Nd4j.create(new double[] {0.5, 1.5}), Nd4j.create(new double[] {2.5, 3.5})),
new MinMaxStats(Nd4j.create(new double[] {4.5, 5.5, 6.5}),
Nd4j.create(new double[] {7.5, 8.5, 9.5}))));
original.setLabelStats(asList(
new MinMaxStats(Nd4j.create(new double[] {0.5, 1.5}), Nd4j.create(new double[] {2.5, 3.5})),
new MinMaxStats(Nd4j.create(new double[] {4.5}), Nd4j.create(new double[] {7.5})),
new MinMaxStats(Nd4j.create(new double[] {4.5, 5.5, 6.5}),
Nd4j.create(new double[] {7.5, 8.5, 9.5}))));
original.fitLabel(true);
SUT.write(original, tmpFile);
MultiNormalizerMinMaxScaler restored = SUT.restore(tmpFile);
assertEquals(original, restored);
}
@Test
public void testMultiNormalizerHybridEmpty() throws Exception {
MultiNormalizerHybrid original = new MultiNormalizerHybrid();
original.setInputStats(new HashMap<Integer, NormalizerStats>());
original.setOutputStats(new HashMap<Integer, NormalizerStats>());
SUT.write(original, tmpFile);
MultiNormalizerHybrid restored = SUT.restore(tmpFile);
assertEquals(original, restored);
}
@Test
public void testMultiNormalizerHybridGlobalStats() throws Exception {
MultiNormalizerHybrid original = new MultiNormalizerHybrid().minMaxScaleAllInputs().standardizeAllOutputs();
Map<Integer, NormalizerStats> inputStats = new HashMap<>();
inputStats.put(0, new MinMaxStats(Nd4j.create(new float[] {1, 2}), Nd4j.create(new float[] {3, 4})));
inputStats.put(0, new MinMaxStats(Nd4j.create(new float[] {5, 6}), Nd4j.create(new float[] {7, 8})));
Map<Integer, NormalizerStats> outputStats = new HashMap<>();
outputStats.put(0, new DistributionStats(Nd4j.create(new float[] {9, 10}), Nd4j.create(new float[] {11, 12})));
outputStats.put(0, new DistributionStats(Nd4j.create(new float[] {13, 14}), Nd4j.create(new float[] {15, 16})));
original.setInputStats(inputStats);
original.setOutputStats(outputStats);
SUT.write(original, tmpFile);
MultiNormalizerHybrid restored = SUT.restore(tmpFile);
assertEquals(original, restored);
}
@Test
public void testMultiNormalizerHybridGlobalAndSpecificStats() throws Exception {
MultiNormalizerHybrid original = new MultiNormalizerHybrid().standardizeAllInputs().minMaxScaleInput(0, -5, 5)
.minMaxScaleAllOutputs(-10, 10).standardizeOutput(1);
Map<Integer, NormalizerStats> inputStats = new HashMap<>();
inputStats.put(0, new MinMaxStats(Nd4j.create(new float[] {1, 2}), Nd4j.create(new float[] {3, 4})));
inputStats.put(1, new DistributionStats(Nd4j.create(new float[] {5, 6}), Nd4j.create(new float[] {7, 8})));
Map<Integer, NormalizerStats> outputStats = new HashMap<>();
outputStats.put(0, new MinMaxStats(Nd4j.create(new float[] {9, 10}), Nd4j.create(new float[] {11, 12})));
outputStats.put(1, new DistributionStats(Nd4j.create(new float[] {13, 14}), Nd4j.create(new float[] {15, 16})));
original.setInputStats(inputStats);
original.setOutputStats(outputStats);
SUT.write(original, tmpFile);
MultiNormalizerHybrid restored = SUT.restore(tmpFile);
assertEquals(original, restored);
}
@Test(expected = RuntimeException.class)
public void testCustomNormalizerWithoutRegisteredStrategy() throws Exception {
SUT.write(new MyNormalizer(123), tmpFile);
}
@Test
public void testCustomNormalizer() throws Exception {
MyNormalizer original = new MyNormalizer(42);
SUT.addStrategy(new MyNormalizerSerializerStrategy());
SUT.write(original, tmpFile);
MyNormalizer restored = SUT.restore(tmpFile);
assertEquals(original, restored);
}
public static class MyNormalizer extends AbstractDataSetNormalizer<MinMaxStats> {
@Getter
private final int foo;
public MyNormalizer(int foo) {
super(new MinMaxStrategy());
this.foo = foo;
setFeatureStats(new MinMaxStats(Nd4j.zeros(1), Nd4j.ones(1)));
}
@Override
public NormalizerType getType() {
return NormalizerType.CUSTOM;
}
@Override
protected NormalizerStats.Builder newBuilder() {
return new MinMaxStats.Builder();
}
}
public static class MyNormalizerSerializerStrategy extends CustomSerializerStrategy<MyNormalizer> {
@Override
public Class<MyNormalizer> getSupportedClass() {
return MyNormalizer.class;
}
@Override
public void write(MyNormalizer normalizer, OutputStream stream) throws IOException {
new DataOutputStream(stream).writeInt(normalizer.getFoo());
}
@Override
public MyNormalizer restore(InputStream stream) throws IOException {
return new MyNormalizer(new DataInputStream(stream).readInt());
}
}
@Override
public char ordering() {
return 'c';
}
}