/* * Copyright [2012-2014] PayPal Software Foundation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package ml.shifu.shifu.core; import java.io.Serializable; import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; import java.util.HashMap; import java.util.List; import java.util.Map; import ml.shifu.shifu.container.obj.ColumnConfig; import ml.shifu.shifu.util.CommonUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * Reasoner, it helps to find the the majority contributor to the model */ public class Reasoner { private static Logger log = LoggerFactory.getLogger(Reasoner.class); private Integer numTopVariables = 5; private List<ScoreDiffObject> sdList; private List<String> reasons; private Map<String, String> reasonCodeMap; public Reasoner(Map<String, String> reasonCodeMap) { this.reasonCodeMap = reasonCodeMap; } public void calculateReasonCodes(List<ColumnConfig> columnConfigList, Map<String, String> rawDataMap) { if (columnConfigList == null || columnConfigList.size() == 0) { throw new RuntimeException("ColumnConfig is empty."); } sdList = new ArrayList<ScoreDiffObject>(); reasons = new ArrayList<String>(); for (ColumnConfig config : columnConfigList) { if (!config.isFinalSelect()) { continue; } String key = config.getColumnName(); if (!rawDataMap.containsKey(key)) { log.error("Variable Missing in Test Data: " + key); continue; } // add this, for user may not run post-process and he/she don't want to have reason code // make it common. Just skip this column if (config.getBinAvgScore() == null) { // log.info("No bin average score for " + config.getColumnName()); continue; } ScoreDiffObject sd = new ScoreDiffObject(); int binLength = config.getBinLength(); Integer binNum = binLength; if (config.isNumerical()) { if (rawDataMap.get(key).equals("")) { sd.varValue = config.getMean(); } else { sd.varValue = Double.parseDouble(rawDataMap.get(key)); } List<Double> binBoundary = config.getBinBoundary(); while ((--binNum) >= 0) { if (sd.varValue >= binBoundary.get(binNum)) { break; } } if (binNum == -1) { log.info(sd.varValue.toString()); log.info(binBoundary.toString()); break; } sd.binBoundary = config.getBinBoundary(); sd.binNum = binNum; } else if (config.isCategorical()) { List<String> binCategory = config.getBinCategory(); sd.varCategory = rawDataMap.get(key); while ((--binNum) >= 0) { if (CommonUtils.isCategoricalBinValue(binCategory.get(binNum), sd.varCategory)) { // if (sd.varCategory.equals(binCategory.get(binNum))) { break; } } if (binNum == -1) { log.info("Unknown value."); break; } sd.binCategory = config.getBinCategory(); sd.binNum = binNum; } sd.columnName = config.getColumnName(); sd.columnNum = config.getColumnNum(); sd.scoreDiff = config.getBinAvgScore().get(binNum); sd.binAvgScore = config.getBinAvgScore(); sd.binCountNeg = config.getBinCountNeg(); sd.binCountPos = config.getBinCountPos(); sdList.add(sd); } Collections.sort(sdList, new ScoreDiffComparator()); String reason = null; int n = numTopVariables; if (n > sdList.size()) { n = sdList.size(); } for (int i = 0; i < n; i++) { log.debug(sdList.get(i).columnName + "==>" + sdList.get(i).scoreDiff); reason = reasonCodeMap.get(sdList.get(i).columnName); if (!reasons.contains(reason)) { reasons.add(reason); } } } public List<String> getReasonCodes() { return reasons; } public Map<String, Object> getReasonDetails() { Map<String, Object> map = new HashMap<String, Object>(); map.put("details", sdList.subList(0, numTopVariables)); map.put("reasons", reasons); return map; } static class ScoreDiffObject { public String columnName; public Integer columnNum; public Integer binNum; public Integer scoreDiff; public Double varValue; public String varCategory; public List<Double> binBoundary; public List<String> binCategory; public List<Integer> binAvgScore; public List<Integer> binCountPos; public List<Integer> binCountNeg; } public void setNumTopVariables(Integer numTopVariables) { this.numTopVariables = numTopVariables; } static class ScoreDiffComparator implements Comparator<ScoreDiffObject>, Serializable { private static final long serialVersionUID = 652346402551215269L; public int compare(ScoreDiffObject a, ScoreDiffObject b) { if (!a.scoreDiff.equals(b.scoreDiff)) { return b.scoreDiff.compareTo(a.scoreDiff); } else { return a.columnNum.compareTo(b.columnNum); } } } }