/* * Copyright [2012-2014] PayPal Software Foundation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package ml.shifu.shifu.core.processor; import java.io.BufferedReader; import java.io.IOException; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Scanner; import ml.shifu.guagua.GuaguaConstants; import ml.shifu.guagua.hadoop.util.HDPUtils; import ml.shifu.guagua.mapreduce.GuaguaMapReduceClient; import ml.shifu.guagua.mapreduce.GuaguaMapReduceConstants; import ml.shifu.shifu.container.obj.ColumnConfig; import ml.shifu.shifu.container.obj.ModelVarSelectConf.PostCorrelationMetric; import ml.shifu.shifu.container.obj.RawSourceData.SourceType; import ml.shifu.shifu.core.VariableSelector; import ml.shifu.shifu.core.dtrain.CommonConstants; import ml.shifu.shifu.core.dtrain.nn.NNConstants; import ml.shifu.shifu.core.dvarsel.VarSelMaster; import ml.shifu.shifu.core.dvarsel.VarSelMasterResult; import ml.shifu.shifu.core.dvarsel.VarSelOutput; import ml.shifu.shifu.core.dvarsel.VarSelWorker; import ml.shifu.shifu.core.dvarsel.VarSelWorkerResult; import ml.shifu.shifu.core.dvarsel.wrapper.CandidateGenerator; import ml.shifu.shifu.core.dvarsel.wrapper.WrapperMasterConductor; import ml.shifu.shifu.core.dvarsel.wrapper.WrapperWorkerConductor; import ml.shifu.shifu.core.mr.input.CombineInputFormat; import ml.shifu.shifu.core.validator.ModelInspector.ModelStep; import ml.shifu.shifu.core.varselect.ColumnInfo; import ml.shifu.shifu.core.varselect.ColumnStatistics; import ml.shifu.shifu.core.varselect.VarSelectMapper; import ml.shifu.shifu.core.varselect.VarSelectReducer; import ml.shifu.shifu.exception.ShifuErrorCode; import ml.shifu.shifu.exception.ShifuException; import ml.shifu.shifu.fs.PathFinder; import ml.shifu.shifu.fs.ShifuFileUtils; import ml.shifu.shifu.util.CommonUtils; import ml.shifu.shifu.util.Constants; import ml.shifu.shifu.util.Environment; import org.apache.commons.collections.ListUtils; import org.apache.commons.compress.compressors.bzip2.BZip2CompressorInputStream; import org.apache.commons.io.IOUtils; import org.apache.commons.jexl2.JexlException; import org.apache.commons.lang.StringUtils; import org.apache.commons.lang3.tuple.MutablePair; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FileStatus; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.LongWritable; import org.apache.hadoop.io.Text; import org.apache.hadoop.mapreduce.Job; import org.apache.hadoop.mapreduce.lib.input.FileInputFormat; import org.apache.hadoop.mapreduce.lib.map.MultithreadedMapper; import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat; import org.apache.hadoop.mapreduce.lib.output.MultipleOutputs; import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat; import org.apache.hadoop.util.GenericOptionsParser; import org.apache.pig.impl.util.JarManager; import org.apache.zookeeper.ZooKeeper; import org.encog.ml.BasicML; import org.encog.ml.data.MLDataSet; import org.jboss.netty.bootstrap.ServerBootstrap; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.core.JsonParser; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.base.Splitter; /** * Variable selection processor, select the variable based on KS/IV value, or * * <p> * Selection variable based on the wrapper training processor. * * <p> * For sensitive variable selection, each time wrapperRatio percent of variables will be removed. If continue do * variable selection, continue to run varselect command. Current design will do variable selection continuously. */ public class VarSelectModelProcessor extends BasicModelProcessor implements Processor { private final static Logger log = LoggerFactory.getLogger(VarSelectModelProcessor.class); private boolean isToReset = false; private boolean isToList = false; public VarSelectModelProcessor() { // default constructor } public VarSelectModelProcessor(Map<String, Object> otherConfigs) { super.otherConfigs = otherConfigs; } public VarSelectModelProcessor(boolean isToReset) { this.isToReset = isToReset; } @SuppressWarnings("unused") private static final double BAD_IV_THRESHOLD = 0.02d; /** * SE stats mao for correlation variable selection,if not se, this field will be null. */ private Map<Integer, ColumnStatistics> seStatsMap; private void validateParameters() throws Exception { // String alg = super.getModelConfig().getTrain().getAlgorithm(); String filterBy = this.modelConfig.getVarSelectFilterBy(); if(filterBy.equalsIgnoreCase(Constants.FILTER_BY_SE) || filterBy.equalsIgnoreCase(Constants.FILTER_BY_ST)) { validateSEParameters(); validateNormalize(); } } public void setToList(boolean toList) { isToList = toList; } /** * Run for the variable selection */ @Override public int run() throws Exception { log.info("Step Start: varselect"); long start = System.currentTimeMillis(); try { setUp(ModelStep.VARSELECT); validateParameters(); // reset all selections if user specify or select by absolute number if(isToReset) { log.info("Reset all selections data including type final select etc!"); resetAllFinalSelect(); } else if(isToList) { log.info("Below variables are selected - "); for(ColumnConfig columnConfig: this.columnConfigList) { if(columnConfig.isFinalSelect()) { log.info(columnConfig.getColumnName()); } } log.info("----- Done -----"); } else { // sync to make sure load from hdfs config is consistent with local configuration syncDataToHdfs(super.modelConfig.getDataSet().getSource()); if(modelConfig.isRegression()) { VariableSelector selector = new VariableSelector(this.modelConfig, this.columnConfigList); String filterBy = this.modelConfig.getVarSelectFilterBy(); if(filterBy.equalsIgnoreCase(Constants.FILTER_BY_KS) || filterBy.equalsIgnoreCase(Constants.FILTER_BY_IV) || filterBy.equalsIgnoreCase(Constants.FILTER_BY_PARETO) || filterBy.equalsIgnoreCase(Constants.FILTER_BY_MIX)) { this.columnConfigList = selector.selectByFilter(); } else if(filterBy.equalsIgnoreCase(Constants.FILTER_BY_FI)) { if(!CommonUtils.isTreeModel(modelConfig.getAlgorithm())) { throw new IllegalArgumentException( "Filter by FI only works well in GBT/RF. Please check your modelconfig::train."); } selectByFeatureImportance(); } else if(filterBy.equalsIgnoreCase(Constants.FILTER_BY_SE) || filterBy.equalsIgnoreCase(Constants.FILTER_BY_ST)) { if(!Constants.NN.equalsIgnoreCase(modelConfig.getAlgorithm()) && !Constants.LR.equalsIgnoreCase(modelConfig.getAlgorithm())) { throw new IllegalArgumentException( "Filter by SE/ST only works well in NN/LR. Please check your modelconfig::train."); } distributedSEWrapper(); } else if(filterBy.equalsIgnoreCase(Constants.FILTER_BY_VOTED)) { votedVariablesSelection(); } } else { // multiple classification, select all candidate at first, TODO add SE for multi-classification for(ColumnConfig config: this.columnConfigList) { if(CommonUtils.isGoodCandidate(modelConfig.isRegression(), config)) { config.setFinalSelect(true); } } } } // save column config to file and sync to clearUp(ModelStep.VARSELECT); } catch (Exception e) { log.error("Error:", e); return -1; } log.info("Step Finished: varselect with {} ms", (System.currentTimeMillis() - start)); return 0; } private void selectByFeatureImportance() throws Exception { List<BasicML> models = null; if(!super.modelConfig.getVarSelect().getFilterEnable()) { models = CommonUtils.loadBasicModels(this.modelConfig, this.columnConfigList, null); } if(models == null || models.size() < 1) { TrainModelProcessor trainModelProcessor = new TrainModelProcessor(); trainModelProcessor.setForVarSelect(true); trainModelProcessor.run(); models = CommonUtils.loadBasicModels(this.modelConfig, this.columnConfigList, null); } // compute feature importance and write to local file Map<Integer, MutablePair<String, Double>> featureImportances = CommonUtils .computeTreeModelFeatureImportance(models); CommonUtils.writeFeatureImportance(this.pathFinder.getLocalFeatureImportancePath(), featureImportances); if(super.modelConfig.getVarSelect().getFilterEnable()) { this.postProcessFIVarSelect(featureImportances); } } public void resetAllFinalSelect() throws IOException { log.info("!!! Reset all variables finalSelect = false"); for(ColumnConfig columnConfig: this.columnConfigList) { columnConfig.setFinalSelect(false); columnConfig.setColumnFlag(null); } saveColumnConfigList(); } private void validateNormalize() throws IOException { if(!ShifuFileUtils.isFileExists( new PathFinder(modelConfig).getNormalizedDataPath(this.modelConfig.getDataSet().getSource()), this.modelConfig.getDataSet().getSource())) { throw new IllegalStateException("Cannot find normalized data, please do 'Shifu normalize' firstly."); } } private void validateSEParameters() { if(!NNConstants.NN_ALG_NAME.equalsIgnoreCase(super.getModelConfig().getTrain().getAlgorithm()) && !"LR".equalsIgnoreCase(super.getModelConfig().getTrain().getAlgorithm())) { throw new IllegalArgumentException( "Currently we only support NN and LR distributed training to do wrapper by analyzing variable selection."); } if(super.getModelConfig().getDataSet().getSource() != SourceType.HDFS) { throw new IllegalArgumentException( "Currently we only support distributed wrapper by analyzing on HDFS source type."); } if(!super.getModelConfig().isMapReduceRunMode()) { throw new IllegalArgumentException( "Currently we only support distributed wrapper by on MAPRED or DIST mode."); } } private void votedVariablesSelection() throws ClassNotFoundException, IOException, InterruptedException { log.info("Start voted variables selection "); // sync data back to hdfs super.syncDataToHdfs(modelConfig.getDataSet().getSource()); SourceType sourceType = super.getModelConfig().getDataSet().getSource(); Configuration conf = new Configuration(); final List<String> args = new ArrayList<String>(); // prepare parameter prepareVarSelParams(args, sourceType); Path columnIdsPath = getVotedSelectionPath(sourceType); args.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, ml.shifu.shifu.util.Constants.VAR_SEL_COLUMN_IDS_OUPUT, columnIdsPath.toString())); long start = System.currentTimeMillis(); GuaguaMapReduceClient guaguaClient = new GuaguaMapReduceClient(); String hdpVersion = HDPUtils.getHdpVersionForHDP224(); if(StringUtils.isNotBlank(hdpVersion)) { // for hdp 2.2.4, hdp.version should be set and configuration files should be add to container class path conf.set("hdp.version", hdpVersion); HDPUtils.addFileToClassPath(HDPUtils.findContainingFile("hdfs-site.xml"), conf); HDPUtils.addFileToClassPath(HDPUtils.findContainingFile("core-site.xml"), conf); HDPUtils.addFileToClassPath(HDPUtils.findContainingFile("mapred-site.xml"), conf); HDPUtils.addFileToClassPath(HDPUtils.findContainingFile("yarn-site.xml"), conf); } guaguaClient.createJob(args.toArray(new String[0])).waitForCompletion(true); log.info("Voted variables selection finished in {}ms.", System.currentTimeMillis() - start); persistColumnIds(columnIdsPath); super.syncDataToHdfs(sourceType); } private int persistColumnIds(Path path) { try { List<Scanner> scanners = ShifuFileUtils.getDataScanners(path.toString(), modelConfig.getDataSet() .getSource()); List<Integer> ids = null; for(Scanner scanner: scanners) { while(scanner.hasNextLine()) { String[] raw = scanner.nextLine().trim().split("\\|"); @SuppressWarnings("unused") int idSize = Integer.parseInt(raw[0]); ids = CommonUtils.stringToIntegerList(raw[1]); } } // prevent multiply running setting for(ColumnConfig config: columnConfigList) { if(!config.isForceSelect()) { config.setFinalSelect(Boolean.FALSE); } } for(Integer id: ids) { this.columnConfigList.get(id).setFinalSelect(Boolean.TRUE); } super.saveColumnConfigList(); } catch (IOException e) { e.printStackTrace(); return -1; } catch (IllegalArgumentException e) { e.printStackTrace(); return -1; } return 0; } private Path getVotedSelectionPath(SourceType sourceType) { return ShifuFileUtils.getFileSystemBySourceType(sourceType).makeQualified( new Path(getPathFinder().getVarSelsPath(sourceType), "VarSels")); } @SuppressWarnings("unused") private void prepareVarSelParams(final List<String> args, final SourceType sourceType) { args.add("-libjars"); args.add(addRuntimeJars()); args.add("-i"); args.add(ShifuFileUtils.getFileSystemBySourceType(sourceType) .makeQualified(new Path(modelConfig.getDataSetRawPath())).toString()); String zkServers = Environment.getProperty(Environment.ZOO_KEEPER_SERVERS); if(StringUtils.isEmpty(zkServers)) { log.warn("No specified zookeeper settings from zookeeperServers in shifuConfig file, Guagua will set embeded zookeeper server in client process. For big data applications, specified zookeeper servers are strongly recommended."); } else { args.add("-z"); args.add(zkServers); } // setting the class args.add("-w"); args.add(VarSelWorker.class.getName()); args.add("-m"); args.add(VarSelMaster.class.getName()); args.add("-c"); // the reason to add 1 is that the first iteration in D-NN implementation is used for training preparation. // FIXME, how to set iteration number int forceSelectCount = 0; int candidateCount = 0; for(ColumnConfig columnConfig: columnConfigList) { if(columnConfig.isForceSelect()) { forceSelectCount++; } if(CommonUtils.isGoodCandidate(columnConfig)) { candidateCount++; } } int iterationCnt = (Integer) this.modelConfig.getVarSelect().getParams() .get(CandidateGenerator.POPULATION_MULTIPLY_CNT) + 1; args.add(Integer.toString(iterationCnt)); args.add("-mr"); args.add(VarSelMasterResult.class.getName()); args.add("-wr"); args.add(VarSelWorkerResult.class.getName()); // setting conductor args.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, ml.shifu.shifu.util.Constants.VAR_SEL_MASTER_CONDUCTOR, Environment.getProperty(Environment.VAR_SEL_MASTER_CONDUCTOR, WrapperMasterConductor.class.getName()))); args.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, ml.shifu.shifu.util.Constants.VAR_SEL_WORKER_CONDUCTOR, Environment.getProperty(Environment.VAR_SEL_MASTER_CONDUCTOR, WrapperWorkerConductor.class.getName()))); // setting queue args.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, NNConstants.MAPRED_JOB_QUEUE_NAME, Environment.getProperty(Environment.HADOOP_JOB_QUEUE, ml.shifu.shifu.util.Constants.DEFAULT_JOB_QUEUE))); // MAPRED timeout args.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, NNConstants.MAPRED_TASK_TIMEOUT, Environment .getInt(NNConstants.MAPRED_TASK_TIMEOUT, ml.shifu.shifu.util.Constants.DEFAULT_MAPRED_TIME_OUT))); args.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, GuaguaConstants.GUAGUA_MASTER_INTERCEPTERS, VarSelOutput.class.getName())); // setting model config column config args.add(String.format( CommonConstants.MAPREDUCE_PARAM_FORMAT, CommonConstants.SHIFU_MODEL_CONFIG, ShifuFileUtils.getFileSystemBySourceType(sourceType).makeQualified( new Path(super.getPathFinder().getModelConfigPath(sourceType))))); args.add(String.format( CommonConstants.MAPREDUCE_PARAM_FORMAT, CommonConstants.SHIFU_COLUMN_CONFIG, ShifuFileUtils.getFileSystemBySourceType(sourceType).makeQualified( new Path(super.getPathFinder().getColumnConfigPath(sourceType))))); // source type args.add(String .format(CommonConstants.MAPREDUCE_PARAM_FORMAT, CommonConstants.MODELSET_SOURCE_TYPE, sourceType)); // computation time args.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, GuaguaConstants.GUAGUA_COMPUTATION_TIME_THRESHOLD, 60 * 60 * 1000l)); setHeapSizeAndSplitSize(args); // one can set guagua conf in shifuconfig for(Map.Entry<Object, Object> entry: Environment.getProperties().entrySet()) { if(CommonUtils.isHadoopConfigurationInjected(entry.getKey().toString())) { args.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, entry.getKey().toString(), entry .getValue().toString())); } } } // GuaguaOptionsParser doesn't to support *.jar currently. private String addRuntimeJars() { List<String> jars = new ArrayList<String>(16); // jackson-databind-*.jar jars.add(JarManager.findContainingJar(ObjectMapper.class)); // jackson-core-*.jar jars.add(JarManager.findContainingJar(JsonParser.class)); // jackson-annotations-*.jar jars.add(JarManager.findContainingJar(JsonIgnore.class)); // commons-compress-*.jar jars.add(JarManager.findContainingJar(BZip2CompressorInputStream.class)); // commons-lang-*.jar jars.add(JarManager.findContainingJar(StringUtils.class)); // commons-collections-*.jar jars.add(JarManager.findContainingJar(ListUtils.class)); // common-io-*.jar jars.add(JarManager.findContainingJar(org.apache.commons.io.IOUtils.class)); // guava-*.jar jars.add(JarManager.findContainingJar(Splitter.class)); // encog-core-*.jar jars.add(JarManager.findContainingJar(MLDataSet.class)); // shifu-*.jar jars.add(JarManager.findContainingJar(getClass())); // guagua-core-*.jar jars.add(JarManager.findContainingJar(GuaguaConstants.class)); // guagua-mapreduce-*.jar jars.add(JarManager.findContainingJar(GuaguaMapReduceConstants.class)); // zookeeper-*.jar jars.add(JarManager.findContainingJar(ZooKeeper.class)); // netty-*.jar jars.add(JarManager.findContainingJar(ServerBootstrap.class)); jars.add(JarManager.findContainingJar(JexlException.class)); String hdpVersion = HDPUtils.getHdpVersionForHDP224(); if(StringUtils.isNotBlank(hdpVersion)) { // for hdp 2.2.4, hdp.version should be set and configuration files should be add to container class path jars.add(HDPUtils.findContainingFile("hdfs-site.xml")); jars.add(HDPUtils.findContainingFile("core-site.xml")); jars.add(HDPUtils.findContainingFile("mapred-site.xml")); jars.add(HDPUtils.findContainingFile("yarn-site.xml")); } return StringUtils.join(jars, NNConstants.LIB_JAR_SEPARATOR); } /** * Wrapper through {@link TrainModelProcessor} and a MapReduce job to analyze biggest sensitivity RMS. */ private void distributedSEWrapper() throws Exception { // 1. Train a model using current selected variables, if no variables selected, use all candidate variables. TrainModelProcessor trainModelProcessor = new TrainModelProcessor(); trainModelProcessor.setForVarSelect(true); trainModelProcessor.run(); // 2. Submit a MapReduce job to analyze sensitivity RMS. SourceType source = this.modelConfig.getDataSet().getSource(); Configuration conf = new Configuration(); // 2.1 prepare se job conf prepareSEJobConf(source, conf); // 2.2 get output path String varSelectMSEOutputPath = super.getPathFinder().getVarSelectMSEOutputPath(source); // 2.3 create se job Job job = createSEMapReduceJob(source, conf, varSelectMSEOutputPath); // 2.4 clean output firstly ShifuFileUtils.deleteFile(varSelectMSEOutputPath, source); // 2.5 submit job if(job.waitForCompletion(true)) { // 2.6 post process 4 var select if(super.modelConfig.getVarSelect().getFilterEnable()) { postProcess4SEVarSelect(source, varSelectMSEOutputPath); } else { log.info("Only print sensitivity analysis report."); log.info( "Sensitivity analysis report is in {}/{}-* file(s) with format 'column_index\tcolumn_name\tmean\trms\tvariance'.", varSelectMSEOutputPath, Constants.SHIFU_VARSELECT_SE_OUTPUT_NAME); } } else { log.error("VarSelect SE hadoop job is failed, please re-try varselect step."); } } private Job createSEMapReduceJob(SourceType source, Configuration conf, String varSelectMSEOutputPath) throws IOException { @SuppressWarnings("deprecation") Job job = new Job(conf, "Shifu: Variable Selection Wrapper Job : " + this.modelConfig.getModelSetName()); job.setJarByClass(getClass()); boolean isSEVarSelMulti = Boolean.TRUE.toString().equalsIgnoreCase( Environment.getProperty(Constants.SHIFU_VARSEL_SE_MULTI, Constants.SHIFU_DEFAULT_VARSEL_SE_MULTI)); if(isSEVarSelMulti) { job.setMapperClass(MultithreadedMapper.class); MultithreadedMapper.setMapperClass(job, VarSelectMapper.class); int threads; try { threads = Integer.parseInt(Environment.getProperty(Constants.SHIFU_VARSEL_SE_MULTI_THREAD, Constants.SHIFU_DEFAULT_VARSEL_SE_MULTI_THREAD + "")); } catch (Exception e) { log.warn("'shifu.varsel.se.multi.thread' should be a int value, set default value: {}", Constants.SHIFU_DEFAULT_VARSEL_SE_MULTI_THREAD); threads = Constants.SHIFU_DEFAULT_VARSEL_SE_MULTI_THREAD; } conf.setInt("mapreduce.map.cpu.vcores", threads); MultithreadedMapper.setNumberOfThreads(job, threads); } else { job.setMapperClass(VarSelectMapper.class); } job.setMapOutputKeyClass(LongWritable.class); job.setMapOutputValueClass(ColumnInfo.class); job.setInputFormatClass(CombineInputFormat.class); FileInputFormat.setInputPaths( job, ShifuFileUtils.getFileSystemBySourceType(source).makeQualified( new Path(super.getPathFinder().getNormalizedDataPath()))); job.setReducerClass(VarSelectReducer.class); // Only one reducer, no need set combiner because of distinct keys in map outputs. job.setNumReduceTasks(1); job.setOutputKeyClass(Text.class); job.setOutputValueClass(Text.class); job.setOutputFormatClass(TextOutputFormat.class); FileOutputFormat.setOutputPath(job, new Path(varSelectMSEOutputPath)); MultipleOutputs.addNamedOutput(job, Constants.SHIFU_VARSELECT_SE_OUTPUT_NAME, TextOutputFormat.class, Text.class, Text.class); return job; } private void prepareSEJobConf(SourceType source, Configuration conf) throws IOException { // add jars to hadoop mapper and reducer new GenericOptionsParser(conf, new String[] { "-libjars", addRuntimeJars() }); conf.setBoolean(GuaguaMapReduceConstants.MAPRED_MAP_TASKS_SPECULATIVE_EXECUTION, true); conf.setBoolean(GuaguaMapReduceConstants.MAPRED_REDUCE_TASKS_SPECULATIVE_EXECUTION, true); conf.setBoolean(GuaguaMapReduceConstants.MAPREDUCE_MAP_SPECULATIVE, true); conf.setBoolean(GuaguaMapReduceConstants.MAPREDUCE_REDUCE_SPECULATIVE, true); conf.set( Constants.SHIFU_MODEL_CONFIG, ShifuFileUtils.getFileSystemBySourceType(source) .makeQualified(new Path(super.getPathFinder().getModelConfigPath(source))).toString()); conf.set( Constants.SHIFU_COLUMN_CONFIG, ShifuFileUtils.getFileSystemBySourceType(source) .makeQualified(new Path(super.getPathFinder().getColumnConfigPath(source))).toString()); conf.set(NNConstants.MAPRED_JOB_QUEUE_NAME, Environment.getProperty(Environment.HADOOP_JOB_QUEUE, "default")); conf.set(Constants.SHIFU_MODELSET_SOURCE_TYPE, source.toString()); // set mapreduce.job.max.split.locations to 100 to suppress warnings conf.setInt(GuaguaMapReduceConstants.MAPREDUCE_JOB_MAX_SPLIT_LOCATIONS, 5000); // Tmp set to false because of some cluster by default use gzip while CombineInputFormat will split gzip file (a // bug) conf.setBoolean(CombineInputFormat.SHIFU_VS_SPLIT_COMBINABLE, false); conf.setBoolean("mapreduce.input.fileinputformat.input.dir.recursive", true); conf.set("mapred.reduce.slowstart.completed.maps", Environment.getProperty("mapred.reduce.slowstart.completed.maps", "0.9")); conf.set(Constants.SHIFU_VARSELECT_FILTEROUT_TYPE, modelConfig.getVarSelectFilterBy()); Float filterOutRatio = this.modelConfig.getVarSelect().getFilterOutRatio(); if(filterOutRatio == null) { log.warn("filterOutRatio in var select is not set. Using default value 0.05."); filterOutRatio = 0.05f; } if(filterOutRatio.compareTo(Float.valueOf(1.0f)) >= 0) { throw new IllegalArgumentException("WrapperRatio should be in (0, 1)."); } conf.setFloat(Constants.SHIFU_VARSELECT_FILTEROUT_RATIO, filterOutRatio); conf.setInt(Constants.SHIFU_VARSELECT_FILTER_NUM, this.modelConfig.getVarSelectFilterNum()); String hdpVersion = HDPUtils.getHdpVersionForHDP224(); if(StringUtils.isNotBlank(hdpVersion)) { // for hdp 2.2.4, hdp.version should be set and configuration files should be add to container class path conf.set("hdp.version", hdpVersion); HDPUtils.addFileToClassPath(HDPUtils.findContainingFile("hdfs-site.xml"), conf); HDPUtils.addFileToClassPath(HDPUtils.findContainingFile("core-site.xml"), conf); HDPUtils.addFileToClassPath(HDPUtils.findContainingFile("mapred-site.xml"), conf); HDPUtils.addFileToClassPath(HDPUtils.findContainingFile("yarn-site.xml"), conf); } // one can set guagua conf in shifuconfig for(Map.Entry<Object, Object> entry: Environment.getProperties().entrySet()) { if(CommonUtils.isHadoopConfigurationInjected(entry.getKey().toString())) { conf.set(entry.getKey().toString(), entry.getValue().toString()); } } } private void postProcessFIVarSelect(Map<Integer, MutablePair<String, Double>> importances) { int selectCnt = 0; for(ColumnConfig config: super.columnConfigList) { // enable ForceSelect if(config.isForceSelect()) { config.setFinalSelect(true); selectCnt++; log.info("Variable {} is selected, since it is in ForceSelect list.", config.getColumnName()); } } VariableSelector.setFilterNumberByFilterOutRatio(this.modelConfig, this.columnConfigList); int targetCnt = this.modelConfig.getVarSelectFilterNum(); List<Integer> candidateColumnIdList = new ArrayList<Integer>(); candidateColumnIdList.addAll(importances.keySet()); int i = 0; int candidateCount = candidateColumnIdList.size(); // try to select another (targetCnt - selectCnt) variables, but we need to exclude those // force-selected variables for(ColumnConfig columnConfig: this.columnConfigList) { if(columnConfig.isFinalSelect()) { columnConfig.setFinalSelect(false); } } while(selectCnt < targetCnt && i < targetCnt) { if(i >= candidateCount) { log.warn("Var select finish due to feature importance count {} is less than target var count {}", candidateCount, targetCnt); break; } Integer columnId = candidateColumnIdList.get(i++); ColumnConfig columnConfig = this.columnConfigList.get(columnId); if(!columnConfig.isForceSelect() && !columnConfig.isForceRemove()) { columnConfig.setFinalSelect(true); selectCnt++; log.info("Variable {} is selected.", columnConfig.getColumnName()); } } log.info("{} variables are selected.", selectCnt); } private void postProcess4SEVarSelect(SourceType source, String varSelectMSEOutputPath) throws IOException { String outputFilePattern = varSelectMSEOutputPath + Path.SEPARATOR + "part-r-*"; if(!ShifuFileUtils.isFileExists(outputFilePattern, source)) { throw new RuntimeException("Var select MSE stats output file not exist."); } int selectCnt = 0; for(ColumnConfig config: super.columnConfigList) { if(config.isFinalSelect()) { config.setFinalSelect(false); } // enable ForceSelect if(config.isForceSelect()) { config.setFinalSelect(true); selectCnt++; log.info("Variable {} is selected, since it is in ForceSelect list.", config.getColumnName()); } } List<Scanner> scanners = null; try { // here only works for 1 reducer FileStatus[] globStatus = ShifuFileUtils.getFileSystemBySourceType(source).globStatus( new Path(outputFilePattern)); if(globStatus == null || globStatus.length == 0) { throw new RuntimeException("Var select MSE stats output file not exist."); } scanners = ShifuFileUtils.getDataScanners(globStatus[0].getPath().toString(), source); String str = null; int targetCnt = 0; // total variable count that user want to select List<Integer> candidateColumnIdList = new ArrayList<Integer>(); Scanner scanner = scanners.get(0); while(scanner.hasNext()) { ++targetCnt; str = scanner.nextLine().trim(); candidateColumnIdList.add(Integer.parseInt(str)); } int i = 0; int candidateCount = candidateColumnIdList.size(); // try to select another (targetCnt - selectCnt) variables, but we need to exclude those // force-selected variables while(selectCnt < targetCnt && i < targetCnt) { if(i >= candidateCount) { log.warn("Var select finish due candidate column {} is less than target var count {}", candidateCount, targetCnt); break; } Integer columnId = candidateColumnIdList.get(i++); ColumnConfig columnConfig = this.columnConfigList.get(columnId); if(!columnConfig.isForceSelect() && !columnConfig.isForceRemove()) { columnConfig.setFinalSelect(true); selectCnt++; log.info("Variable {} is selected.", columnConfig.getColumnName()); } } log.info("{} variables are selected.", selectCnt); log.info( "Sensitivity analysis report is in {}/{}-* file(s) with format 'column_index\tcolumn_name\tmean\trms\tvariance'.", varSelectMSEOutputPath, Constants.SHIFU_VARSELECT_SE_OUTPUT_NAME); this.seStatsMap = readSEValuesToMap(varSelectMSEOutputPath + Path.SEPARATOR + Constants.SHIFU_VARSELECT_SE_OUTPUT_NAME + "-*", source); } finally { if(scanners != null) { for(Scanner scanner: scanners) { if(scanner != null) { scanner.close(); } } } } } private Map<Integer, ColumnStatistics> readSEValuesToMap(String seOutputFiles, SourceType source) throws IOException { // here only works for 1 reducer FileStatus[] globStatus = ShifuFileUtils.getFileSystemBySourceType(source).globStatus(new Path(seOutputFiles)); if(globStatus == null || globStatus.length == 0) { throw new RuntimeException("Var select MSE stats output file not exist."); } Map<Integer, ColumnStatistics> map = new HashMap<Integer, ColumnStatistics>(); List<Scanner> scanners = null; try { scanners = ShifuFileUtils.getDataScanners(globStatus[0].getPath().toString(), source); for(Scanner scanner: scanners) { String str = null; while(scanner.hasNext()) { str = scanner.nextLine().trim(); String[] splits = CommonUtils.split(str, "\t"); if(splits.length == 5) { map.put(Integer.parseInt(splits[0].trim()), new ColumnStatistics(Double.parseDouble(splits[2]), Double.parseDouble(splits[3]), Double.parseDouble(splits[4]))); } } } } finally { if(scanners != null) { for(Scanner scanner: scanners) { if(scanner != null) { scanner.close(); } } } } return null; } @Override protected void clearUp(ModelStep step) throws IOException { if(!isToReset) { autoVarSelCondition(); } try { this.saveColumnConfigList(); } catch (Exception e) { throw new ShifuException(ShifuErrorCode.ERROR_WRITE_COLCONFIG, e); } this.syncDataToHdfs(this.modelConfig.getDataSet().getSource()); } /** * To do some auto variable selection like remove ID-like variables, remove variable with high missing rate. * * @throws IOException * any IO exception */ private void autoVarSelCondition() throws IOException { // here we do loop again as it is not bad for variables less than 100,000 // 1. check missing rate for(ColumnConfig config: columnConfigList) { if(!config.isTarget() && !config.isMeta() && !config.isForceSelect() && config.isFinalSelect() && isHighMissingRateColumn(config)) { log.warn( "Column {} is with very high missing rate, set final select to false. If not, you can check it manually in ColumnConfig.json", config.getColumnName()); config.setFinalSelect(false); } } // 2. check correlation value: if(!ShifuFileUtils.isFileExists(pathFinder.getLocalCorrelationCsvPath(), SourceType.LOCAL)) { return; } varSelectByCorrelation(); // 3. check KS and IV min threshold value for(ColumnConfig config: columnConfigList) { if(!config.isTarget() && !config.isMeta() && !config.isForceSelect() && config.isFinalSelect()) { float minIvThreshold = super.modelConfig.getVarSelect().getMinIvThreshold() == null ? 0f : super.modelConfig.getVarSelect().getMinIvThreshold(); if(config.getIv() != null && config.getIv() < minIvThreshold) { log.warn( "IV of column {} is less than minimal IV threshold, set final select to false. If not, you can check it manually in ColumnConfig.json", config.getColumnName()); config.setFinalSelect(false); } float minKsThreshold = super.modelConfig.getVarSelect().getMinKsThreshold() == null ? 0f : super.modelConfig.getVarSelect().getMinKsThreshold(); if(config.getKs() != null && config.getKs() < minKsThreshold) { log.warn( "KS of column {} is less than minimal KS threshold, set final select to false. If not, you can check it manually in ColumnConfig.json", config.getColumnName()); config.setFinalSelect(false); } } } } private void varSelectByCorrelation() throws IOException { BufferedReader reader = ShifuFileUtils.getReader(pathFinder.getLocalCorrelationCsvPath(), SourceType.LOCAL); int lineNum = 0; try { String line = null; while((line = reader.readLine()) != null) { lineNum += 1; if(lineNum <= 2) { // skip first 2 lines which are indexes and names continue; } String[] columns = CommonUtils.split(line, ","); if(columns != null && columns.length == columnConfigList.size() + 2) { int columnIndex = Integer.parseInt(columns[0].trim()); ColumnConfig config = this.columnConfigList.get(columnIndex); // only check final-selected non-meta columns if(config.isFinalSelect() || config.isTarget()) { double[] corrArray = getCorrArray(columns); for(int i = 0; i < corrArray.length; i++) { // only check column larger than current column index and already final selected if(config.getColumnNum() < i && (columnConfigList.get(i).isTarget() || columnConfigList.get(i).isFinalSelect())) { // * 1.000005d is to avoid some value like 1.0000000002 in correlation value if(Math.abs(corrArray[i]) > (modelConfig.getVarSelect().getCorrelationThreshold() * 1.000005d)) { if(config.isTarget() && columnConfigList.get(i).isFinalSelect()) { log.warn( "{} and {} has high correlated value while {} is target, {} is set to NOT final-selected no matter it is force-selected or not.", columnIndex, i, i); columnConfigList.get(i).setFinalSelect(false); } else if(config.isFinalSelect() && columnConfigList.get(i).isTarget()) { log.warn( "{} and {} has high correlated value while {} is target, {} is set to NOT final-selected no matter it is force-selected or not.", columnIndex, i, columnIndex); config.setFinalSelect(false); } else { // both columns are not target and all final selected ColumnConfig dropConfig = null; PostCorrelationMetric corrMetric = modelConfig.getVarSelect() .getPostCorrelationMetric(); if(checkCorrelationMetric(config, columnConfigList.get(i), corrMetric)) { dropConfig = columnConfigList.get(i); } else { dropConfig = config; } // if SE filterBy and SE postcorrelationMetric, seStatsMap has stats, do // correlation comparison by SE RMS value if(this.modelConfig.getVarSelectFilterBy().equals("SE") && corrMetric == PostCorrelationMetric.SE && this.seStatsMap != null && this.seStatsMap.get(config.getColumnNum()) != null && this.seStatsMap.get(columnConfigList.get(i).getColumnNum()) != null) { log.warn( "Absolute correlation value {} in column pair ({}, {}) ({}, {}) are larger than correlationThreshold value {} set in VarSelect#correlationThreshold, column {} name {} with smaller SE RMS value will not be selected, set finalSelect to false.", Math.abs(corrArray[i]), config.getColumnNum(), i, config.getColumnName(), columnConfigList.get(i).getColumnName(), modelConfig.getVarSelect().getCorrelationThreshold(), dropConfig.getColumnNum(), dropConfig.getColumnName()); } else { log.info( "Absolute correlation value {} in column pair ({}, {}) ({}, {}) are larger than correlationThreshold value {} set in VarSelect#correlationThreshold, column {} name {} with smaller {} value will not be selected, set finalSelect to false.", Math.abs(corrArray[i]), config.getColumnNum(), i, config.getColumnName(), columnConfigList.get(i).getColumnName(), modelConfig.getVarSelect().getCorrelationThreshold(), dropConfig.getColumnNum(), dropConfig.getColumnName(), corrMetric); } // de-select column which is dropped in current logic dropConfig.setFinalSelect(false); } } } } } } } } finally { IOUtils.closeQuietly(reader); } } private boolean checkCorrelationMetric(ColumnConfig config1, ColumnConfig config2, PostCorrelationMetric metric) { if(metric == null) { return config1.getIv() > config2.getIv(); } switch(metric) { case KS: return config1.getKs() > config2.getKs(); case SE: if(this.modelConfig.getVarSelectFilterBy().equals("SE") && this.seStatsMap != null && this.seStatsMap.get(config1.getColumnNum()) != null && this.seStatsMap.get(config2.getColumnNum()) != null) { // if bigger SE rms, means it is much important column, smaller will be dropped return this.seStatsMap.get(config1.getColumnNum()).getRms() > this.seStatsMap.get( config2.getColumnNum()).getRms(); } else { // not valid se, take iv return config1.getIv() > config2.getIv(); } case IV: default: return config1.getIv() > config2.getIv(); } } private double[] getCorrArray(String[] columns) { double[] corr = new double[columns.length - 2]; for(int i = 2; i < corr.length; i++) { corr[i - 2] = Double.parseDouble(columns[i].trim()); } return corr; } /** * Check is missing rate is over threshold. */ private boolean isHighMissingRateColumn(ColumnConfig config) { Double missingPercentage = config.getMissingPercentage(); if(missingPercentage != null && missingPercentage >= modelConfig.getVarSelect().getMissingRateThreshold()) { return true; } return false; } /** * Check if column is ID-like. */ @SuppressWarnings("unused") private boolean isIDLikeVariable(ColumnConfig config) { Long distinctCount = config.getColumnStats().getDistinctCount(); Long totalCount = config.getColumnStats().getTotalCount(); if(totalCount != null && distinctCount != null && totalCount >= 10000 && distinctCount * 1.0 / totalCount >= 0.97d) { return true; } return false; } private void setHeapSizeAndSplitSize(final List<String> args) { // args.add(String.format(NNConstants.MAPREDUCE_PARAM_FORMAT, GuaguaMapReduceConstants.MAPRED_CHILD_JAVA_OPTS, // "-Xmn128m -Xms1G -Xmx1G -verbose:gc -XX:+PrintGCDetails -XX:+PrintGCTimeStamps")); args.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, GuaguaMapReduceConstants.MAPRED_CHILD_JAVA_OPTS, "-Xmn128m -Xms1G -Xmx1G")); args.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, GuaguaConstants.GUAGUA_SPLIT_COMBINABLE, Environment.getProperty(GuaguaConstants.GUAGUA_SPLIT_COMBINABLE, "true"))); args.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, GuaguaConstants.GUAGUA_SPLIT_MAX_COMBINED_SPLIT_SIZE, Environment.getProperty(GuaguaConstants.GUAGUA_SPLIT_MAX_COMBINED_SPLIT_SIZE, "268435456"))); } }