/* * Copyright [2012-2014] PayPal Software Foundation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package ml.shifu.shifu.core.validator; import java.io.IOException; import java.util.Arrays; import java.util.List; import java.util.Map; import ml.shifu.shifu.container.meta.MetaFactory; import ml.shifu.shifu.container.meta.ValidateResult; import ml.shifu.shifu.container.obj.EvalConfig; import ml.shifu.shifu.container.obj.ModelBasicConf.RunMode; import ml.shifu.shifu.container.obj.ModelConfig; import ml.shifu.shifu.container.obj.ModelNormalizeConf; import ml.shifu.shifu.container.obj.ModelNormalizeConf.NormType; import ml.shifu.shifu.container.obj.ModelSourceDataConf; import ml.shifu.shifu.container.obj.ModelStatsConf.BinningAlgorithm; import ml.shifu.shifu.container.obj.ModelStatsConf.BinningMethod; import ml.shifu.shifu.container.obj.ModelTrainConf; import ml.shifu.shifu.container.obj.ModelTrainConf.MultipleClassification; import ml.shifu.shifu.container.obj.ModelVarSelectConf; import ml.shifu.shifu.container.obj.ModelVarSelectConf.PostCorrelationMetric; import ml.shifu.shifu.container.obj.RawSourceData; import ml.shifu.shifu.container.obj.RawSourceData.SourceType; import ml.shifu.shifu.core.dtrain.CommonConstants; import ml.shifu.shifu.core.dtrain.DTrainUtils; import ml.shifu.shifu.core.dtrain.FeatureSubsetStrategy; import ml.shifu.shifu.core.dtrain.gs.GridSearch; import ml.shifu.shifu.core.dtrain.nn.NNConstants; import ml.shifu.shifu.fs.ShifuFileUtils; import ml.shifu.shifu.util.CommonUtils; import org.apache.commons.collections.CollectionUtils; import org.apache.commons.lang.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * ModelInspector class is to do Safety Testing for model. * * <p> * Safety Testing include: 1. validate the ModelConfig against its meta data * src/main/resources/store/ModelConfigMeta.json 2. check source data for training and evaluation 3. check the * prerequisite for each step */ public class ModelInspector { private static final Logger LOG = LoggerFactory.getLogger(ModelInspector.class); public static enum ModelStep { INIT, STATS, VARSELECT, NORMALIZE, TRAIN, POSTTRAIN, EVAL, EXPORT, COMBO } private static ModelInspector instance = new ModelInspector(); // singleton class, avoid to create new instance private ModelInspector() { } /** * @return the inspector handler */ public static ModelInspector getInspector() { return instance; } /** * Probe the status of model for each step. * It will check the setting in @ModelConfig to make sure all setting from user are correct. * After that it will do different checking for different steps * * @param modelConfig * - the model configuration that want to probe * @param modelStep * - the steps * @return the result of probe * if everything is OK, the status of ValidateResult is TRUE * else the status of ValidateResult is FALSE, and the reasons will in the clauses of ValidateResult * @throws Exception * any exception in validation */ public ValidateResult probe(ModelConfig modelConfig, ModelStep modelStep) throws Exception { ValidateResult result = checkMeta(modelConfig); if(!result.getStatus()) { return result; } if(modelConfig.isClassification()) { if(modelConfig.getBasic().getRunMode() == RunMode.LOCAL || modelConfig.getDataSet().getSource() == SourceType.LOCAL) { ValidateResult tmpResult = new ValidateResult(true); tmpResult.addCause("Multiple classification is only effective in MAPRED runmode and HDFS source type."); result = ValidateResult.mergeResult(result, tmpResult); } } if(modelConfig.getDataSet().getSource() == SourceType.LOCAL && modelConfig.isMapReduceRunMode()) { ValidateResult tmpResult = new ValidateResult(true); // tmpResult.setStatus(false); // tmpResult.getCauses().add( // "'LOCAL' data set (dataSet.source) cannot be run with 'mapred' run mode(basic.runMode)"); result = ValidateResult.mergeResult(result, tmpResult); } if(ModelStep.INIT.equals(modelStep)) { result = ValidateResult.mergeResult(result, checkTrainData(modelConfig.getDataSet())); result = ValidateResult.mergeResult(result, checkVarSelect(modelConfig, modelConfig.getVarSelect())); if(result.getStatus()) { result = ValidateResult.mergeResult(result, checkColumnConf(modelConfig)); } } else if(ModelStep.STATS.equals(modelStep)) { result = ValidateResult.mergeResult(result, checkFile("ColumnConfig.json", SourceType.LOCAL, "ColumnConfig.json : ")); result = ValidateResult.mergeResult(result, checkStatsConf(modelConfig)); } else if(ModelStep.VARSELECT.equals(modelStep)) { result = ValidateResult.mergeResult(result, checkVarSelect(modelConfig, modelConfig.getVarSelect())); if(result.getStatus()) { // user may add configure file between steps // add validation to avoid user to make mistake result = ValidateResult.mergeResult(result, checkColumnConf(modelConfig)); } } else if(ModelStep.NORMALIZE.equals(modelStep)) { result = ValidateResult.mergeResult(result, checkNormSetting(modelConfig, modelConfig.getNormalize())); } else if(ModelStep.TRAIN.equals(modelStep)) { result = ValidateResult.mergeResult(result, checkTrainSetting(modelConfig, modelConfig.getTrain())); if(modelConfig.isClassification() && modelConfig.getTrain().getMultiClassifyMethod() == MultipleClassification.NATIVE) { if(!"nn".equalsIgnoreCase((modelConfig.getTrain().getAlgorithm())) && !CommonConstants.RF_ALG_NAME.equalsIgnoreCase(modelConfig.getTrain().getAlgorithm())) { ValidateResult tmpResult = new ValidateResult(true); tmpResult .addCause("Native multiple classification is only effective in neural network (nn) or random forest (rf) training method."); result = ValidateResult.mergeResult(result, tmpResult); } } if(modelConfig.isClassification() && modelConfig.getTrain().isOneVsAll()) { if(!CommonUtils.isTreeModel(modelConfig.getAlgorithm()) && !modelConfig.getAlgorithm().equalsIgnoreCase("nn")) { ValidateResult tmpResult = new ValidateResult(true); tmpResult .addCause("OneVSAll multiple classification is only effective in gradient boosted trees (GBT) or random forest (RF) or Neural Network (NN) training method."); result = ValidateResult.mergeResult(result, tmpResult); } } } else if(ModelStep.POSTTRAIN.equals(modelStep)) { // TODO } else if(ModelStep.EVAL.equals(modelStep)) { if(CollectionUtils.isNotEmpty(modelConfig.getEvals())) { for(EvalConfig evalConfig: modelConfig.getEvals()) { result = ValidateResult.mergeResult(result, checkRawData(evalConfig.getDataSet(), "Eval Set - " + evalConfig.getName() + ": ")); if(StringUtils.isNotBlank(evalConfig.getScoreMetaColumnNameFile())) { result = ValidateResult.mergeResult( result, checkFile(evalConfig.getScoreMetaColumnNameFile(), SourceType.LOCAL, "Eval Set - " + evalConfig.getName() + ": ")); } } } } return result; } /** * Check the settings in @ModelConfig against the constrains in @MetaFactory * * @param modelConfig * - model configuration to check * @return - the validation result * @throws Exception * Exception when checking model configuration */ public ValidateResult checkMeta(ModelConfig modelConfig) throws Exception { return MetaFactory.validate(modelConfig); } /** * Check the target column in @ModelConfit, it shouldn' be null or empty * - the target column shouldn't be meta column * - the target column shouldn't be force select column * - the target column shouldn't be force remove column * <p/> * - a column shouldn't exist in more than list - metaColumns, forceSelectColumns, forceRemoveColumns * * @param modelConfig * - model configuration to check * @return - the validation result * @throws IOException * Exception when checking model configuration */ private ValidateResult checkColumnConf(ModelConfig modelConfig) throws IOException { ValidateResult result = new ValidateResult(true); if(StringUtils.isBlank(modelConfig.getTargetColumnName())) { result.addCause("The target column name is null or empty."); } else { List<String> metaColumns = modelConfig.getMetaColumnNames(); List<String> forceRemoveColumns = modelConfig.getListForceRemove(); List<String> forceSelectColumns = modelConfig.getListForceSelect(); if(CollectionUtils.isNotEmpty(metaColumns) && metaColumns.contains(modelConfig.getTargetColumnName())) { result.addCause("The target column name shouldn't be in the meta column conf."); } if(Boolean.TRUE.equals(modelConfig.getVarSelect().getForceEnable()) && CollectionUtils.isNotEmpty(forceRemoveColumns) && forceRemoveColumns.contains(modelConfig.getTargetColumnName())) { result.addCause("The target column name shouldn't be in the force remove conf."); } if(Boolean.TRUE.equals(modelConfig.getVarSelect().getForceEnable()) && CollectionUtils.isNotEmpty(forceSelectColumns) && forceSelectColumns.contains(modelConfig.getTargetColumnName())) { result.addCause("The target column name shouldn't be in the force select conf."); } if(Boolean.TRUE.equals(modelConfig.getVarSelect().getForceEnable())) { String columnColumn = CommonUtils.containsAny(metaColumns, forceRemoveColumns); if(columnColumn != null) { result.addCause("Column - " + columnColumn + " exists both in meta column conf and force remove conf."); } columnColumn = CommonUtils.containsAny(metaColumns, forceSelectColumns); if(columnColumn != null) { result.addCause("Column - " + columnColumn + " exists both in meta column conf and force select conf."); } columnColumn = CommonUtils.containsAny(forceSelectColumns, forceRemoveColumns); if(columnColumn != null) { result.addCause("Column - " + columnColumn + " exists both in force select conf and force remove conf."); } } } return result; } private ValidateResult checkStatsConf(ModelConfig modelConfig) throws IOException { ValidateResult result = new ValidateResult(true); if(modelConfig.isClassification() && (modelConfig.getBinningMethod() == BinningMethod.EqualPositive || modelConfig.getBinningMethod() == BinningMethod.EqualNegtive || modelConfig.getBinningMethod() == BinningMethod.WeightEqualPositive || modelConfig .getBinningMethod() == BinningMethod.WeightEqualNegative)) { ValidateResult tmpResult = new ValidateResult(false, Arrays.asList("Multiple classification cannot leverage EqualNegtive and EqualPositive binning.")); result = ValidateResult.mergeResult(result, tmpResult); } if(modelConfig.isClassification() && modelConfig.getBinningAlgorithm() != BinningAlgorithm.SPDTI) { result = ValidateResult.mergeResult( result, new ValidateResult(false, Arrays .asList("Only SPDTI binning algorithm are supported with multiple classification."))); } // maxNumBin should be less than Short.MAX_VALUE, larger maxNumBin need more computing and no meaningful. if(modelConfig.getStats().getMaxNumBin() > Short.MAX_VALUE || modelConfig.getStats().getMaxNumBin() <= 0) { result = ValidateResult.mergeResult(result, new ValidateResult(false, Arrays.asList("stats#maxNumBin should be in (0, 32767]."))); } return result; } /** * Check the prerequisite for variable selection * 1. if the force remove is not empty, check the conf file exists or not * 2. if the force select is not empty, check the conf file exists or not * * @param varSelect * - @ModelVarSelectConf settings for variable selection * @return - the result of validation * @throws IOException * IOException may be thrown when checking file */ private ValidateResult checkVarSelect(ModelConfig modelConfig, ModelVarSelectConf varSelect) throws IOException { ValidateResult result = new ValidateResult(true); if(Boolean.TRUE.equals(varSelect.getForceEnable())) { if(StringUtils.isNotBlank(varSelect.getForceRemoveColumnNameFile())) { result = ValidateResult.mergeResult( result, checkFile(varSelect.getForceRemoveColumnNameFile(), SourceType.LOCAL, "forceRemove columns configuration ")); } if(StringUtils.isNotBlank(varSelect.getForceSelectColumnNameFile())) { result = ValidateResult.mergeResult( result, checkFile(varSelect.getForceSelectColumnNameFile(), SourceType.LOCAL, "forceSelect columns configuration")); } } PostCorrelationMetric corrMetric = varSelect.getPostCorrelationMetric(); if(!varSelect.getFilterBy().equals("SE") && corrMetric != null && corrMetric == PostCorrelationMetric.SE) { ValidateResult tmpResult = new ValidateResult(true); tmpResult.setStatus(false); tmpResult.getCauses().add( "VarSelect#filterBy and VarSelect#postCorrelationMetric should be both set to SE."); result = ValidateResult.mergeResult(result, tmpResult); } return result; } /** * Check the Data Set - to check the data exists or not * to check the header of data exists or not * * @param dataSet * - @RawSourceData to check * @param prefix * - the prefix to generate readable clauses * @return @ValidateResult * @throws IOException * IOException may be thrown when checking file */ private ValidateResult checkRawData(RawSourceData dataSet, String prefix) throws IOException { ValidateResult result = new ValidateResult(true); result = ValidateResult.mergeResult(result, checkFile(dataSet.getDataPath(), dataSet.getSource(), prefix + "data path ")); if(!StringUtils.isBlank(dataSet.getHeaderPath())) { result = ValidateResult.mergeResult(result, checkFile(dataSet.getHeaderPath(), dataSet.getSource(), prefix + "header path ")); } else { LOG.warn("Header file is set to empty, shifu will try to detect schema by first line of input and header " + "delimiter."); } return result; } /** * Check the training data for model * Fist of all, it checks the @RawSourceData * Then, it checks conf file for categorical column exists or not, if the setting is not empty * Then, it checks conf file for meta column exists or not, if the setting is not empty * * @param dataSet * - @ModelSourceDataConf to check * @return @ValidateResult * @throws IOException * IOException may be thrown when checking file */ private ValidateResult checkTrainData(ModelSourceDataConf dataSet) throws IOException { ValidateResult result = checkRawData(dataSet, "Train Set:"); if(StringUtils.isNotBlank(dataSet.getCategoricalColumnNameFile())) { result = ValidateResult.mergeResult( result, checkFile(dataSet.getCategoricalColumnNameFile(), SourceType.LOCAL, "categorical columns configuration ")); } if(StringUtils.isNotBlank(dataSet.getMetaColumnNameFile())) { result = ValidateResult.mergeResult(result, checkFile(dataSet.getMetaColumnNameFile(), SourceType.LOCAL, "meta columns configuration ")); } return result; } /** * Check the setting for model normalize. * It will make sure the following condition: * * <p> * <ul> * <li>stdDevCutOff > 0</li> * <li>0 < sampleRate <= 1</li> * <li>sampleNegOnly is either true or false</li> * <li>normType contains valid value among [ZSCALE, WOE, WEIGHT_WOE, HYBRID, WEIGHT_HYBRID]</li> * </ul> * </p> * * @param norm * {@link ModelNormalizeConf} instance. * @return check result instance {@link ValidateResult}. */ private ValidateResult checkNormSetting(ModelConfig modelConfig, ModelNormalizeConf norm) { ValidateResult result = new ValidateResult(true); if(norm.getStdDevCutOff() == null || norm.getStdDevCutOff() <= 0) { ValidateResult tmpResult = new ValidateResult(true); tmpResult.setStatus(false); tmpResult.getCauses().add("stdDevCutOff should be positive value in normalize configuration"); result = ValidateResult.mergeResult(result, tmpResult); } if(norm.getSampleRate() == null || norm.getSampleRate() <= 0 || norm.getSampleRate() > 1) { ValidateResult tmpResult = new ValidateResult(true); tmpResult.setStatus(false); tmpResult.getCauses().add("sampleRate should be positive value in normalize configuration"); result = ValidateResult.mergeResult(result, tmpResult); } if(norm.getSampleNegOnly() == null) { ValidateResult tmpResult = new ValidateResult(true); tmpResult.setStatus(false); tmpResult.getCauses().add("sampleNegOnly should be true/false in normalize configuration"); result = ValidateResult.mergeResult(result, tmpResult); } if(norm.getNormType() == null) { ValidateResult tmpResult = new ValidateResult(true); tmpResult.setStatus(false); tmpResult .getCauses() .add("normType should be one of [ZSCALE, WOE, WEIGHT_WOE, HYBRID, WEIGHT_HYBRID] in normalize configuration"); result = ValidateResult.mergeResult(result, tmpResult); } boolean isZScore = modelConfig.getNormalize().getNormType() == NormType.ZSCALE || modelConfig.getNormalize().getNormType() == NormType.ZSCORE || modelConfig.getNormalize().getNormType() == NormType.OLD_ZSCALE || modelConfig.getNormalize().getNormType() == NormType.OLD_ZSCORE; if(modelConfig.isClassification() && !isZScore) { ValidateResult tmpResult = new ValidateResult(false); tmpResult.getCauses().add("NormType 'ZSCALE|ZSCORE' is the only norm type for multiple classification."); result = ValidateResult.mergeResult(result, tmpResult); } return result; } /** * Check the setting for model training. * It will make sure (num_of_layers > 0 * && num_of_layers = hidden_nodes_size * && num_of_layse = active_func_size) * * @param train * - @ModelTrainConf to check * @return @ValidateResult */ @SuppressWarnings("unchecked") private ValidateResult checkTrainSetting(ModelConfig modelConfig, ModelTrainConf train) { ValidateResult result = new ValidateResult(true); if(train.getBaggingNum() == null || train.getBaggingNum() < 0) { ValidateResult tmpResult = new ValidateResult(true); tmpResult.setStatus(false); tmpResult.getCauses().add("Bagging number should be greater than zero in train configuration"); result = ValidateResult.mergeResult(result, tmpResult); } if(train.getNumKFold() != null && train.getNumKFold() > 20) { ValidateResult tmpResult = new ValidateResult(true); tmpResult.setStatus(false); tmpResult.getCauses().add("numKFold should be in (0, 20] or <=0 (not dp k-crossValidation)"); result = ValidateResult.mergeResult(result, tmpResult); } if(train.getBaggingSampleRate() == null || train.getBaggingSampleRate().compareTo(Double.valueOf(0)) <= 0 || train.getBaggingSampleRate().compareTo(Double.valueOf(1)) > 0) { ValidateResult tmpResult = new ValidateResult(true); tmpResult.setStatus(false); tmpResult.getCauses().add("Bagging sample rate number should be in (0, 1]."); result = ValidateResult.mergeResult(result, tmpResult); } if(train.getValidSetRate() == null || train.getValidSetRate().compareTo(Double.valueOf(0)) < 0 || train.getValidSetRate().compareTo(Double.valueOf(1)) >= 0) { ValidateResult tmpResult = new ValidateResult(true); tmpResult.setStatus(false); tmpResult.getCauses().add("Validation set rate number should be in [0, 1)."); result = ValidateResult.mergeResult(result, tmpResult); } if(train.getNumTrainEpochs() == null || train.getNumTrainEpochs() <= 0) { ValidateResult tmpResult = new ValidateResult(true); tmpResult.setStatus(false); tmpResult.getCauses().add("Epochs should be larger than 0."); result = ValidateResult.mergeResult(result, tmpResult); } if(train.getEpochsPerIteration() != null && train.getEpochsPerIteration() <= 0) { ValidateResult tmpResult = new ValidateResult(true); tmpResult.setStatus(false); tmpResult.getCauses().add("'epochsPerIteration' should be larger than 0 if set."); result = ValidateResult.mergeResult(result, tmpResult); } if(train.getWorkerThreadCount() != null && (train.getWorkerThreadCount() <= 0 || train.getWorkerThreadCount() > 32)) { ValidateResult tmpResult = new ValidateResult(true); tmpResult.setStatus(false); tmpResult.getCauses().add("'workerThreadCount' should be in (0, 32] if set."); result = ValidateResult.mergeResult(result, tmpResult); } if(train.getConvergenceThreshold() != null && train.getConvergenceThreshold().compareTo(0.0) < 0) { ValidateResult tmpResult = new ValidateResult(true); tmpResult.setStatus(false); tmpResult.getCauses().add("'threshold' should be larger than or equal to 0.0 if set."); result = ValidateResult.mergeResult(result, tmpResult); } if(modelConfig.isClassification() && train.isOneVsAll() && !CommonUtils.isTreeModel(train.getAlgorithm()) && !train.getAlgorithm().equalsIgnoreCase("nn")) { ValidateResult tmpResult = new ValidateResult(true); tmpResult.setStatus(false); tmpResult.getCauses().add( "'one vs all' or 'one vs rest' is only enabled with 'RF' or 'GBT' or 'NN' algorithm"); result = ValidateResult.mergeResult(result, tmpResult); } if(modelConfig.isClassification() && train.getMultiClassifyMethod() == MultipleClassification.NATIVE && train.getAlgorithm().equalsIgnoreCase(CommonConstants.RF_ALG_NAME)) { Object impurity = train.getParams().get("Impurity"); if(impurity != null && !"entropy".equalsIgnoreCase(impurity.toString()) && !"gini".equalsIgnoreCase(impurity.toString())) { ValidateResult tmpResult = new ValidateResult(true); tmpResult.setStatus(false); tmpResult.getCauses().add( "Impurity should be in [entropy,gini] if native mutiple classification in RF."); result = ValidateResult.mergeResult(result, tmpResult); } } GridSearch gs = new GridSearch(train.getParams()); // such parameter validation only in regression and not grid search mode if(modelConfig.isRegression() && !gs.hasHyperParam()) { if(train.getAlgorithm().equalsIgnoreCase("nn")) { Map<String, Object> params = train.getParams(); int layerCnt = (Integer) params.get(CommonConstants.NUM_HIDDEN_LAYERS); if(layerCnt < 0) { ValidateResult tmpResult = new ValidateResult(true); tmpResult.setStatus(false); tmpResult.getCauses().add("The number of hidden layers should be >= 0 in train configuration"); result = ValidateResult.mergeResult(result, tmpResult); } List<Integer> hiddenNode = (List<Integer>) params.get(CommonConstants.NUM_HIDDEN_NODES); List<String> activateFucs = (List<String>) params.get(CommonConstants.ACTIVATION_FUNC); if(hiddenNode.size() != activateFucs.size() || layerCnt != activateFucs.size()) { ValidateResult tmpResult = new ValidateResult(true); tmpResult.setStatus(false); tmpResult.getCauses().add( CommonConstants.NUM_HIDDEN_LAYERS + "/SIZE(" + CommonConstants.NUM_HIDDEN_NODES + ")" + "/SIZE(" + CommonConstants.ACTIVATION_FUNC + ")" + " should be equal in train configuration"); result = ValidateResult.mergeResult(result, tmpResult); } Double learningRate = Double.valueOf(params.get(CommonConstants.LEARNING_RATE).toString()); if(learningRate != null && (learningRate.compareTo(Double.valueOf(0)) <= 0)) { ValidateResult tmpResult = new ValidateResult(true); tmpResult.setStatus(false); tmpResult.getCauses().add("Learning rate should be larger than 0."); result = ValidateResult.mergeResult(result, tmpResult); } Object learningDecayO = params.get("LearningDecay"); if(learningDecayO != null) { Double learningDecay = Double.valueOf(learningDecayO.toString()); if(learningDecay != null && ((learningDecay.compareTo(Double.valueOf(0)) < 0) || (learningDecay.compareTo(Double .valueOf(1)) >= 0))) { ValidateResult tmpResult = new ValidateResult(true); tmpResult.setStatus(false); tmpResult.getCauses().add("Learning decay should be in [0, 1) if set."); result = ValidateResult.mergeResult(result, tmpResult); } } Object elmObject = params.get(DTrainUtils.IS_ELM); boolean isELM = elmObject == null ? false : "true".equalsIgnoreCase(elmObject.toString()); if(isELM && layerCnt != 1) { ValidateResult tmpResult = new ValidateResult(true); tmpResult.setStatus(false); tmpResult.getCauses().add( "If ELM(extreme learning machine), hidden layer should only be one layer."); result = ValidateResult.mergeResult(result, tmpResult); } Object dropoutObj = params.get(CommonConstants.DROPOUT_RATE); if(dropoutObj != null) { Double dropoutRate = Double.valueOf(dropoutObj.toString()); if(dropoutRate != null && (dropoutRate < 0d || dropoutRate >= 1d)) { ValidateResult tmpResult = new ValidateResult(true); tmpResult.setStatus(false); tmpResult.getCauses().add("Dropout rate should be in [0, 1)."); result = ValidateResult.mergeResult(result, tmpResult); } } } if(train.getAlgorithm().equalsIgnoreCase(CommonConstants.GBT_ALG_NAME) || train.getAlgorithm().equalsIgnoreCase(CommonConstants.RF_ALG_NAME) || train.getAlgorithm().equalsIgnoreCase(NNConstants.NN_ALG_NAME)) { Map<String, Object> params = train.getParams(); Object fssObj = params.get("FeatureSubsetStrategy"); if(fssObj == null) { if(train.getAlgorithm().equalsIgnoreCase(CommonConstants.GBT_ALG_NAME) || train.getAlgorithm().equalsIgnoreCase(CommonConstants.RF_ALG_NAME)) { ValidateResult tmpResult = new ValidateResult(true); tmpResult.setStatus(false); tmpResult.getCauses().add("FeatureSubsetStrategy is not set in RF/GBT algorithm."); result = ValidateResult.mergeResult(result, tmpResult); } } else { boolean isNumber = false; double doubleFss = 0; try { doubleFss = Double.parseDouble(fssObj.toString()); isNumber = true; } catch (Exception e) { isNumber = false; } if(isNumber && (doubleFss <= 0d || doubleFss > 1d)) { ValidateResult tmpResult = new ValidateResult(true); tmpResult.setStatus(false); tmpResult.getCauses().add("FeatureSubsetStrategy if double should be in (0, 1]"); } else { boolean fssInEnum = false; for(FeatureSubsetStrategy fss: FeatureSubsetStrategy.values()) { if(fss.toString().equalsIgnoreCase(fssObj.toString())) { fssInEnum = true; break; } } if(!fssInEnum) { ValidateResult tmpResult = new ValidateResult(true); tmpResult.setStatus(false); tmpResult .getCauses() .add("FeatureSubsetStrategy if string should be in ['ALL', 'HALF', 'ONETHIRD' , 'TWOTHIRDS' , 'AUTO' , 'SQRT' , 'LOG2']"); } } } } if(train.getAlgorithm().equalsIgnoreCase(CommonConstants.GBT_ALG_NAME) || train.getAlgorithm().equalsIgnoreCase(CommonConstants.RF_ALG_NAME)) { Map<String, Object> params = train.getParams(); if(train.getAlgorithm().equalsIgnoreCase(CommonConstants.GBT_ALG_NAME)) { Object loss = params.get("Loss"); if(loss != null && !"log".equalsIgnoreCase(loss.toString()) && !"squared".equalsIgnoreCase(loss.toString()) && !"halfgradsquared".equalsIgnoreCase(loss.toString()) && !"absolute".equalsIgnoreCase(loss.toString())) { ValidateResult tmpResult = new ValidateResult(true); tmpResult.setStatus(false); tmpResult.getCauses().add("Loss should be in [log,squared,absolute]."); result = ValidateResult.mergeResult(result, tmpResult); } if(loss == null) { ValidateResult tmpResult = new ValidateResult(true); tmpResult.setStatus(false); tmpResult.getCauses().add("'Loss' parameter isn't be set in train#parameters in GBT training."); result = ValidateResult.mergeResult(result, tmpResult); } } Object maxDepthObj = params.get("MaxDepth"); if(maxDepthObj != null) { int maxDepth = Integer.valueOf(maxDepthObj.toString()); if(maxDepth <= 0 || maxDepth > 20) { ValidateResult tmpResult = new ValidateResult(true); tmpResult.setStatus(false); tmpResult.getCauses().add("MaxDepth should in [1, 20]."); result = ValidateResult.mergeResult(result, tmpResult); } } Object vtObj = params.get("ValidationTolerance"); if(vtObj != null) { double validationTolerance = Double.valueOf(vtObj.toString()); if(validationTolerance < 0d || validationTolerance >= 1d) { ValidateResult tmpResult = new ValidateResult(true); tmpResult.setStatus(false); tmpResult.getCauses().add("ValidationTolerance should in [0, 1)."); result = ValidateResult.mergeResult(result, tmpResult); } } Object maxLeavesObj = params.get("MaxLeaves"); if(maxLeavesObj != null) { int maxLeaves = Integer.valueOf(maxLeavesObj.toString()); if(maxLeaves <= 0) { ValidateResult tmpResult = new ValidateResult(true); tmpResult.setStatus(false); tmpResult.getCauses().add("MaxLeaves should in [1, Integer.MAX_VALUE]."); result = ValidateResult.mergeResult(result, tmpResult); } } if(maxDepthObj == null && maxLeavesObj == null) { ValidateResult tmpResult = new ValidateResult(true); tmpResult.setStatus(false); tmpResult .getCauses() .add("'MaxDepth' or 'MaxLeaves' parameters at least one of both should be set in train#parameters in GBT training."); result = ValidateResult.mergeResult(result, tmpResult); } Object maxStatsMemoryMBObj = params.get("MaxStatsMemoryMB"); if(maxStatsMemoryMBObj != null) { int maxStatsMemoryMB = Integer.valueOf(maxStatsMemoryMBObj.toString()); if(maxStatsMemoryMB <= 0) { ValidateResult tmpResult = new ValidateResult(true); tmpResult.setStatus(false); tmpResult.getCauses().add("MaxStatsMemoryMB should > 0."); result = ValidateResult.mergeResult(result, tmpResult); } } Object dropoutObj = params.get(CommonConstants.DROPOUT_RATE); if(dropoutObj != null) { Double dropoutRate = Double.valueOf(dropoutObj.toString()); if(dropoutRate != null && (dropoutRate < 0d || dropoutRate >= 1d)) { ValidateResult tmpResult = new ValidateResult(true); tmpResult.setStatus(false); tmpResult.getCauses().add("Dropout rate should be in [0, 1)."); result = ValidateResult.mergeResult(result, tmpResult); } } if(train.getAlgorithm().equalsIgnoreCase(CommonConstants.GBT_ALG_NAME)) { Object learningRateObj = params.get(CommonConstants.LEARNING_RATE); if(learningRateObj != null) { Double learningRate = Double.valueOf(learningRateObj.toString()); if(learningRate != null && (learningRate.compareTo(Double.valueOf(0)) <= 0)) { ValidateResult tmpResult = new ValidateResult(true); tmpResult.setStatus(false); tmpResult.getCauses().add("Learning rate should be larger than 0."); result = ValidateResult.mergeResult(result, tmpResult); } } else { ValidateResult tmpResult = new ValidateResult(true); tmpResult.setStatus(false); tmpResult.getCauses().add( "'LearningRate' parameter isn't be set in train#parameters in GBT training."); result = ValidateResult.mergeResult(result, tmpResult); } } Object minInstancesPerNodeObj = params.get("MinInstancesPerNode"); if(minInstancesPerNodeObj != null) { int minInstancesPerNode = Integer.valueOf(minInstancesPerNodeObj.toString()); if(minInstancesPerNode <= 0) { ValidateResult tmpResult = new ValidateResult(true); tmpResult.setStatus(false); tmpResult.getCauses().add("MinInstancesPerNode should > 0."); result = ValidateResult.mergeResult(result, tmpResult); } } else { ValidateResult tmpResult = new ValidateResult(true); tmpResult.setStatus(false); tmpResult.getCauses().add( "'MinInstancesPerNode' parameter isn't be set in train#parameters in GBT/RF training."); result = ValidateResult.mergeResult(result, tmpResult); } Object treeNumObj = params.get("TreeNum"); if(treeNumObj != null) { int treeNum = Integer.valueOf(treeNumObj.toString()); if(treeNum <= 0 || treeNum > 10000) { ValidateResult tmpResult = new ValidateResult(true); tmpResult.setStatus(false); tmpResult.getCauses().add("TreeNum should be in [1, 10000]."); result = ValidateResult.mergeResult(result, tmpResult); } } else { ValidateResult tmpResult = new ValidateResult(true); tmpResult.setStatus(false); tmpResult.getCauses().add( "'TreeNum' parameter isn't be set in train#parameters in GBT/RF training."); result = ValidateResult.mergeResult(result, tmpResult); } Object minInfoGainObj = params.get("MinInfoGain"); if(minInfoGainObj != null) { Double minInfoGain = Double.valueOf(minInfoGainObj.toString()); if(minInfoGain != null && (minInfoGain.compareTo(Double.valueOf(0)) < 0)) { ValidateResult tmpResult = new ValidateResult(true); tmpResult.setStatus(false); tmpResult.getCauses().add("MinInfoGain should be >= 0."); result = ValidateResult.mergeResult(result, tmpResult); } } else { ValidateResult tmpResult = new ValidateResult(true); tmpResult.setStatus(false); tmpResult.getCauses().add( "'MinInfoGain' parameter isn't be set in train#parameters in GBT/RF training."); result = ValidateResult.mergeResult(result, tmpResult); } Object impurityObj = params.get("Impurity"); if(impurityObj == null) { ValidateResult tmpResult = new ValidateResult(true); tmpResult.setStatus(false); tmpResult.getCauses().add("Impurity is not set in RF/GBT algorithm."); result = ValidateResult.mergeResult(result, tmpResult); } else { if(train.getAlgorithm().equalsIgnoreCase(CommonConstants.GBT_ALG_NAME)) { if(impurityObj != null && !"variance".equalsIgnoreCase(impurityObj.toString()) && !"friedmanmse".equalsIgnoreCase(impurityObj.toString())) { ValidateResult tmpResult = new ValidateResult(true); tmpResult.setStatus(false); tmpResult.getCauses().add("GBDT only supports 'variance' impurity type."); result = ValidateResult.mergeResult(result, tmpResult); } } if(train.getAlgorithm().equalsIgnoreCase(CommonConstants.RF_ALG_NAME)) { if(impurityObj != null && !"friedmanmse".equalsIgnoreCase(impurityObj.toString()) && !"entropy".equalsIgnoreCase(impurityObj.toString()) && !"variance".equalsIgnoreCase(impurityObj.toString()) && !"gini".equalsIgnoreCase(impurityObj.toString())) { ValidateResult tmpResult = new ValidateResult(true); tmpResult.setStatus(false); tmpResult.getCauses().add("RF supports 'variance|entropy|gini' impurity types."); result = ValidateResult.mergeResult(result, tmpResult); } } } } } return result; } /** * check the file exists or not * * @param dataPath * - the path of data * @param sourceType * - the source type of data [local/hdfs/s3] * @param prefix * - the prefix to generate readable clauses * @return @ValidateResult * @throws IOException */ private ValidateResult checkFile(String dataPath, SourceType sourceType, String prefix) throws IOException { ValidateResult result = new ValidateResult(true); if(StringUtils.isBlank(dataPath)) { result.addCause(prefix + "is null or empty - " + dataPath); } else if(dataPath.trim().contains("~")) { result.addCause(prefix + "contains ~, which is not allowed - " + dataPath); } else if(!ShifuFileUtils.isFileExists(dataPath, sourceType)) { result.addCause(prefix + "doesn't exist - " + dataPath); } return result; } }