/*
* Copyright [2012-2014] PayPal Software Foundation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.shifu.shifu.core;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;
import ml.shifu.shifu.column.NSColumn;
import ml.shifu.shifu.container.ScoreObject;
import ml.shifu.shifu.container.obj.ColumnConfig;
import ml.shifu.shifu.container.obj.ModelConfig;
import ml.shifu.shifu.core.dtrain.DTrainUtils;
import ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork;
import ml.shifu.shifu.core.dtrain.nn.NNConstants;
import ml.shifu.shifu.executor.ExecutorManager;
import ml.shifu.shifu.util.CommonUtils;
import ml.shifu.shifu.util.Constants;
import org.apache.commons.collections.CollectionUtils;
import org.encog.ml.BasicML;
import org.encog.ml.data.MLData;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.svm.SVM;
import org.encog.neural.networks.BasicNetwork;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Scorer, calculate the score for a specify input
*/
public class Scorer {
private static Logger log = LoggerFactory.getLogger(Scorer.class);
public static final int DEFAULT_SCORE_SCALE = 1000;
private String alg;
private List<BasicML> models;
private List<ColumnConfig> columnConfigList;
private double cutoff = 4.0d;
private ModelConfig modelConfig;
private int scale = DEFAULT_SCORE_SCALE;
/**
* No any variables set to finalSelect=true, we should take all candidate variables as inputs.
*/
private boolean noVarSelect = false;
/**
* For faster query from categorical bins
*/
private Map<Integer, Map<String, Integer>> binCategoryMap = new HashMap<Integer, Map<String, Integer>>();
/**
* Run model in parallel. Size is # of models.
*/
private ExecutorManager<MLData> executorManager;
public Scorer(List<BasicML> models, List<ColumnConfig> columnConfigList, String algorithm, ModelConfig modelConfig) {
this(models, columnConfigList, algorithm, modelConfig, 4.0d);
}
public Scorer(List<BasicML> models, List<ColumnConfig> columnConfigList, String algorithm, ModelConfig modelConfig,
Double cutoff) {
if(modelConfig == null) {
throw new IllegalArgumentException("modelConfig should not be null");
}
this.models = models;
this.columnConfigList = columnConfigList;
this.cutoff = cutoff;
this.alg = algorithm;
this.modelConfig = modelConfig;
if(this.columnConfigList != null) {
int[] inputOutputIndex = DTrainUtils.getInputOutputCandidateCounts(this.columnConfigList);
int inputNodeCount = inputOutputIndex[0] == 0 ? inputOutputIndex[2] : inputOutputIndex[0];
int candidateCount = inputOutputIndex[2];
this.noVarSelect = (inputNodeCount == candidateCount);
}
// compute binCategoryMap for all algorithm while only be used in
if(this.columnConfigList != null) {
for(ColumnConfig columnConfig: this.columnConfigList) {
if(columnConfig.isCategorical()) {
Map<String, Integer> map = new HashMap<String, Integer>();
List<String> categories = columnConfig.getBinCategory();
if(categories != null) {
for(int i = 0; i < categories.size(); i++) {
String categoricalVal = categories.get(i);
if(categoricalVal == null) {
map.put("", i);
} else {
List<String> cvals = CommonUtils.flattenCatValGrp(categoricalVal);
for(String cval: cvals) {
map.put(cval, i);
}
}
map.put(categories.get(i) == null ? "" : categories.get(i), i);
}
}
this.binCategoryMap.put(columnConfig.getColumnNum(), map);
}
}
}
this.executorManager = new ExecutorManager<MLData>(Math.min(Runtime.getRuntime().availableProcessors(),
(models.size() == 0 ? 5 : models.size())));
// add a shutdown hook as a safe guard if some one not call close
Runtime.getRuntime().addShutdownHook(new Thread(new Runnable() {
@Override
public void run() {
Scorer.this.executorManager.forceShutDown();
}
}));
}
/**
* Cleaning the thread pool resources, must be called at last.
*/
public void close() {
this.executorManager.forceShutDown();
}
public ScoreObject score(Map<String, String> rawDataMap) {
return scoreNsData(CommonUtils.convertRawMapToNsDataMap(rawDataMap));
}
/**
* Run model against raw NSColumn Data map to get score
*
* @param rawDataNsMap
* - raw NSColumn Data map
* @return ScoreObject - model score
*/
public ScoreObject scoreNsData(Map<NSColumn, String> rawDataNsMap) {
return scoreNsData(null, rawDataNsMap);
}
public ScoreObject score(final MLDataPair pair, Map<String, String> rawDataMap) {
return scoreNsData(pair, CommonUtils.convertRawMapToNsDataMap(rawDataMap));
}
public ScoreObject scoreNsData(MLDataPair inputPair, Map<NSColumn, String> rawNsDataMap) {
if(inputPair == null && !this.alg.equalsIgnoreCase(NNConstants.NN_ALG_NAME)) {
inputPair = CommonUtils.assembleNsDataPair(binCategoryMap, noVarSelect, modelConfig, columnConfigList,
rawNsDataMap, cutoff, alg);
}
final MLDataPair pair = inputPair;
List<Callable<MLData>> tasks = new ArrayList<Callable<MLData>>();
for(final BasicML model: models) {
// TODO, check if no need 'if' condition and refactor two if for loops please
if(model instanceof BasicFloatNetwork) {
final BasicFloatNetwork network = (BasicFloatNetwork) model;
final MLDataPair networkPair = CommonUtils.assembleNsDataPair(binCategoryMap, noVarSelect, modelConfig,
columnConfigList, rawNsDataMap, cutoff, alg, network.getFeatureSet());
if(network.getFeatureSet().size() != networkPair.getInput().size()) {
log.error("Network and input size mismatch: Network Size = " + network.getFeatureSet().size()
+ "; Input Size = " + networkPair.getInput().size());
continue;
}
tasks.add(new Callable<MLData>() {
@Override
public MLData call() throws Exception {
return network.compute(networkPair.getInput());
}
});
} else if(model instanceof SVM) {
final SVM svm = (SVM) model;
if(svm.getInputCount() != pair.getInput().size()) {
log.error("SVM and input size mismatch: SVM Size = " + svm.getInputCount() + "; Input Size = "
+ pair.getInput().size());
continue;
}
tasks.add(new Callable<MLData>() {
@Override
public MLData call() throws Exception {
return svm.compute(pair.getInput());
}
});
} else if(model instanceof LR) {
final LR lr = (LR) model;
if(lr.getInputCount() != pair.getInput().size()) {
log.error("LR and input size mismatch: LR Size = " + lr.getInputCount() + "; Input Size = "
+ pair.getInput().size());
continue;
}
tasks.add(new Callable<MLData>() {
@Override
public MLData call() throws Exception {
return lr.compute(pair.getInput());
}
});
} else if(model instanceof TreeModel) {
final TreeModel tm = (TreeModel) model;
if(tm.getInputCount() != pair.getInput().size()) {
throw new RuntimeException("GBDT and input size mismatch: tm input Size = " + tm.getInputCount()
+ "; data input Size = " + pair.getInput().size());
}
tasks.add(new Callable<MLData>() {
@Override
public MLData call() throws Exception {
MLData result = tm.compute(pair.getInput());
return result;
}
});
} else {
throw new RuntimeException("unsupport models");
}
}
List<Double> scores = new ArrayList<Double>();
List<Integer> rfTreeSizeList = new ArrayList<Integer>();
if(CollectionUtils.isNotEmpty(tasks)) {
List<MLData> modelResults = this.executorManager.submitTasksAndWaitResults(tasks);
if(CollectionUtils.isEmpty(modelResults) || modelResults.size() != this.models.size()) {
log.error("Get empty model results or model results size doesn't match with models size.");
return null;
}
for(int i = 0; i < this.models.size(); i++) {
BasicML model = this.models.get(i);
MLData score = modelResults.get(i);
if(model instanceof BasicNetwork) {
if(modelConfig != null && modelConfig.isRegression()) {
scores.add(toScore(score.getData(0)));
} else if(modelConfig != null && modelConfig.isClassification()
&& modelConfig.getTrain().isOneVsAll()) {
// if one vs all classification
scores.add(toScore(score.getData(0)));
} else {
double[] outputs = score.getData();
for(double d: outputs) {
scores.add(toScore(d));
}
}
} else if(model instanceof SVM) {
scores.add(toScore(score.getData(0)));
} else if(model instanceof LR) {
scores.add(toScore(score.getData(0)));
} else if(model instanceof TreeModel) {
if(modelConfig.isClassification() && !modelConfig.getTrain().isOneVsAll()) {
double[] scoreArray = score.getData();
for(double sc: scoreArray) {
scores.add(sc);
}
} else {
// if one vs all multiple classification or regression
scores.add(toScore(score.getData(0)));
}
final TreeModel tm = (TreeModel) model;
// regression for RF
if(!tm.isClassfication() && !tm.isGBDT()) {
rfTreeSizeList.add(tm.getTrees().size());
}
} else {
throw new RuntimeException("unsupport models");
}
}
}
Integer tag = Constants.DEFAULT_IDEAL_VALUE;
if(scores.size() == 0) {
log.warn("No Scores Calculated...");
}
return new ScoreObject(scores, tag, rfTreeSizeList);
}
private double toScore(Double d) {
return d * scale;
}
public int getModelCnt() {
return ((models != null) ? this.models.size() : 0);
}
public int getScale() {
return scale;
}
public void setScale(int scale) {
if(scale > 0) {
this.scale = scale;
}
}
}