/* * Copyright [2012-2014] PayPal Software Foundation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package ml.shifu.shifu.core; import java.util.List; import ml.shifu.shifu.container.obj.ColumnConfig; import ml.shifu.shifu.container.obj.ModelNormalizeConf; import ml.shifu.shifu.util.CommonUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * Normalizer * <p> * formula: * <p> * <code>norm_result = (value - means) / stdev</code> The stdDevCutOff should be setting, by default it's 4 * <p> * The <code>value</code> should less than mean + stdDevCutOff * stdev * <p> * and larger than mean - stdDevCutOff * stdev */ public class Normalizer { private static Logger log = LoggerFactory.getLogger(Normalizer.class); public static final double STD_DEV_CUTOFF = 4.0d; public enum NormalizeMethod { ZScore, MaxMin; } private ColumnConfig config; private Double stdDevCutOff = 4.0; private NormalizeMethod method; /** * Create @Normalizer, according ColumnConfig * NormalizeMethod method will be NormalizeMethod.ZScore * stdDevCutOff will be STD_DEV_CUTOFF * * @param config * ColumnConfig to create normalizer */ public Normalizer(ColumnConfig config) { this(config, NormalizeMethod.ZScore, STD_DEV_CUTOFF); } /** * Create @Normalizer, according ColumnConfig and NormalizeMethod * stdDevCutOff will be STD_DEV_CUTOFF * * @param config * ColumnConfig to create normalizer * @param method * NormalizMethod to use */ public Normalizer(ColumnConfig config, NormalizeMethod method) { this(config, method, STD_DEV_CUTOFF); } /** * Create @Normalizer, according ColumnConfig and NormalizeMethod * NormalizeMethod method will be NormalizeMethod.ZScore * * @param config * ColumnConfig to create normalizer * @param cutoff * stand_dev_cutoff to use */ public Normalizer(ColumnConfig config, Double cutoff) { this(config, NormalizeMethod.ZScore, STD_DEV_CUTOFF); } /** * Create @Normalizer, according ColumnConfig and NormalizeMethod * NormalizeMethod method will be NormalizeMethod.ZScore * * @param config * ColumnConfig to create normalizer * @param method * NormalizMethod to use * @param cutoff * stand_dev_cutoff to use */ public Normalizer(ColumnConfig config, NormalizeMethod method, Double cutoff) { this.config = config; this.method = method; this.stdDevCutOff = cutoff; } /** * Normalize the input data for column * * @param raw * the raw value * @return normalized value */ public Double normalize(String raw) { return normalize(config, raw, method, stdDevCutOff); } /** * Normalize the raw file, according the ColumnConfig info * * @param config * ColumnConfig to normalize data * @param raw * raw input data * @return normalized value */ public static Double normalize(ColumnConfig config, String raw) { return normalize(config, raw, NormalizeMethod.ZScore); } /** * Normalize the raw file, according the ColumnConfig info and normalized method * * @param config * ColumnConfig to normalize data * @param raw * raw input data * @param method * the method used to do normalization * @return normalized value */ public static Double normalize(ColumnConfig config, String raw, NormalizeMethod method) { return normalize(config, raw, method, STD_DEV_CUTOFF); } /** * Normalize the raw file, according the ColumnConfig info and standard deviation cutoff * * @param config * ColumnConfig to normalize data * @param raw * raw input data * @param stdDevCutoff * the standard deviation cutoff to use * @return normalized value */ public static Double normalize(ColumnConfig config, String raw, double stdDevCutoff) { return normalize(config, raw, NormalizeMethod.ZScore, stdDevCutoff); } /** * Normalize the raw file, according the ColumnConfig info, normalized method and standard deviation cutoff * * @param config * ColumnConfig to normalize data * @param raw * raw input data * @param method * the method used to do normalization * @param stdDevCutoff * the standard deviation cutoff to use * @return normalized value */ public static Double normalize(ColumnConfig config, String raw, NormalizeMethod method, double stdDevCutoff) { if(method == null) { method = NormalizeMethod.ZScore; } switch(method) { case ZScore: return zScoreNormalize(config, raw, stdDevCutoff); case MaxMin: return getMaxMinScore(config, raw); default: return 0.0; } } /** * Compute the normalized data for @NormalizeMethod.MaxMin * * @param config * ColumnConfig info * @param raw * input column value * @return normalized value for MaxMin method */ private static Double getMaxMinScore(ColumnConfig config, String raw) { if(config.isCategorical()) { // TODO, doesn't support } else { Double value = Double.parseDouble(raw); return (value - config.getColumnStats().getMin()) / (config.getColumnStats().getMax() - config.getColumnStats().getMin()); } return null; } /** * Normalize the raw data, according the ColumnConfig infomation and normalization type. * Currently, the cutoff value doesn't affect the computation of WOE or WEIGHT_WOE type. * * <p> * Noticed: currently OLD_ZSCALE and ZSCALE is implemented with the same process method. * </p> * * @param config * ColumnConfig to normalize data * @param raw * raw input data * @param cutoff * standard deviation cut off * @param type * normalization type of ModelNormalizeConf.NormType * @return normalized value. If normType parameter is invalid, then the ZSCALE will be used as default. */ public static Double normalize(ColumnConfig config, String raw, Double cutoff, ModelNormalizeConf.NormType type) { switch(type) { case WOE: return woeNormalize(config, raw, false); case WEIGHT_WOE: return woeNormalize(config, raw, true); case HYBRID: return hybridNormalize(config, raw, cutoff, false); case WEIGHT_HYBRID: return hybridNormalize(config, raw, cutoff, true); case WOE_ZSCORE: case WOE_ZSCALE: return woeZScoreNormalize(config, raw, cutoff, false); case WEIGHT_WOE_ZSCORE: case WEIGHT_WOE_ZSCALE: return woeZScoreNormalize(config, raw, cutoff, true); case OLD_ZSCALE: case OLD_ZSCORE: case ZSCALE: case ZSCORE: default: return zScoreNormalize(config, raw, cutoff); } } /** * Compute the normalized data for @NormalizeMethod.Zscore * * @param config * ColumnConfig info * @param raw * input column value * @param cutoff * standard deviation cut off * @return normalized value for ZScore method. */ private static Double zScoreNormalize(ColumnConfig config, String raw, Double cutoff) { double stdDevCutOff = checkCutOff(cutoff); double value = parseRawValue(config, raw); return computeZScore(value, config.getMean(), config.getStdDev(), stdDevCutOff); } /** * Parse raw value based on ColumnConfig. * * @param config * ColumnConfig info * @param raw * input column value * @return parsed raw value. For categorical type, return BinPosRate. For numerical type, return * corresponding double value. For missing data, return default value using * {@link Normalizer#defaultMissingValue}. */ private static double parseRawValue(ColumnConfig config, String raw) { double value = 0.0; if(config.isCategorical()) { int index = CommonUtils.getBinNum(config, raw); if(index == -1) { value = defaultMissingValue(config); } else { Double binPosRate = config.getBinPosRate().get(index); value = binPosRate == null ? defaultMissingValue(config) : binPosRate.doubleValue(); } } else { try { value = Double.parseDouble(raw); } catch (Exception e) { log.debug("Not decimal format " + raw + ", using default!"); value = defaultMissingValue(config); } } return value; } /** * Get the default value for missing data. * * @param config * ColumnConfig info * @return default value for missing data. Now simply return Mean value. If mean is null then return 0. */ public static double defaultMissingValue(ColumnConfig config) { // TODO return 0 for mean == null is correct or reasonable? return config.getMean() == null ? 0 : config.getMean().doubleValue(); } /** * Compute the normalized data for Woe Score. * * @param config * ColumnConfig info * @param raw * input column value * @param isWeightedNorm * if use weighted woe * @return normalized value for Woe method. For missing value, we return the value in last bin. Since the last * bin refers to the missing value bin. */ private static Double woeNormalize(ColumnConfig config, String raw, boolean isWeightedNorm) { List<Double> woeBins = isWeightedNorm ? config.getBinWeightedWoe() : config.getBinCountWoe(); int binIndex = CommonUtils.getBinNum(config, raw); if(binIndex == -1) { // The last bin in woeBins is the miss value bin. return woeBins.get(woeBins.size() - 1); } else { return woeBins.get(binIndex); } } /** * Compute the normalized value for woe zscore normalize.Take woe as variable value and using zscore normalizing * to compute zscore of woe. * * @param config * ColumnConfig info * @param raw * input column value * @param cutoff * standard deviation cut off * @param isWeightedNorm * if use weighted woe * @return normalized value for woe zscore method. */ private static Double woeZScoreNormalize(ColumnConfig config, String raw, Double cutoff, boolean isWeightedNorm) { double stdDevCutOff = checkCutOff(cutoff); double woe = woeNormalize(config, raw, isWeightedNorm); double[] meanAndStdDev = calculateWoeMeanAndStdDev(config, isWeightedNorm); return computeZScore(woe, meanAndStdDev[0], meanAndStdDev[1], stdDevCutOff); } /** * Compute the normalized data for hbrid normalize. Use zscore noramlize for numerical data. Use woe normalize * for categorical data while use weight woe normalize when isWeightedNorm is true. * * @param config * ColumnConfig info * @param raw * input column value * @param cutoff * standard deviation cut off * @param isWeightedNorm * if use weighted woe * @return normalized value for hybrid method. */ private static Double hybridNormalize(ColumnConfig config, String raw, Double cutoff, boolean isWeightedNorm) { Double normValue; if(config.isNumerical()) { // For numerical data, use zscore. normValue = zScoreNormalize(config, raw, cutoff); } else { // For categorical data, use woe. normValue = woeNormalize(config, raw, isWeightedNorm); } return normValue; } /** * Check specified standard deviation cutoff and return the correct value. * * @param cutoff * specified standard deviation cutoff * @return If cutoff is valid then return it, else return {@link Normalizer#STD_DEV_CUTOFF} */ private static double checkCutOff(Double cutoff) { double stdDevCutOff; if(cutoff != null && !cutoff.isInfinite() && !cutoff.isNaN()) { stdDevCutOff = cutoff; } else { stdDevCutOff = STD_DEV_CUTOFF; } return stdDevCutOff; } /** * Calculate woe mean and woe standard deviation. * * @param config * ColumnConfig info * @param isWeightedNorm * if use weighted woe * @return an double array contains woe mean and woe standard deviation as order {mean, stdDev} */ public static double[] calculateWoeMeanAndStdDev(ColumnConfig config, boolean isWeightedNorm) { List<Double> woeList = isWeightedNorm ? config.getBinWeightedWoe() : config.getBinCountWoe(); if(woeList == null || woeList.size() < 2) { throw new IllegalArgumentException("Woe list is null or too short(size < 2)"); } List<Integer> negCountList = config.getBinCountNeg(); List<Integer> posCountList = config.getBinCountPos(); // calculate woe mean and standard deviation int size = woeList.size(); double sum = 0.0; double squaredSum = 0.0; long totalCount = 0; for(int i = 0; i < size; i++) { int count = negCountList.get(i) + posCountList.get(i); totalCount += count; double x = woeList.get(i); sum += x * count; squaredSum += x * x * count; } double woeMean = sum / totalCount; double woeStdDev = Math.sqrt(Math.abs((squaredSum - (sum * sum) / totalCount) / (totalCount - 1))); return new double[] { woeMean, woeStdDev }; } /** * Compute the zscore, by original value, mean, standard deviation and standard deviation cutoff * * @param var * original value * @param mean * mean value * @param stdDev * standard deviation * @param stdDevCutOff * standard deviation cutoff * @return zscore */ public static double computeZScore(double var, double mean, double stdDev, double stdDevCutOff) { double maxCutOff = mean + stdDevCutOff * stdDev; if(var > maxCutOff) { var = maxCutOff; } double minCutOff = mean - stdDevCutOff * stdDev; if(var < minCutOff) { var = minCutOff; } if(stdDev > 0.00001) { return (var - mean) / stdDev; } else { return 0.0; } } }