/*
* Copyright [2013-2016] PayPal Software Foundation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.shifu.common;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import ml.shifu.shifu.container.meta.ValidateResult;
import ml.shifu.shifu.container.obj.ColumnConfig;
import ml.shifu.shifu.container.obj.ModelConfig;
import ml.shifu.shifu.core.validator.ModelInspector;
import ml.shifu.shifu.core.validator.ModelInspector.ModelStep;
import ml.shifu.shifu.exception.ShifuErrorCode;
import ml.shifu.shifu.exception.ShifuException;
import ml.shifu.shifu.fs.PathFinder;
import org.apache.commons.lang.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Abstract {@link Step} for all shifu stages. {@link Step} contains basic loading and validation for ModelConfg and
* ColumnConfig.
*
* @author Zhang David (pengzhang@paypal.com)
*/
public abstract class Step<STEP_RESULT> {
private final static Logger LOG = LoggerFactory.getLogger(Step.class);
protected final ModelConfig modelConfig;
protected List<ColumnConfig> columnConfigList;
protected final Map<String, Object> otherConfigs;
protected final ModelStep step;
protected final PathFinder pathFinder;
public Step(ModelStep step, ModelConfig modelConfig, List<ColumnConfig> columnConfigList,
Map<String, Object> otherConfigs) {
// 1. validate model config
try {
validateModelConfig(modelConfig, step);
} catch (Exception e) {
throw new RuntimeException(e);
}
// 2. validate column config
switch(step) {
case INIT:
break;
default:
validateColumnConfig(modelConfig, columnConfigList);
break;
}
this.step = step;
this.modelConfig = modelConfig;
this.columnConfigList = columnConfigList;
this.otherConfigs = otherConfigs;
this.pathFinder = new PathFinder(modelConfig, otherConfigs);
}
public abstract STEP_RESULT process() throws IOException;
/**
* @return the modelConfig
*/
public ModelConfig getModelConfig() {
return modelConfig;
}
/**
* @return the columnConfigList
*/
public List<ColumnConfig> getColumnConfigList() {
return columnConfigList;
}
/**
* @return the otherConfigs
*/
public Map<String, Object> getOtherConfigs() {
return otherConfigs;
}
/**
* Validate the modelconfig if it's well written.
*/
/**
* Validate the modelconfig if it's well written.
*
* @param modelConfig
* the model config
* @param step
* step in Shifu
* @throws Exception
* any exception in validation
*/
protected void validateModelConfig(ModelConfig modelConfig, ModelStep step) throws Exception {
ValidateResult result = new ValidateResult(false);
if(modelConfig == null) {
result.getCauses().add("The ModelConfig is not loaded!");
} else {
result = ModelInspector.getInspector().probe(modelConfig, step);
}
if(!result.getStatus()) {
LOG.error("ModelConfig Validation - Fail! See below:");
for(String cause: result.getCauses()) {
LOG.error("\t!!! " + cause);
}
throw new ShifuException(ShifuErrorCode.ERROR_MODELCONFIG_NOT_VALIDATION);
} else {
LOG.info("ModelConfig Validation - OK");
}
checkAlgParameter(modelConfig);
}
private void checkAlgParameter(ModelConfig modelConfig) {
String alg = modelConfig.getAlgorithm();
Map<String, Object> param = modelConfig.getParams();
LOG.info("Check algorithm parameter");
if(alg.equalsIgnoreCase("LR")) {
if(!param.containsKey("LearningRate")) {
param = new LinkedHashMap<String, Object>();
param.put("LearningRate", 0.1);
modelConfig.setParams(param);
}
} else if(alg.equalsIgnoreCase("NN")) {
if(!param.containsKey("Propagation")) {
param = new LinkedHashMap<String, Object>();
param.put("Propagation", "Q");
param.put("LearningRate", 0.1);
param.put("NumHiddenLayers", 2);
List<Integer> nodes = new ArrayList<Integer>();
nodes.add(20);
nodes.add(10);
param.put("NumHiddenNodes", nodes);
List<String> func = new ArrayList<String>();
func.add("tanh");
func.add("tanh");
param.put("ActivationFunc", func);
modelConfig.setParams(param);
}
} else if(alg.equalsIgnoreCase("SVM")) {
if(!param.containsKey("Kernel")) {
param = new LinkedHashMap<String, Object>();
param.put("Kernel", "linear");
param.put("Gamma", 1.);
param.put("Const", 1.);
modelConfig.setParams(param);
}
} else if(alg.equalsIgnoreCase("DT")) {
// do nothing
} else if(alg.equalsIgnoreCase("RF")) {
if(!param.containsKey("FeatureSubsetStrategy")) {
param = new LinkedHashMap<String, Object>();
param.put("FeatureSubsetStrategy", "all");
param.put("MaxDepth", 10);
param.put("MaxStatsMemoryMB", 256);
param.put("Impurity", "entropy");
modelConfig.setParams(param);
}
} else if(alg.equalsIgnoreCase("GBT")) {
if(!param.containsKey("FeatureSubsetStrategy")) {
param = new LinkedHashMap<String, Object>();
param.put("FeatureSubsetStrategy", "all");
param.put("MaxDepth", 10);
param.put("MaxStatsMemoryMB", 256);
param.put("Impurity", "entropy");
param.put("Loss", "squared");
modelConfig.setParams(param);
}
} else {
throw new ShifuException(ShifuErrorCode.ERROR_UNSUPPORT_ALG);
}
}
private void validateColumnConfig(ModelConfig modelConfig, List<ColumnConfig> columnConfigList) {
if(columnConfigList == null) {
return;
}
Set<String> names = new HashSet<String>();
for(ColumnConfig config: columnConfigList) {
if(names.contains(config.getColumnName())) {
LOG.warn("Duplicated {} in ColumnConfig.json file, later one will be append index to make it unique.",
config.getColumnName());
}
names.add(config.getColumnName());
}
if(!names.contains(modelConfig.getTargetColumnName())) {
throw new IllegalArgumentException("target column " + modelConfig.getTargetColumnName()
+ " does not exist.");
}
if(StringUtils.isNotBlank(modelConfig.getWeightColumnName())
&& !names.contains(modelConfig.getWeightColumnName())) {
throw new IllegalArgumentException("weight column " + modelConfig.getWeightColumnName()
+ " does not exist.");
}
}
/**
* @return the pathFinder
*/
public PathFinder getPathFinder() {
return pathFinder;
}
/**
* @return the step
*/
public ModelStep getStep() {
return step;
}
}