/*
* Copyright [2012-2014] PayPal Software Foundation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.shifu.shifu.core;
import ml.shifu.shifu.container.obj.ColumnBinning;
import ml.shifu.shifu.container.obj.ColumnConfig;
import ml.shifu.shifu.container.obj.ColumnConfig.ColumnType;
import ml.shifu.shifu.container.obj.ModelNormalizeConf.NormType;
import org.testng.Assert;
import org.testng.annotations.Test;
import java.util.Arrays;
public class NormalizerTest {
@Test
public void computeZScore() {
Assert.assertEquals(0.0, Normalizer.computeZScore(2, 2, 1, 6.0));
Assert.assertEquals(6.0, Normalizer.computeZScore(12, 2, 1, 6.0));
Assert.assertEquals(-2.0, Normalizer.computeZScore(2, 4, 1, 2));
// If stdDev == 0, return 0
Assert.assertEquals(0.0, Normalizer.computeZScore(12, 2, 0, 6.0));
}
@Test
public void getZScore1() {
ColumnConfig config = new ColumnConfig();
config.setMean(2.0);
config.setStdDev(1.0);
config.setColumnType(ColumnType.N);
Assert.assertEquals(0.0, Normalizer.normalize(config, "2", 6.0));
Assert.assertEquals(0.0, Normalizer.normalize(config, "ABC", 0.1));
}
@Test
public void getZScore2() {
ColumnConfig config = new ColumnConfig();
config.setMean(2.0);
config.setStdDev(1.0);
config.setColumnType(ColumnType.N);
Assert.assertEquals(-4.0, Normalizer.normalize(config, "-3", null));
}
@Test
public void getZScore3() {
ColumnConfig config = new ColumnConfig();
config.setColumnType(ColumnType.C);
config.setMean(2.0);
config.setStdDev(1.0);
config.setBinCategory(Arrays.asList(new String[] { "1", "2", "3", "4", "ABC" }));
config.setBinPosCaseRate(Arrays.asList(new Double[] { 0.1, 2.0, 0.3, 0.1 }));
Assert.assertEquals(0.0, Normalizer.normalize(config, "2", 0.1));
Assert.assertEquals(0.0, Normalizer.normalize(config, "5", 0.1));
}
@Test
public void getZScore4() {
ColumnConfig config = new ColumnConfig();
Normalizer n = new Normalizer(config, 0.1);
config.setMean(2.0);
config.setStdDev(1.0);
config.setColumnType(ColumnType.N);
Assert.assertEquals(0.0, n.normalize("2"));
}
@Test
public void numericalNormalizeTest() {
// Input setting
ColumnConfig config = new ColumnConfig();
config.setMean(2.0);
config.setStdDev(1.0);
config.setColumnType(ColumnType.N);
ColumnBinning cbin = new ColumnBinning();
cbin.setBinCountWoe(Arrays.asList(new Double[] { 10.0, 11.0, 12.0, 13.0, 6.5 }));
cbin.setBinWeightedWoe(Arrays.asList(new Double[] { 20.0, 21.0, 22.0, 23.0, 16.5 }));
cbin.setBinBoundary(Arrays.asList(new Double[] { Double.NEGATIVE_INFINITY, 2.0, 4.0, 6.0 }));
cbin.setBinCountNeg(Arrays.asList(1, 2, 3, 4, 5));
cbin.setBinCountPos(Arrays.asList(5, 4, 3, 2, 1));
config.setColumnBinning(cbin);
// Test zscore normalization
Assert.assertEquals(Normalizer.normalize(config, "5.0", 4.0, NormType.ZSCALE), 3.0);
Assert.assertEquals(Normalizer.normalize(config, "5.0", null, NormType.ZSCALE), 3.0);
Assert.assertEquals(Normalizer.normalize(config, "wrong_format", 4.0, NormType.ZSCALE), 0.0);
Assert.assertEquals(Normalizer.normalize(config, null, 4.0, NormType.ZSCALE), 0.0);
// Test old zscore normalization
Assert.assertEquals(Normalizer.normalize(config, "5.0", 4.0, NormType.OLD_ZSCALE), 3.0);
Assert.assertEquals(Normalizer.normalize(config, "5.0", null, NormType.OLD_ZSCALE), 3.0);
Assert.assertEquals(Normalizer.normalize(config, "wrong_format", 4.0, NormType.OLD_ZSCALE), 0.0);
Assert.assertEquals(Normalizer.normalize(config, null, 4.0, NormType.OLD_ZSCALE), 0.0);
// Test woe normalization
Assert.assertEquals(Normalizer.normalize(config, "3.0", null, NormType.WEIGHT_WOE), 21.0);
Assert.assertEquals(Normalizer.normalize(config, "wrong_format", null, NormType.WEIGHT_WOE), 16.5);
Assert.assertEquals(Normalizer.normalize(config, null, null, NormType.WEIGHT_WOE), 16.5);
Assert.assertEquals(Normalizer.normalize(config, "3.0", null, NormType.WOE), 11.0);
Assert.assertEquals(Normalizer.normalize(config, "wrong_format", null, NormType.WOE), 6.5);
Assert.assertEquals(Normalizer.normalize(config, null, null, NormType.WOE), 6.5);
// Test hybrid normalization, for numerical use zscore.
Assert.assertEquals(Normalizer.normalize(config, "5.0", 4.0, NormType.HYBRID), 3.0);
Assert.assertEquals(Normalizer.normalize(config, "5.0", null, NormType.HYBRID), 3.0);
Assert.assertEquals(Normalizer.normalize(config, "wrong_format", 4.0, NormType.HYBRID), 0.0);
Assert.assertEquals(Normalizer.normalize(config, null, 4.0, NormType.HYBRID), 0.0);
// Currently WEIGHT_HYBRID and HYBRID act same for numerical value, both calculate zscore.
Assert.assertEquals(Normalizer.normalize(config, "5.0", 4.0, NormType.WEIGHT_HYBRID), 3.0);
Assert.assertEquals(Normalizer.normalize(config, "5.0", null, NormType.WEIGHT_HYBRID), 3.0);
Assert.assertEquals(Normalizer.normalize(config, "wrong_format", 4.0, NormType.WEIGHT_HYBRID), 0.0);
Assert.assertEquals(Normalizer.normalize(config, null, 4.0, NormType.WEIGHT_HYBRID), 0.0);
// Test woe zscore normalization
// Assert.assertEquals(Normalizer.normalize(config, "3.0", 10.0, NormType.WOE_ZSCORE), 0.2);
// Assert.assertEquals(Normalizer.normalize(config, "wrong_format", 12.0, NormType.WOE_ZSCORE), -1.6);
// Assert.assertEquals(Normalizer.normalize(config, null, 12.0, NormType.WOE_ZSCORE), -1.6);
//
// Assert.assertEquals(Normalizer.normalize(config, "3.0", 20.0, NormType.WEIGHT_WOE_ZSCORE), 0.2);
// Assert.assertEquals(Normalizer.normalize(config, "wrong_format", 22.0, NormType.WEIGHT_WOE_ZSCORE), -1.6);
// Assert.assertEquals(Normalizer.normalize(config, null, 22.0, NormType.WEIGHT_WOE_ZSCORE), -1.6);
}
@Test
public void categoricalNormalizeTest() {
// Input setting
ColumnConfig config = new ColumnConfig();
config.setMean(0.2);
config.setStdDev(1.0);
config.setColumnType(ColumnType.C);
ColumnBinning cbin = new ColumnBinning();
cbin.setBinCountWoe(Arrays.asList(new Double[] { 10.0, 11.0, 12.0, 13.0, 6.5 }));
cbin.setBinWeightedWoe(Arrays.asList(new Double[] { 20.0, 21.0, 22.0, 23.0, 16.5 }));
cbin.setBinCategory(Arrays.asList(new String[] { "a", "b", "c", "d" }));
cbin.setBinPosRate(Arrays.asList(new Double[] { 0.2, 0.4, 0.8, 1.0 }));
cbin.setBinCountNeg(Arrays.asList(1, 2, 3, 4, 5));
cbin.setBinCountPos(Arrays.asList(5, 4, 3, 2, 1));
config.setColumnBinning(cbin);
// Test zscore normalization
Assert.assertEquals(Normalizer.normalize(config, "b", 4.0, NormType.ZSCALE), 0.2);
Assert.assertEquals(Normalizer.normalize(config, "b", null, NormType.ZSCALE), 0.2);
Assert.assertEquals(Normalizer.normalize(config, "wrong_format", 4.0, NormType.ZSCALE), 0.0);
Assert.assertEquals(Normalizer.normalize(config, null, 4.0, NormType.ZSCALE), 0.0);
// Test old zscore normalization
Assert.assertEquals(Normalizer.normalize(config, "b", 4.0, NormType.OLD_ZSCALE), 0.2);
Assert.assertEquals(Normalizer.normalize(config, "b", null, NormType.OLD_ZSCALE), 0.2);
Assert.assertEquals(Normalizer.normalize(config, "wrong_format", 4.0, NormType.OLD_ZSCALE), 0.0);
Assert.assertEquals(Normalizer.normalize(config, null, 4.0, NormType.OLD_ZSCALE), 0.0);
// Test woe normalization
Assert.assertEquals(Normalizer.normalize(config, "c", null, NormType.WEIGHT_WOE), 22.0);
Assert.assertEquals(Normalizer.normalize(config, "wrong_format", null, NormType.WEIGHT_WOE), 16.5);
Assert.assertEquals(Normalizer.normalize(config, null, null, NormType.WEIGHT_WOE), 16.5);
Assert.assertEquals(Normalizer.normalize(config, "c", null, NormType.WOE), 12.0);
Assert.assertEquals(Normalizer.normalize(config, "wrong_format", null, NormType.WOE), 6.5);
Assert.assertEquals(Normalizer.normalize(config, null, null, NormType.WOE), 6.5);
// Test hybrid normalization, for categorical value use [weight]woe.
Assert.assertEquals(Normalizer.normalize(config, "a", null, NormType.HYBRID), 10.0);
Assert.assertEquals(Normalizer.normalize(config, "wrong_format", null, NormType.HYBRID), 6.5);
Assert.assertEquals(Normalizer.normalize(config, null, null, NormType.HYBRID), 6.5);
Assert.assertEquals(Normalizer.normalize(config, "a", null, NormType.WEIGHT_HYBRID), 20.0);
Assert.assertEquals(Normalizer.normalize(config, "wrong_format", null, NormType.WEIGHT_HYBRID), 16.5);
Assert.assertEquals(Normalizer.normalize(config, null, null, NormType.WEIGHT_HYBRID), 16.5);
// Test woe zscore normalization
// Assert.assertEquals(Normalizer.normalize(config, "b", 12.0, NormType.WOE_ZSCORE), 0.2);
// Assert.assertEquals(Normalizer.normalize(config, "wrong_format", 13.0, NormType.WOE_ZSCORE), -1.6);
// Assert.assertEquals(Normalizer.normalize(config, null, 13.0, NormType.WOE_ZSCORE), -1.6);
//
// Assert.assertEquals(Normalizer.normalize(config, "b", 22.0, NormType.WEIGHT_WOE_ZSCORE), 0.2);
// Assert.assertEquals(Normalizer.normalize(config, "wrong_format", 23.0, NormType.WEIGHT_WOE_ZSCORE), -1.6);
// Assert.assertEquals(Normalizer.normalize(config, null, 23.0, NormType.WEIGHT_WOE_ZSCORE), -1.6);
}
}