/* * Copyright [2013-2016] PayPal Software Foundation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package ml.shifu.shifu.core.pmml.builder.impl; import ml.shifu.shifu.container.obj.ColumnConfig; import ml.shifu.shifu.container.obj.ModelConfig; import ml.shifu.shifu.container.obj.ModelNormalizeConf; import ml.shifu.shifu.core.Normalizer; import org.dmg.pmml.*; import java.util.ArrayList; import java.util.List; /** * Created by zhanhu on 3/29/16. */ public class WoeLocalTransformCreator extends ZscoreLocalTransformCreator { public WoeLocalTransformCreator(ModelConfig modelConfig, List<ColumnConfig> columnConfigList) { super(modelConfig, columnConfigList); } public WoeLocalTransformCreator(ModelConfig modelConfig, List<ColumnConfig> columnConfigList, boolean isConcise) { super(modelConfig, columnConfigList, isConcise); } /** * Create @DerivedField for numerical variable * * @param config * - ColumnConfig for numerical variable * @param cutoff * - cutoff of normalization * @param normType * - the normalization method that is used to generate DerivedField * @return DerivedField for variable */ @Override protected List<DerivedField> createNumericalDerivedField(ColumnConfig config, double cutoff, ModelNormalizeConf.NormType normType) { List<Double> binWoeList = (normType.equals(ModelNormalizeConf.NormType.WOE) ? config.getBinCountWoe() : config .getBinWeightedWoe()); List<Double> binBoundaryList = config.getBinBoundary(); List<DiscretizeBin> discretizeBinList = new ArrayList<DiscretizeBin>(); for(int i = 0; i < binBoundaryList.size(); i++) { DiscretizeBin discretizeBin = new DiscretizeBin(); Interval interval = new Interval(); if(i == 0) { if ( binBoundaryList.size() == 1 ) { interval.withClosure(Interval.Closure.OPEN_OPEN) .withLeftMargin(Double.NEGATIVE_INFINITY) .withRightMargin(Double.POSITIVE_INFINITY); } else { interval.withClosure(Interval.Closure.OPEN_OPEN).withRightMargin(binBoundaryList.get(i + 1)); } } else if(i == binBoundaryList.size() - 1) { interval.withClosure(Interval.Closure.CLOSED_OPEN).withLeftMargin(binBoundaryList.get(i)); } else { interval.withClosure(Interval.Closure.CLOSED_OPEN).withLeftMargin(binBoundaryList.get(i)) .withRightMargin(binBoundaryList.get(i + 1)); } discretizeBin.withInterval(interval).withBinValue(Double.toString(binWoeList.get(i))); discretizeBinList.add(discretizeBin); } Discretize discretize = new Discretize(); discretize .withDataType(DataType.DOUBLE) .withField(FieldName.create(config.getColumnName())) .withMapMissingTo(Normalizer.normalize(config, null, cutoff, normType).toString()) .withDefaultValue(Normalizer.normalize(config, null, cutoff, normType).toString()) .withDiscretizeBins(discretizeBinList); // derived field name is consisted of FieldName and "_zscl" List<DerivedField> derivedFields = new ArrayList<DerivedField>(); derivedFields.add(new DerivedField(OpType.CONTINUOUS, DataType.DOUBLE).withName( FieldName.create(genPmmlColumnName(config.getColumnName(), normType))).withExpression(discretize)); return derivedFields; } }