/* * Copyright [2013-2016] PayPal Software Foundation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package ml.shifu.shifu.core.pmml.builder.impl; import java.util.ArrayList; import java.util.List; import ml.shifu.shifu.container.obj.ColumnConfig; import ml.shifu.shifu.container.obj.ModelConfig; import ml.shifu.shifu.container.obj.ModelNormalizeConf; import ml.shifu.shifu.core.Normalizer; import org.dmg.pmml.DataType; import org.dmg.pmml.DerivedField; import org.dmg.pmml.FieldName; import org.dmg.pmml.LinearNorm; import org.dmg.pmml.NormContinuous; import org.dmg.pmml.OpType; import org.dmg.pmml.OutlierTreatmentMethodType; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * Created by zhanhu on 5/20/16. */ public class WoeZscoreLocalTransformCreator extends WoeLocalTransformCreator { @SuppressWarnings("unused") private static final Logger LOG = LoggerFactory.getLogger(WoeZscoreLocalTransformCreator.class); private boolean isWeightedNorm; public WoeZscoreLocalTransformCreator(ModelConfig modelConfig, List<ColumnConfig> columnConfigList, boolean isWeightedNorm) { super(modelConfig, columnConfigList); this.isWeightedNorm = isWeightedNorm; } public WoeZscoreLocalTransformCreator(ModelConfig modelConfig, List<ColumnConfig> columnConfigList, boolean isConcise, boolean isWeightedNorm) { super(modelConfig, columnConfigList, isConcise); this.isWeightedNorm = isWeightedNorm; } /** * Create @DerivedField for numerical variable * * @param config - ColumnConfig for numerical variable * @param cutoff - cutoff of normalization * @return DerivedField for variable */ @Override protected List<DerivedField> createNumericalDerivedField(ColumnConfig config, double cutoff, ModelNormalizeConf.NormType normType) { List<DerivedField> derivedFields = new ArrayList<DerivedField>(); DerivedField derivedField = super.createNumericalDerivedField(config, cutoff, ModelNormalizeConf.NormType.WOE).get(0); derivedFields.add(derivedField); double[] meanAndStdDev = Normalizer.calculateWoeMeanAndStdDev(config, isWeightedNorm); // added capping logic to linearNorm LinearNorm from = new LinearNorm().withOrig(meanAndStdDev[0] - meanAndStdDev[1] * cutoff).withNorm(-cutoff); LinearNorm to = new LinearNorm().withOrig(meanAndStdDev[0] + meanAndStdDev[1] * cutoff).withNorm(cutoff); NormContinuous normContinuous = new NormContinuous(FieldName.create(derivedField.getName().getValue())) .withLinearNorms(from, to).withMapMissingTo(0.0) .withOutliers(OutlierTreatmentMethodType.AS_EXTREME_VALUES); // derived field name is consisted of FieldName and "_zscl" derivedFields.add(new DerivedField(OpType.CONTINUOUS, DataType.DOUBLE) .withName(FieldName.create(genPmmlColumnName(config.getColumnName(), normType))) .withExpression(normContinuous)); return derivedFields; } /** * Create @DerivedField for categorical variable * * @param config - ColumnConfig for categorical variable * @param cutoff - cutoff for normalization * @return DerivedField for variable */ protected List<DerivedField> createCategoricalDerivedField(ColumnConfig config, double cutoff, ModelNormalizeConf.NormType normType) { List<DerivedField> derivedFields = new ArrayList<DerivedField>(); DerivedField derivedField = super.createCategoricalDerivedField(config, cutoff, ModelNormalizeConf.NormType.WOE).get(0); derivedFields.add(derivedField); double[] meanAndStdDev = Normalizer.calculateWoeMeanAndStdDev(config, isWeightedNorm); // added capping logic to linearNorm LinearNorm from = new LinearNorm().withOrig(meanAndStdDev[0] - meanAndStdDev[1] * cutoff).withNorm(-cutoff); LinearNorm to = new LinearNorm().withOrig(meanAndStdDev[0] + meanAndStdDev[1] * cutoff).withNorm(cutoff); NormContinuous normContinuous = new NormContinuous(FieldName.create(derivedField.getName().getValue())) .withLinearNorms(from, to).withMapMissingTo(0.0) .withOutliers(OutlierTreatmentMethodType.AS_EXTREME_VALUES); // derived field name is consisted of FieldName and "_zscl" derivedFields.add(new DerivedField(OpType.CONTINUOUS, DataType.DOUBLE) .withName(FieldName.create(genPmmlColumnName(config.getColumnName(), normType))) .withExpression(normContinuous)); return derivedFields; } }