/*
* Copyright [2013-2016] PayPal Software Foundation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.shifu.shifu.core.pmml.builder.impl;
import ml.shifu.shifu.container.obj.ColumnConfig;
import ml.shifu.shifu.container.obj.ModelConfig;
import ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork;
import ml.shifu.shifu.core.pmml.builder.creator.AbstractPmmlElementCreator;
import org.apache.commons.lang.StringUtils;
import org.dmg.pmml.*;
import org.encog.ml.BasicML;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
/**
* Created by zhanhu on 3/29/16.
*/
public class ModelStatsCreator extends AbstractPmmlElementCreator<ModelStats> {
private static final double EPS = 1e-10;
public ModelStatsCreator(ModelConfig modelConfig, List<ColumnConfig> columnConfigList) {
super(modelConfig, columnConfigList);
}
public ModelStatsCreator(ModelConfig modelConfig, List<ColumnConfig> columnConfigList, boolean isConcise) {
super(modelConfig, columnConfigList, isConcise);
}
@Override
public ModelStats build(BasicML basicML) {
ModelStats modelStats = new ModelStats();
if(basicML instanceof BasicFloatNetwork) {
BasicFloatNetwork bfn = (BasicFloatNetwork) basicML;
Set<Integer> featureSet = bfn.getFeatureSet();
for(ColumnConfig columnConfig: columnConfigList) {
if(columnConfig.isFinalSelect() && featureSet.contains(columnConfig.getColumnNum())) {
UnivariateStats univariateStats = new UnivariateStats();
univariateStats.setField(FieldName.create(columnConfig.getColumnName()));
if(columnConfig.isCategorical()) {
DiscrStats discrStats = new DiscrStats();
Array countArray = createCountArray(columnConfig);
discrStats.withArrays(countArray);
if(!isConcise) {
List<Extension> extensions = createExtensions(columnConfig);
discrStats.withExtensions(extensions);
}
univariateStats.setDiscrStats(discrStats);
} else { // numerical column
univariateStats.setNumericInfo(createNumericInfo(columnConfig));
if(!isConcise) {
univariateStats.setContStats(createConStats(columnConfig));
}
}
modelStats.withUnivariateStats(univariateStats);
}
}
} else {
for(ColumnConfig columnConfig: columnConfigList) {
if(columnConfig.isFinalSelect()) {
UnivariateStats univariateStats = new UnivariateStats();
univariateStats.setField(FieldName.create(columnConfig.getColumnName()));
if(columnConfig.isCategorical()) {
DiscrStats discrStats = new DiscrStats();
Array countArray = createCountArray(columnConfig);
discrStats.withArrays(countArray);
if(!isConcise) {
List<Extension> extensions = createExtensions(columnConfig);
discrStats.withExtensions(extensions);
}
univariateStats.setDiscrStats(discrStats);
} else { // numerical column
univariateStats.setNumericInfo(createNumericInfo(columnConfig));
if(!isConcise) {
univariateStats.setContStats(createConStats(columnConfig));
}
}
modelStats.withUnivariateStats(univariateStats);
}
}
}
return modelStats;
}
/**
* Create @Array for numerical variable
*
* @param columnConfig
* - ColumnConfig for numerical variable
* @return Array for numerical variable ( positive count + negative count )
*/
private Array createCountArray(ColumnConfig columnConfig) {
Array countAllArray = new Array();
List<Integer> binCountAll = new ArrayList<Integer>(columnConfig.getBinCountPos().size());
for(int i = 0; i < binCountAll.size(); i++) {
binCountAll.add(columnConfig.getBinCountPos().get(i) + columnConfig.getBinCountNeg().get(i));
}
countAllArray.setType(Array.Type.INT);
countAllArray.setN(binCountAll.size());
countAllArray.setValue(StringUtils.join(binCountAll, ' '));
return countAllArray;
}
/**
* Create common extension list from ColumnConfig
*
* @param columnConfig
* - ColumnConfig to create extension
* @return extension list
*/
private List<Extension> createExtensions(ColumnConfig columnConfig) {
Map<String, String> extensionMap = new HashMap<String, String>();
extensionMap.put("BinCountPos", columnConfig.getBinCountPos().toString());
extensionMap.put("BinCountNeg", columnConfig.getBinCountNeg().toString());
extensionMap.put("BinWeightedCountPos", columnConfig.getBinWeightedPos().toString());
extensionMap.put("BinWeightedCountNeg", columnConfig.getBinWeightedNeg().toString());
extensionMap.put("BinPosRate", columnConfig.getBinPosRate().toString());
return createExtensions(extensionMap);
}
/**
* Create extension list from HashMap
*
* @param extensionMap
* the <String,String> map to create extension list
* @return extension list
*/
private List<Extension> createExtensions(Map<String, String> extensionMap) {
List<Extension> extensions = new ArrayList<Extension>();
for(Map.Entry<String, String> entry: extensionMap.entrySet()) {
String key = entry.getKey();
Extension extension = new Extension();
extension.setName(key);
extension.setValue(entry.getValue());
extensions.add(extension);
}
return extensions;
}
/**
* Create @NumericInfo for numerical variable
*
* @param columnConfig
* - ColumnConfig for numerical variable
* @return NumericInfo for variable
*/
private NumericInfo createNumericInfo(ColumnConfig columnConfig) {
NumericInfo numericInfo = new NumericInfo();
numericInfo.setMaximum(columnConfig.getColumnStats().getMax());
numericInfo.setMinimum(columnConfig.getColumnStats().getMin());
numericInfo.setMean(columnConfig.getMean());
numericInfo.setMedian(columnConfig.getMedian());
numericInfo.setStandardDeviation(columnConfig.getStdDev());
return numericInfo;
}
/**
* Create @ConStats for numerical variable
*
* @param columnConfig
* - ColumnConfig to generate ConStats
* @return ConStats for variable
*/
private ContStats createConStats(ColumnConfig columnConfig) {
ContStats conStats = new ContStats();
List<Interval> intervals = new ArrayList<Interval>();
for(int i = 0; i < columnConfig.getBinBoundary().size(); i++) {
Interval interval = new Interval();
interval.setClosure(Interval.Closure.OPEN_CLOSED);
interval.setLeftMargin(columnConfig.getBinBoundary().get(i));
if(i == columnConfig.getBinBoundary().size() - 1) {
interval.setRightMargin(Double.POSITIVE_INFINITY);
} else {
interval.setRightMargin(columnConfig.getBinBoundary().get(i + 1));
}
intervals.add(interval);
}
conStats.withIntervals(intervals);
Map<String, String> extensionMap = new HashMap<String, String>();
extensionMap.put("BinCountPos", columnConfig.getBinCountPos().toString());
extensionMap.put("BinCountNeg", columnConfig.getBinCountNeg().toString());
extensionMap.put("BinWeightedCountPos", columnConfig.getBinWeightedPos().toString());
extensionMap.put("BinWeightedCountNeg", columnConfig.getBinWeightedNeg().toString());
extensionMap.put("BinPosRate", columnConfig.getBinPosRate().toString());
extensionMap.put("BinWOE", calculateWoe(columnConfig.getBinCountPos(), columnConfig.getBinCountNeg())
.toString());
extensionMap.put("KS", Double.toString(columnConfig.getKs()));
extensionMap.put("IV", Double.toString(columnConfig.getIv()));
conStats.withExtensions(createExtensions(extensionMap));
return conStats;
}
/**
* Generate Woe data from positive and negative counts
*
* @param binCountPos
* - positive count list
* @param binCountNeg
* - negative count list
* @return Woe value list
*/
private List<Double> calculateWoe(List<Integer> binCountPos, List<Integer> binCountNeg) {
List<Double> woe = new ArrayList<Double>();
double sumPos = 0.0;
double sumNeg = 0.0;
for(int i = 0; i < binCountPos.size(); i++) {
sumPos += binCountPos.get(i);
sumNeg += binCountNeg.get(i);
}
for(int i = 0; i < binCountPos.size(); i++) {
woe.add(Math.log((binCountPos.get(i) / sumPos + EPS) / (binCountNeg.get(i) / sumNeg + EPS)));
}
return woe;
}
}