/* * Copyright [2013-2016] PayPal Software Foundation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package ml.shifu.shifu.core.pmml.builder.impl; import java.util.ArrayList; import java.util.List; import java.util.Set; import javax.xml.parsers.DocumentBuilderFactory; import javax.xml.parsers.ParserConfigurationException; import ml.shifu.shifu.container.obj.ColumnConfig; import ml.shifu.shifu.container.obj.ModelConfig; import ml.shifu.shifu.container.obj.ModelNormalizeConf; import ml.shifu.shifu.core.Normalizer; import ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork; import ml.shifu.shifu.core.pmml.builder.creator.AbstractPmmlElementCreator; import ml.shifu.shifu.util.CommonUtils; import org.dmg.pmml.DataType; import org.dmg.pmml.DerivedField; import org.dmg.pmml.FieldColumnPair; import org.dmg.pmml.FieldName; import org.dmg.pmml.InlineTable; import org.dmg.pmml.LinearNorm; import org.dmg.pmml.LocalTransformations; import org.dmg.pmml.MapValues; import org.dmg.pmml.NormContinuous; import org.dmg.pmml.OpType; import org.dmg.pmml.OutlierTreatmentMethodType; import org.dmg.pmml.Row; import org.encog.ml.BasicML; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.w3c.dom.Document; import org.w3c.dom.Element; /** * Created by zhanhu on 3/29/16. */ public class ZscoreLocalTransformCreator extends AbstractPmmlElementCreator<LocalTransformations> { private static final Logger LOG = LoggerFactory.getLogger(ZscoreLocalTransformCreator.class); protected static final String NAME_SPACE_URI = "http://www.dmg.org/PMML-4_2"; protected static final String ELEMENT_OUT = "out"; protected static final String ELEMENT_ORIGIN = "origin"; public ZscoreLocalTransformCreator(ModelConfig modelConfig, List<ColumnConfig> columnConfigList) { super(modelConfig, columnConfigList); } public ZscoreLocalTransformCreator(ModelConfig modelConfig, List<ColumnConfig> columnConfigList, boolean isConcise) { super(modelConfig, columnConfigList, isConcise); } @Override public LocalTransformations build(BasicML basicML) { LocalTransformations localTransformations = new LocalTransformations(); if(basicML instanceof BasicFloatNetwork) { BasicFloatNetwork bfn = (BasicFloatNetwork) basicML; Set<Integer> featureSet = bfn.getFeatureSet(); for(ColumnConfig config: columnConfigList) { if(config.isFinalSelect() && featureSet.contains(config.getColumnName())) { double cutoff = modelConfig.getNormalizeStdDevCutOff(); localTransformations.withDerivedFields(config.isCategorical() ? createCategoricalDerivedField( config, cutoff, modelConfig.getNormalizeType()) : createNumericalDerivedField(config, cutoff, modelConfig.getNormalizeType())); } } } else { for(ColumnConfig config: columnConfigList) { if(config.isFinalSelect()) { double cutoff = modelConfig.getNormalizeStdDevCutOff(); localTransformations.withDerivedFields(config.isCategorical() ? createCategoricalDerivedField( config, cutoff, modelConfig.getNormalizeType()) : createNumericalDerivedField(config, cutoff, modelConfig.getNormalizeType())); } } } return localTransformations; } /** * Create DerivedField for categorical variable * * @param config * - ColumnConfig for categorical variable * @param cutoff * - cutoff for normalization * @param normType * - the normalization method that is used to generate DerivedField * @return DerivedField for variable */ protected List<DerivedField> createCategoricalDerivedField(ColumnConfig config, double cutoff, ModelNormalizeConf.NormType normType) { Document document = null; try { document = DocumentBuilderFactory.newInstance().newDocumentBuilder().newDocument(); } catch (ParserConfigurationException e) { LOG.error("Fail to create document node.", e); throw new RuntimeException("Fail to create document node.", e); } String defaultValue = Normalizer.normalize(config, "doesn't exist at all...by paypal", cutoff, normType) .toString(); String missingValue = Normalizer.normalize(config, null, cutoff, normType).toString(); InlineTable inlineTable = new InlineTable(); for(int i = 0; i < config.getBinCategory().size(); i++) { List<String> catVals = CommonUtils.flattenCatValGrp(config.getBinCategory().get(i)); for(String cval: catVals) { String dval = Normalizer.normalize(config, cval, cutoff, normType).toString(); Element out = document.createElementNS(NAME_SPACE_URI, ELEMENT_OUT); out.setTextContent(dval); Element origin = document.createElementNS(NAME_SPACE_URI, ELEMENT_ORIGIN); origin.setTextContent(cval); inlineTable.withRows(new Row().withContent(origin).withContent(out)); } } MapValues mapValues = new MapValues("out").withDataType(DataType.DOUBLE).withDefaultValue(defaultValue) .withFieldColumnPairs(new FieldColumnPair(new FieldName(config.getColumnName()), ELEMENT_ORIGIN)) .withInlineTable(inlineTable).withMapMissingTo(missingValue); List<DerivedField> derivedFields = new ArrayList<DerivedField>(); derivedFields.add(new DerivedField(OpType.CONTINUOUS, DataType.DOUBLE).withName( FieldName.create(genPmmlColumnName(config.getColumnName(), normType))).withExpression(mapValues)); return derivedFields; } /** * Create @DerivedField for numerical variable * * @param config * - ColumnConfig for numerical variable * @param cutoff * - cutoff of normalization * @param normType * - the normalization method that is used to generate DerivedField * @return DerivedField for variable */ protected List<DerivedField> createNumericalDerivedField(ColumnConfig config, double cutoff, ModelNormalizeConf.NormType normType) { // added capping logic to linearNorm LinearNorm from = new LinearNorm().withOrig(config.getMean() - config.getStdDev() * cutoff).withNorm(-cutoff); LinearNorm to = new LinearNorm().withOrig(config.getMean() + config.getStdDev() * cutoff).withNorm(cutoff); NormContinuous normContinuous = new NormContinuous(FieldName.create(config.getColumnName())) .withLinearNorms(from, to).withMapMissingTo(0.0) .withOutliers(OutlierTreatmentMethodType.AS_EXTREME_VALUES); // derived field name is consisted of FieldName and "_zscl" List<DerivedField> derivedFields = new ArrayList<DerivedField>(); derivedFields.add(new DerivedField(OpType.CONTINUOUS, DataType.DOUBLE).withName( FieldName.create(genPmmlColumnName(config.getColumnName(), normType))).withExpression(normContinuous)); return derivedFields; } /** * Convert column name into PMML format(with normalization) * * @param columnName * the column name * @param normType * the norm type * @return - PMML standard column name */ public static String genPmmlColumnName(String columnName, ModelNormalizeConf.NormType normType) { return columnName + "_" + normType.name().toLowerCase(); } }