/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.shifu.shifu.core.processor;
import ml.shifu.shifu.column.NSColumnUtils;
import ml.shifu.shifu.container.obj.ColumnConfig;
import ml.shifu.shifu.container.obj.ModelTrainConf.ALGORITHM;
import ml.shifu.shifu.container.obj.RawSourceData.SourceType;
import ml.shifu.shifu.core.ColumnStatsCalculator;
import ml.shifu.shifu.core.binning.CateDynamicBinning;
import ml.shifu.shifu.core.binning.CategoricalBinInfo;
import ml.shifu.shifu.core.pmml.PMMLTranslator;
import ml.shifu.shifu.core.pmml.PMMLUtils;
import ml.shifu.shifu.core.pmml.builder.PMMLConstructorFactory;
import ml.shifu.shifu.core.validator.ModelInspector.ModelStep;
import ml.shifu.shifu.fs.ShifuFileUtils;
import ml.shifu.shifu.util.CommonUtils;
import ml.shifu.shifu.util.HDFSUtils;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.collections.MapUtils;
import org.apache.commons.io.FileUtils;
import org.apache.commons.lang.StringUtils;
import org.apache.hadoop.fs.Path;
import org.dmg.pmml.PMML;
import org.encog.ml.BasicML;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.BufferedWriter;
import java.io.File;
import java.io.IOException;
import java.util.*;
/**
* ExportModelProcessor class
*
* @author zhanhu
*/
public class ExportModelProcessor extends BasicModelProcessor implements Processor {
public static final String PMML = "pmml";
public static final String COLUMN_STATS = "columnstats";
public static final String WOE_MAPPING = "woemapping";
public static final String IS_CONCISE = "IS_CONCISE";
public static final String REQUEST_VARS = "REQUEST_VARS";
public static final String EXPECTED_BIN_NUM = "EXPECTED_BIN_NUM";
/**
* log object
*/
private final static Logger log = LoggerFactory.getLogger(ExportModelProcessor.class);
private String type;
private Map<String, Object> params;
public ExportModelProcessor(String type, Map<String, Object> params) {
this.type = type;
this.params = params;
}
/*
* (non-Javadoc)
*
* @see ml.shifu.shifu.core.processor.Processor#run()
*/
@Override
public int run() throws Exception {
setUp(ModelStep.EXPORT);
int status = 0;
File pmmls = new File("pmmls");
FileUtils.forceMkdir(pmmls);
if(StringUtils.isBlank(type)) {
type = PMML;
}
if(type.equalsIgnoreCase(PMML)) {
log.info("Convert models into {} format", type);
List<BasicML> models = CommonUtils.loadBasicModels(pathFinder.getModelsPath(SourceType.LOCAL),
ALGORITHM.valueOf(modelConfig.getAlgorithm().toUpperCase()));
PMMLTranslator translator = PMMLConstructorFactory.produce(modelConfig, columnConfigList, isConcise());
for(int index = 0; index < models.size(); index++) {
log.info("\t start to generate " + "pmmls" + File.separator + modelConfig.getModelSetName()
+ Integer.toString(index) + ".pmml");
PMML pmml = translator.build(models.get(index));
PMMLUtils.savePMML(pmml,
"pmmls" + File.separator + modelConfig.getModelSetName() + Integer.toString(index) + ".pmml");
}
} else if(type.equalsIgnoreCase(COLUMN_STATS)) {
saveColumnStatus();
} else if(type.equalsIgnoreCase(WOE_MAPPING)) {
List<ColumnConfig> exportCatColumns = new ArrayList<ColumnConfig>();
List<String> catVariables = getRequestVars();
for ( ColumnConfig columnConfig : this.columnConfigList ) {
if ( columnConfig.isCategorical() ) {
if ( CollectionUtils.isEmpty(catVariables) || isRequestColumn(catVariables, columnConfig)) {
exportCatColumns.add(columnConfig);
}
}
}
if ( CollectionUtils.isNotEmpty(exportCatColumns) ) {
List<String> woeMappings = new ArrayList<String>();
for ( ColumnConfig columnConfig : exportCatColumns ) {
String woeMapText = rebinAndExportWoeMapping(columnConfig);
woeMappings.add(woeMapText);
}
FileUtils.write(new File("woemapping.txt"), StringUtils.join(woeMappings, ",\n"));
}
} else {
log.error("Unsupported output format - {}", type);
status = -1;
}
clearUp(ModelStep.EXPORT);
log.info("Done.");
return status;
}
private String rebinAndExportWoeMapping(ColumnConfig columnConfig) throws IOException {
String woeMappingText;
int expectBinNum = getExpectBinNum();
if ( expectBinNum > 0 && columnConfig.getBinCategory().size() > expectBinNum ) { // needs to do rebin
List<CategoricalBinInfo> categoricalBinInfos = genCategoricalBinInfos(columnConfig);
Collections.sort(categoricalBinInfos);
CateDynamicBinning inst = new CateDynamicBinning(expectBinNum);
categoricalBinInfos = inst.merge(categoricalBinInfos);
long[] binCountNeg = new long[categoricalBinInfos.size() + 1];
long[] binCountPos = new long[categoricalBinInfos.size() + 1];
for ( int i = 0; i < categoricalBinInfos.size(); i ++ ) {
CategoricalBinInfo binInfo = categoricalBinInfos.get(i);
binCountNeg[i] = binInfo.getNegativeCnt();
binCountPos[i] = binInfo.getPositiveCnt();
}
binCountNeg[binCountNeg.length - 1] =
columnConfig.getBinCountNeg().get(columnConfig.getBinCountNeg().size() - 1);
binCountPos[binCountPos.length - 1] =
columnConfig.getBinCountPos().get(columnConfig.getBinCountPos().size() - 1);
ColumnStatsCalculator.ColumnMetrics columnMetrics =
ColumnStatsCalculator.calculateColumnMetrics(binCountNeg, binCountPos);
System.out.println(columnConfig.getColumnName() + ":");
for ( int i = 0; i < categoricalBinInfos.size(); i ++ ) {
CategoricalBinInfo binInfo = categoricalBinInfos.get(i);
System.out.println("\t" + binInfo.getValues()
+ " | posCount:" + binInfo.getPositiveCnt()
+ " | negCount:" + binInfo.getNegativeCnt()
+ " | posRate:" + binInfo.getPositiveRate()
+ " | woe:" + columnMetrics.getBinningWoe().get(i));
}
System.out.println("\t" + columnConfig.getColumnName() + " IV:" + columnMetrics.getIv());
System.out.println("\t" + columnConfig.getColumnName() + " KS:" + columnMetrics.getKs());
System.out.println("\t" + columnConfig.getColumnName() + " WOE:" + columnMetrics.getWoe());
woeMappingText = generateWoeMapping(columnConfig.getColumnName(), categoricalBinInfos,
columnMetrics, expectBinNum);
} else {
System.out.println(columnConfig.getColumnName() + ":");
for ( int i = 0; i < columnConfig.getBinCategory().size(); i ++ ) {
System.out.println("\t[" + columnConfig.getBinCategory().get(i)
+ "] | posCount:" + columnConfig.getBinCountPos().get(i)
+ " | negCount:" + columnConfig.getBinCountNeg().get(i)
+ " | posRate:" + columnConfig.getBinPosRate().get(i)
+ " | woe:" + columnConfig.getBinCountWoe().get(i));
}
System.out.println("\t" + columnConfig.getColumnName() + " IV:" + columnConfig.getIv());
System.out.println("\t" + columnConfig.getColumnName() + " KS:" + columnConfig.getKs());
System.out.println("\t" + columnConfig.getColumnName() + " WOE:" + columnConfig.getColumnStats().getWoe());
woeMappingText = generateWoeMapping(columnConfig, expectBinNum);
}
return woeMappingText;
}
private String generateWoeMapping(String varName, List<CategoricalBinInfo> categoricalBinInfos,
ColumnStatsCalculator.ColumnMetrics columnMetrics, int expectBinNum) {
StringBuffer buffer = new StringBuffer();
buffer.append("( case \n");
for ( int i = 0; i < categoricalBinInfos.size(); i ++ ) {
CategoricalBinInfo binInfo = categoricalBinInfos.get(i);
List<String> values = new ArrayList<String>();
for ( String cval : binInfo.getValues() ) {
List<String> subCvals = CommonUtils.flattenCatValGrp(cval);
for ( String subCval : subCvals ) {
values.add("'" + subCval + "'");
}
}
buffer.append("\twhen " + varName + " in (" + StringUtils.join(values,',')
+ ") then " + columnMetrics.getBinningWoe().get(i) + "\n");
}
buffer.append("\telse " + columnMetrics.getBinningWoe().get(columnMetrics.getBinningWoe().size() - 1) + "\n");
buffer.append(" end ) as " + varName + "_" + expectBinNum);
return buffer.toString();
}
private String generateWoeMapping(ColumnConfig columnConfig, int expectBinNum) {
StringBuffer buffer = new StringBuffer();
buffer.append("( case \n");
for ( int i = 0; i < columnConfig.getBinCategory().size(); i ++ ) {
List<String> values = new ArrayList<String>();
String cval = columnConfig.getBinCategory().get(i);
List<String> subCvals = CommonUtils.flattenCatValGrp(cval);
for ( String subCval : subCvals ) {
values.add("'" + subCval + "'");
}
buffer.append("\twhen " + columnConfig.getColumnName() + " in (" + StringUtils.join(values,',')
+ ") then " + columnConfig.getBinCountWoe().get(i) + "\n");
}
buffer.append("\telse " + columnConfig.getBinCountWoe().get(columnConfig.getBinCountWoe().size() - 1) + "\n");
buffer.append(" end ) as " + columnConfig.getColumnName() + "_" + expectBinNum);
return buffer.toString();
}
private List<CategoricalBinInfo> genCategoricalBinInfos(ColumnConfig columnConfig) {
List<CategoricalBinInfo> categoricalBinInfos = new ArrayList<CategoricalBinInfo>();
for ( int i = 0; i < columnConfig.getBinCategory().size(); i ++ ) {
CategoricalBinInfo binInfo = new CategoricalBinInfo();
List<String> values = new ArrayList<String>();
values.add(columnConfig.getBinCategory().get(i));
binInfo.setValues(values);
binInfo.setPositiveCnt(columnConfig.getBinCountPos().get(i));
binInfo.setNegativeCnt(columnConfig.getBinCountNeg().get(i));
binInfo.setWeightPos(columnConfig.getBinWeightedPos().get(i));
binInfo.setWeightNeg(columnConfig.getBinWeightedNeg().get(i));
categoricalBinInfos.add(binInfo);
}
return categoricalBinInfos;
}
private boolean isRequestColumn(List<String> catVariables, ColumnConfig columnConfig) {
boolean status = false;
for ( String varName : catVariables ) {
if (NSColumnUtils.isColumnEqual(varName, columnConfig.getColumnName()) ) {
status = true;
break;
}
}
return status;
}
private void saveColumnStatus() throws IOException {
Path localColumnStatsPath = new Path(pathFinder.getLocalColumnStatsPath());
log.info("Saving ColumnStatus to local file system: {}.", localColumnStatsPath);
if(HDFSUtils.getLocalFS().exists(localColumnStatsPath)) {
HDFSUtils.getLocalFS().delete(localColumnStatsPath, true);
}
BufferedWriter writer = null;
try {
writer = ShifuFileUtils.getWriter(localColumnStatsPath.toString(), SourceType.LOCAL);
writer.write("dataSet,columnFlag,columnName,columnNum,iv,ks,max,mean,median,min,missingCount,"
+ "missingPercentage,stdDev,totalCount,distinctCount,weightedIv,weightedKs,weightedWoe,woe,"
+ "skewness,kurtosis,columnType,finalSelect,psi,unitstats,version\n");
StringBuilder builder = new StringBuilder(500);
for(ColumnConfig columnConfig: columnConfigList) {
builder.setLength(0);
builder.append(modelConfig.getBasic().getName()).append(',');
builder.append(columnConfig.getColumnFlag()).append(',');
builder.append(columnConfig.getColumnName()).append(',');
builder.append(columnConfig.getColumnNum()).append(',');
builder.append(columnConfig.getIv()).append(',');
builder.append(columnConfig.getKs()).append(',');
builder.append(columnConfig.getColumnStats().getMax()).append(',');
builder.append(columnConfig.getColumnStats().getMean()).append(',');
builder.append(columnConfig.getColumnStats().getMedian()).append(',');
builder.append(columnConfig.getColumnStats().getMin()).append(',');
builder.append(columnConfig.getColumnStats().getMissingCount()).append(',');
builder.append(columnConfig.getColumnStats().getMissingPercentage()).append(',');
builder.append(columnConfig.getColumnStats().getStdDev()).append(',');
builder.append(columnConfig.getColumnStats().getTotalCount()).append(',');
builder.append(columnConfig.getColumnStats().getDistinctCount()).append(',');
builder.append(columnConfig.getColumnStats().getWeightedIv()).append(',');
builder.append(columnConfig.getColumnStats().getWeightedKs()).append(',');
builder.append(columnConfig.getColumnStats().getWeightedWoe()).append(',');
builder.append(columnConfig.getColumnStats().getWoe()).append(',');
builder.append(columnConfig.getColumnStats().getSkewness()).append(',');
builder.append(columnConfig.getColumnStats().getKurtosis()).append(',');
builder.append(columnConfig.getColumnType()).append(',');
builder.append(columnConfig.isFinalSelect()).append(',');
builder.append(columnConfig.getPSI()).append(',');
builder.append(StringUtils.join(columnConfig.getUnitStats(), '|')).append(',');
builder.append(modelConfig.getBasic().getVersion()).append("\n");
writer.write(builder.toString());
}
} finally {
writer.close();
}
}
private boolean isConcise() {
if ( MapUtils.isNotEmpty(this.params) && this.params.get(IS_CONCISE) instanceof Boolean ) {
return (Boolean) this.params.get(IS_CONCISE);
}
return false;
}
private List<String> getRequestVars() {
if ( MapUtils.isNotEmpty(this.params) && this.params.get(REQUEST_VARS) instanceof String ) {
String requestVars = (String) this.params.get(REQUEST_VARS);
if ( StringUtils.isNotBlank(requestVars) ) {
return Arrays.asList(requestVars.split(","));
}
}
return null;
}
private int getExpectBinNum() {
if ( MapUtils.isNotEmpty(this.params) && this.params.get(EXPECTED_BIN_NUM) instanceof String) {
String expectBinNum = (String) this.params.get(EXPECTED_BIN_NUM);
try {
return Integer.parseInt(expectBinNum);
} catch (Exception e) {
log.warn("Invalid expect bin num {}. Ignore it...", expectBinNum);
}
}
return 0;
}
}