/* * Copyright [2012-2014] PayPal Software Foundation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package ml.shifu.shifu.core.processor; import java.io.BufferedReader; import java.io.ByteArrayOutputStream; import java.io.DataOutputStream; import java.io.File; import java.io.FileNotFoundException; import java.io.IOException; import java.lang.Thread.UncaughtExceptionHandler; import java.nio.charset.Charset; import java.util.ArrayList; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Map.Entry; import java.util.Random; import java.util.Scanner; import java.util.Set; import ml.shifu.guagua.GuaguaConstants; import ml.shifu.guagua.hadoop.util.HDPUtils; import ml.shifu.guagua.mapreduce.GuaguaMapReduceClient; import ml.shifu.guagua.mapreduce.GuaguaMapReduceConstants; import ml.shifu.shifu.actor.AkkaSystemExecutor; import ml.shifu.shifu.container.obj.ColumnConfig; import ml.shifu.shifu.container.obj.ModelBasicConf.RunMode; import ml.shifu.shifu.container.obj.ModelTrainConf.MultipleClassification; import ml.shifu.shifu.container.obj.RawSourceData.SourceType; import ml.shifu.shifu.core.AbstractTrainer; import ml.shifu.shifu.core.TreeModel; import ml.shifu.shifu.core.alg.LogisticRegressionTrainer; import ml.shifu.shifu.core.alg.NNTrainer; import ml.shifu.shifu.core.alg.SVMTrainer; import ml.shifu.shifu.core.dtrain.CommonConstants; import ml.shifu.shifu.core.dtrain.DTrainUtils; import ml.shifu.shifu.core.dtrain.FeatureSubsetStrategy; import ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork; import ml.shifu.shifu.core.dtrain.dt.DTMaster; import ml.shifu.shifu.core.dtrain.dt.DTMasterParams; import ml.shifu.shifu.core.dtrain.dt.DTOutput; import ml.shifu.shifu.core.dtrain.dt.DTWorker; import ml.shifu.shifu.core.dtrain.dt.DTWorkerParams; import ml.shifu.shifu.core.dtrain.gs.GridSearch; import ml.shifu.shifu.core.dtrain.lr.LogisticRegressionContants; import ml.shifu.shifu.core.dtrain.lr.LogisticRegressionMaster; import ml.shifu.shifu.core.dtrain.lr.LogisticRegressionOutput; import ml.shifu.shifu.core.dtrain.lr.LogisticRegressionParams; import ml.shifu.shifu.core.dtrain.lr.LogisticRegressionWorker; import ml.shifu.shifu.core.dtrain.nn.ActivationReLU; import ml.shifu.shifu.core.dtrain.nn.NNConstants; import ml.shifu.shifu.core.dtrain.nn.NNMaster; import ml.shifu.shifu.core.dtrain.nn.NNOutput; import ml.shifu.shifu.core.dtrain.nn.NNParams; import ml.shifu.shifu.core.dtrain.nn.NNParquetWorker; import ml.shifu.shifu.core.dtrain.nn.NNWorker; import ml.shifu.shifu.core.validator.ModelInspector.ModelStep; import ml.shifu.shifu.exception.ShifuErrorCode; import ml.shifu.shifu.exception.ShifuException; import ml.shifu.shifu.fs.ShifuFileUtils; import ml.shifu.shifu.guagua.GuaguaParquetMapReduceClient; import ml.shifu.shifu.guagua.ShifuInputFormat; import ml.shifu.shifu.util.CommonUtils; import ml.shifu.shifu.util.Constants; import ml.shifu.shifu.util.Environment; import ml.shifu.shifu.util.HDFSUtils; import ml.shifu.shifu.util.JSONUtils; import org.antlr.runtime.RecognitionException; import org.apache.commons.collections.CollectionUtils; import org.apache.commons.collections.ListUtils; import org.apache.commons.compress.compressors.bzip2.BZip2CompressorInputStream; import org.apache.commons.io.FileUtils; import org.apache.commons.lang.StringUtils; import org.apache.commons.lang3.tuple.MutablePair; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FSDataInputStream; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.IOUtils; import org.apache.pig.LoadPushDown.RequiredField; import org.apache.pig.LoadPushDown.RequiredFieldList; import org.apache.pig.data.DataType; import org.apache.pig.impl.PigContext; import org.apache.pig.impl.util.JarManager; import org.apache.pig.impl.util.ObjectSerializer; import org.apache.zookeeper.ZooKeeper; import org.encog.engine.network.activation.ActivationFunction; import org.encog.engine.network.activation.ActivationLOG; import org.encog.engine.network.activation.ActivationLinear; import org.encog.engine.network.activation.ActivationSIN; import org.encog.engine.network.activation.ActivationSigmoid; import org.encog.engine.network.activation.ActivationTANH; import org.encog.ml.BasicML; import org.encog.ml.data.MLDataSet; import org.encog.neural.networks.BasicNetwork; import org.jboss.netty.bootstrap.ServerBootstrap; import org.joda.time.ReadableInstant; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.xerial.snappy.Snappy; import parquet.ParquetRuntimeException; import parquet.column.ParquetProperties; import parquet.column.values.bitpacking.Packer; import parquet.encoding.Generator; import parquet.format.PageType; import parquet.hadoop.ParquetRecordReader; import parquet.org.codehaus.jackson.Base64Variant; import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.core.JsonParser; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.base.Splitter; /** * Train processor, produce model based on the normalized dataset */ public class TrainModelProcessor extends BasicModelProcessor implements Processor { private final static Logger LOG = LoggerFactory.getLogger(TrainModelProcessor.class); private static final int VAR_SELECT_TRAINING_DECAY_EPOCHES_THRESHOLD = 400; public static final String SHIFU_DEFAULT_DTRAIN_PARALLEL = "true"; private boolean isDryTrain, isDebug; private List<AbstractTrainer> trainers; private static final String LOGS = "./logs"; /** * If for variable selection, only using bagging number 1 to train only one model. */ private boolean isForVarSelect; private boolean isToShuffle = false; /** * Random generator for get sampling features per each iteration. */ private Random featureSamplingRandom = new Random(); public TrainModelProcessor() { } public TrainModelProcessor(Map<String, Object> otherConfigs) { super.otherConfigs = otherConfigs; } /** * Constructor * * @param isDryTrain * dryTrain flag, if it's true, the trainer would start training * @param isDebug * debug flag, if it's true, shifu will create log file to record * each training status */ public TrainModelProcessor(boolean isDryTrain, boolean isDebug) { super(); this.isDebug = isDebug; this.isDryTrain = isDryTrain; trainers = new ArrayList<AbstractTrainer>(); } /** * Training process entry point. */ @Override public int run() throws Exception { int status = 0; if(!this.isForVarSelect()) { LOG.info("Step Start: train"); } long start = System.currentTimeMillis(); try { setUp(ModelStep.TRAIN); if(isDebug) { File file = new File(LOGS); if(!file.exists() && !file.mkdir()) { throw new RuntimeException("logs file is created failed."); } } RunMode runMode = super.modelConfig.getBasic().getRunMode(); switch(runMode) { case DIST: case MAPRED: validateDistributedTrain(); syncDataToHdfs(super.modelConfig.getDataSet().getSource()); checkAndCleanDataForTreeModels(this.isToShuffle); status = runDistributedTrain(); break; case LOCAL: default: runAkkaTrain(isForVarSelect ? 1 : modelConfig.getBaggingNum()); break; } syncDataToHdfs(modelConfig.getDataSet().getSource()); clearUp(ModelStep.TRAIN); } catch (Exception e) { LOG.error("Error:", e); return 1; } if(!this.isForVarSelect()) { LOG.info("Step Finished: train with {} ms", (System.currentTimeMillis() - start)); } return status; } /** * run training process with number of bags * * @param numBags * number of bags, it decide how much trainer will start training * @throws IOException */ private void runAkkaTrain(int numBags) throws IOException { File models = new File("models"); FileUtils.deleteDirectory(models); FileUtils.forceMkdir(models); trainers.clear(); for(int i = 0; i < numBags; i++) { AbstractTrainer trainer; if(modelConfig.getAlgorithm().equalsIgnoreCase("NN")) { trainer = new NNTrainer(modelConfig, i, isDryTrain); } else if(modelConfig.getAlgorithm().equalsIgnoreCase("SVM")) { trainer = new SVMTrainer(this.modelConfig, i, isDryTrain); } else if(modelConfig.getAlgorithm().equalsIgnoreCase("LR")) { trainer = new LogisticRegressionTrainer(this.modelConfig, i, isDryTrain); } else { throw new ShifuException(ShifuErrorCode.ERROR_UNSUPPORT_ALG); } trainers.add(trainer); } List<Scanner> scanners = null; if(modelConfig.getAlgorithm().equalsIgnoreCase("DT")) { LOG.info("Raw Data: " + pathFinder.getNormalizedDataPath()); try { scanners = ShifuFileUtils.getDataScanners(modelConfig.getDataSetRawPath(), modelConfig.getDataSet() .getSource()); } catch (IOException e) { throw new ShifuException(ShifuErrorCode.ERROR_INPUT_NOT_FOUND, e, pathFinder.getNormalizedDataPath()); } if(CollectionUtils.isNotEmpty(scanners)) { AkkaSystemExecutor.getExecutor().submitDecisionTreeTrainJob(modelConfig, columnConfigList, scanners, trainers); } } else { LOG.info("Normalized Data: " + pathFinder.getNormalizedDataPath()); try { scanners = ShifuFileUtils.getDataScanners(pathFinder.getNormalizedDataPath(), modelConfig.getDataSet() .getSource()); } catch (IOException e) { throw new ShifuException(ShifuErrorCode.ERROR_INPUT_NOT_FOUND, e, pathFinder.getNormalizedDataPath()); } if(CollectionUtils.isNotEmpty(scanners)) { AkkaSystemExecutor.getExecutor().submitModelTrainJob(modelConfig, columnConfigList, scanners, trainers); } } // release closeScanners(scanners); } /** * Get the trainer list * * @return the trainer list */ public List<AbstractTrainer> getTrainers() { return trainers; } /** * Get the trainer * * @param index * the index of trainer * @return the trainer */ public AbstractTrainer getTrainer(int index) { if(index >= trainers.size()) throw new RuntimeException("Insufficient models training"); return trainers.get(index); } private void validateDistributedTrain() throws IOException { String alg = super.getModelConfig().getTrain().getAlgorithm(); if(!(NNConstants.NN_ALG_NAME.equalsIgnoreCase(alg) // NN algorithm || LogisticRegressionContants.LR_ALG_NAME.equalsIgnoreCase(alg) // LR algorithm || CommonUtils.isTreeModel(alg))) { // RF or GBT algortihm throw new IllegalArgumentException( "Currently we only support NN, LR, RF(RandomForest) and GBDT(Gradient Boost Desicion Tree) distributed training."); } if((LogisticRegressionContants.LR_ALG_NAME.equalsIgnoreCase(alg) || CommonConstants.GBT_ALG_NAME .equalsIgnoreCase(alg)) && modelConfig.isClassification() && modelConfig.getTrain().getMultiClassifyMethod() == MultipleClassification.NATIVE) { throw new IllegalArgumentException( "Distributed LR, GBDT(Gradient Boost Desicion Tree) only support binary classification, native multiple classification is not supported."); } if(modelConfig.isClassification() && modelConfig.getTrain().isOneVsAll() && !CommonUtils.isTreeModel(alg) && !NNConstants.NN_ALG_NAME.equalsIgnoreCase(alg)) { throw new IllegalArgumentException("Only GBT and RF and NN support OneVsAll multiple classification."); } if(super.getModelConfig().getDataSet().getSource() != SourceType.HDFS) { throw new IllegalArgumentException("Currently we only support distributed training on HDFS source type."); } if(isDebug()) { LOG.warn("Currently we haven't debug logic. It's the same as you don't set it."); } // check if parquet format norm output is consistent with current isParquet setting. boolean isParquetMetaFileExist = ShifuFileUtils.getFileSystemBySourceType( super.getModelConfig().getDataSet().getSource()).exists( new Path(super.getPathFinder().getNormalizedDataPath(), "_common_metadata")); if(super.modelConfig.getNormalize().getIsParquet() && !isParquetMetaFileExist) { throw new IllegalArgumentException( "Your normlized input in " + super.getPathFinder().getNormalizedDataPath() + " is not parquet format. Please keep isParquet and re-run norm again and then run training step or change isParquet to false."); } else if(!super.modelConfig.getNormalize().getIsParquet() && isParquetMetaFileExist) { throw new IllegalArgumentException( "Your normlized input in " + super.getPathFinder().getNormalizedDataPath() + " is parquet format. Please keep isParquet and re-run norm again or change isParquet directly to true."); } GridSearch gridSearch = new GridSearch(modelConfig.getTrain().getParams()); if(!LogisticRegressionContants.LR_ALG_NAME.equalsIgnoreCase(alg) && !NNConstants.NN_ALG_NAME.equalsIgnoreCase(alg) && !CommonUtils.isTreeModel(alg) && gridSearch.hasHyperParam()) { // if grid search but not NN, not RF, not GBT, not LR throw new IllegalArgumentException("Grid search only supports NN, GBT and RF algorithms"); } if(gridSearch.hasHyperParam() && super.getModelConfig().getDataSet().getSource() != SourceType.HDFS && modelConfig.isDistributedRunMode()) { // if grid search but not mapred/dist run mode, not hdfs raw data set throw new IllegalArgumentException("Grid search only supports NN, GBT and RF algorithms"); } } protected int runDistributedTrain() throws IOException, InterruptedException, ClassNotFoundException { LOG.info("Started {} d-training.", isDryTrain ? "dry" : ""); int status = 0; Configuration conf = new Configuration(); SourceType sourceType = super.getModelConfig().getDataSet().getSource(); final List<String> args = new ArrayList<String>(); GridSearch gs = new GridSearch(modelConfig.getTrain().getParams()); prepareCommonParams(gs.hasHyperParam(), args, sourceType); String alg = super.getModelConfig().getTrain().getAlgorithm(); // add tmp models folder to config FileSystem fileSystem = ShifuFileUtils.getFileSystemBySourceType(sourceType); Path tmpModelsPath = fileSystem.makeQualified(new Path(super.getPathFinder().getPathBySourceType( new Path(Constants.TMP, Constants.DEFAULT_MODELS_TMP_FOLDER), sourceType))); args.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, CommonConstants.SHIFU_TMP_MODELS_FOLDER, tmpModelsPath.toString())); int baggingNum = isForVarSelect ? 1 : super.getModelConfig().getBaggingNum(); if(modelConfig.isClassification() && modelConfig.getTrain().isOneVsAll()) { // one vs all multiple classification, we need multiple bagging jobs to do ONEVSALL baggingNum = modelConfig.getTags().size(); if(baggingNum != super.getModelConfig().getBaggingNum()) { LOG.warn("'train:baggingNum' is set to {} because of ONEVSALL multiple classification.", baggingNum); } } boolean isKFoldCV = false; Integer kCrossValidation = this.modelConfig.getTrain().getNumKFold(); if(kCrossValidation != null && kCrossValidation > 0) { isKFoldCV = true; baggingNum = modelConfig.getTrain().getNumKFold(); if(baggingNum != super.getModelConfig().getBaggingNum()) { LOG.warn( "'train:baggingNum' is set to {} because of k-fold cross validation is enabled by 'numKFold' not -1.", baggingNum); } } long start = System.currentTimeMillis(); LOG.info("Distributed trainning with baggingNum: {}", baggingNum); boolean isParallel = Boolean.valueOf( Environment.getProperty(Constants.SHIFU_DTRAIN_PARALLEL, SHIFU_DEFAULT_DTRAIN_PARALLEL)).booleanValue(); GuaguaMapReduceClient guaguaClient; int[] inputOutputIndex = DTrainUtils.getInputOutputCandidateCounts(this.columnConfigList); int inputNodeCount = inputOutputIndex[0] == 0 ? inputOutputIndex[2] : inputOutputIndex[0]; int candidateCount = inputOutputIndex[2]; boolean isAfterVarSelect = (inputOutputIndex[0] != 0); // cache all feature list for sampling features List<Integer> allFeatures = CommonUtils.getAllFeatureList(this.columnConfigList, isAfterVarSelect); if(modelConfig.getNormalize().getIsParquet()) { guaguaClient = new GuaguaParquetMapReduceClient(); // set required field list to make sure we only load selected columns. RequiredFieldList requiredFieldList = new RequiredFieldList(); for(ColumnConfig columnConfig: super.columnConfigList) { if(columnConfig.isTarget()) { requiredFieldList.add(new RequiredField(columnConfig.getColumnName(), columnConfig.getColumnNum(), null, DataType.FLOAT)); } else { if(inputNodeCount == candidateCount) { // no any variables are selected if(!columnConfig.isMeta() && !columnConfig.isTarget() && CommonUtils.isGoodCandidate(columnConfig)) { requiredFieldList.add(new RequiredField(columnConfig.getColumnName(), columnConfig .getColumnNum(), null, DataType.FLOAT)); } } else { if(!columnConfig.isMeta() && !columnConfig.isTarget() && columnConfig.isFinalSelect()) { requiredFieldList.add(new RequiredField(columnConfig.getColumnName(), columnConfig .getColumnNum(), null, DataType.FLOAT)); } } } } // weight is added manually requiredFieldList.add(new RequiredField("weight", columnConfigList.size(), null, DataType.DOUBLE)); args.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, "parquet.private.pig.required.fields", serializeRequiredFieldList(requiredFieldList))); args.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, "parquet.private.pig.column.index.access", "true")); } else { guaguaClient = new GuaguaMapReduceClient(); } int parallelNum = Integer .parseInt(Environment.getProperty(CommonConstants.SHIFU_TRAIN_BAGGING_INPARALLEL, "5")); int parallelGroups = 1; if(gs.hasHyperParam()) { parallelGroups = (gs.getFlattenParams().size() % parallelNum == 0 ? gs.getFlattenParams().size() / parallelNum : gs.getFlattenParams().size() / parallelNum + 1); } else { parallelGroups = baggingNum % parallelNum == 0 ? baggingNum / parallelNum : baggingNum / parallelNum + 1; } List<String> progressLogList = new ArrayList<String>(baggingNum); boolean isOneJobNotContinuous = false; for(int j = 0; j < parallelGroups; j++) { int currBags = baggingNum; if(gs.hasHyperParam()) { if(j == parallelGroups - 1) { currBags = gs.getFlattenParams().size() % parallelNum == 0 ? parallelNum : gs.getFlattenParams() .size() % parallelNum; } else { currBags = parallelNum; } } else { if(j == parallelGroups - 1) { currBags = baggingNum % parallelNum == 0 ? parallelNum : baggingNum % parallelNum; } else { currBags = parallelNum; } } for(int k = 0; k < currBags; k++) { int i = j * parallelNum + k; if(gs.hasHyperParam()) { LOG.info("Start the {}th grid search job with params: {}", i, gs.getParams(i)); } else if(isKFoldCV) { LOG.info("Start the {}th k-fold cross validation job with params.", i); } List<String> localArgs = new ArrayList<String>(args); // set name for each bagging job. localArgs.add("-n"); localArgs.add(String.format("Shifu Master-Workers %s Training Iteration: %s id:%s", alg, super .getModelConfig().getModelSetName(), i)); LOG.info("Start trainer with id: {}", i); String modelName = getModelName(i); Path modelPath = fileSystem.makeQualified(new Path(super.getPathFinder().getModelsPath(sourceType), modelName)); // check if job is continunous training, this can be set multiple times and we only get last one boolean isContinous = false; if(gs.hasHyperParam()) { isContinous = false; } else { int intContinuous = checkContinuousTraining(fileSystem, localArgs, modelPath, modelConfig .getTrain().getParams()); if(intContinuous == -1) { LOG.warn( "Model with index {} with size of trees is over treeNum, such training will not be started.", i); continue; } else { isContinous = (intContinuous == 1); } } // of course gs not support continuous model training, k-fold cross validation is not continuous model // training if(gs.hasHyperParam() || isKFoldCV) { isContinous = false; } if(!isContinous && !isOneJobNotContinuous) { isOneJobNotContinuous = true; // delete all old models if not continous String srcModelPath = super.getPathFinder().getModelsPath(sourceType); String mvModelPath = srcModelPath + "_" + System.currentTimeMillis(); LOG.info("Old model path has been moved to {}", mvModelPath); fileSystem.rename(new Path(srcModelPath), new Path(mvModelPath)); fileSystem.mkdirs(new Path(srcModelPath)); FileSystem.getLocal(conf).delete(new Path(super.getPathFinder().getModelsPath(SourceType.LOCAL)), true); } if(NNConstants.NN_ALG_NAME.equalsIgnoreCase(alg)) { // tree related parameters initialization Map<String, Object> params = gs.hasHyperParam() ? gs.getParams(i) : this.modelConfig.getTrain() .getParams(); Object fssObj = params.get("FeatureSubsetStrategy"); FeatureSubsetStrategy featureSubsetStrategy = null; double featureSubsetRate = 0d; if(fssObj != null) { try { featureSubsetRate = Double.parseDouble(fssObj.toString()); // no need validate featureSubsetRate is in (0,1], as already validated in ModelInspector featureSubsetStrategy = null; } catch (NumberFormatException ee) { featureSubsetStrategy = FeatureSubsetStrategy.of(fssObj.toString()); } } else { LOG.warn("FeatureSubsetStrategy is not set, set to ALL by default."); featureSubsetStrategy = FeatureSubsetStrategy.ALL; featureSubsetRate = 0; } Set<Integer> subFeatures = null; if(isContinous) { BasicFloatNetwork existingModel = (BasicFloatNetwork) CommonUtils.loadModel(modelConfig, modelPath, ShifuFileUtils.getFileSystemBySourceType(this.modelConfig.getDataSet().getSource())); if(existingModel == null) { subFeatures = new HashSet<Integer>(getSubsamplingFeatures(allFeatures, featureSubsetStrategy, featureSubsetRate, inputNodeCount)); } else { subFeatures = existingModel.getFeatureSet(); } } else { subFeatures = new HashSet<Integer>(getSubsamplingFeatures(allFeatures, featureSubsetStrategy, featureSubsetRate, inputNodeCount)); } if(subFeatures == null || subFeatures.size() == 0) { localArgs.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, CommonConstants.SHIFU_NN_FEATURE_SUBSET, "")); } else { localArgs.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, CommonConstants.SHIFU_NN_FEATURE_SUBSET, StringUtils.join(subFeatures, ','))); LOG.debug("Size: {}, list: {}.", subFeatures.size(), StringUtils.join(subFeatures, ',')); } } localArgs.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, CommonConstants.GUAGUA_OUTPUT, modelPath.toString())); if(gs.hasHyperParam() || isKFoldCV) { // k-fold cv need val error Path valErrPath = fileSystem.makeQualified(new Path(super.getPathFinder().getValErrorPath( sourceType), "val_error_" + i)); localArgs.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, CommonConstants.GS_VALIDATION_ERROR, valErrPath.toString())); } localArgs.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, CommonConstants.SHIFU_TRAINER_ID, String.valueOf(i))); final String progressLogFile = getProgressLogFile(i); progressLogList.add(progressLogFile); localArgs.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, CommonConstants.SHIFU_DTRAIN_PROGRESS_FILE, progressLogFile)); String hdpVersion = HDPUtils.getHdpVersionForHDP224(); if(StringUtils.isNotBlank(hdpVersion)) { localArgs.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, "hdp.version", hdpVersion)); HDPUtils.addFileToClassPath(HDPUtils.findContainingFile("hdfs-site.xml"), conf); HDPUtils.addFileToClassPath(HDPUtils.findContainingFile("core-site.xml"), conf); HDPUtils.addFileToClassPath(HDPUtils.findContainingFile("mapred-site.xml"), conf); HDPUtils.addFileToClassPath(HDPUtils.findContainingFile("yarn-site.xml"), conf); } if(isParallel) { guaguaClient.addJob(localArgs.toArray(new String[0])); } else { TailThread tailThread = startTailThread(new String[] { progressLogFile }); guaguaClient.createJob(localArgs.toArray(new String[0])).waitForCompletion(true); stopTailThread(tailThread); } } if(isParallel) { TailThread tailThread = startTailThread(progressLogList.toArray(new String[0])); guaguaClient.run(); stopTailThread(tailThread); } } if(isKFoldCV) { List<Double> valErrs = readAllValidationErrors(sourceType, fileSystem, kCrossValidation); double sum = 0d; for(Double err: valErrs) { sum += err; } LOG.info("Average validation error for current k-fold cross validation is {}.", sum / valErrs.size()); LOG.info("K-fold cross validation on distributed training finished in {}ms.", System.currentTimeMillis() - start); } else if(gs.hasHyperParam()) { // select the best parameter composite in grid search LOG.info("Original grid search params: {}", modelConfig.getParams()); Map<String, Object> params = findBestParams(sourceType, fileSystem, gs); // TODO, copy top 5 models for evaluation? (no need further train) for(Entry<String, Object> entry: params.entrySet()) { modelConfig.getParams().put(entry.getKey(), entry.getValue()); } super.pathFinder.getModelConfigPath(SourceType.LOCAL); // update ModelConfig.json JSONUtils.writeValue(new File(super.pathFinder.getModelConfigPath(SourceType.LOCAL)), modelConfig); LOG.info("Grid search on distributed training finished in {}ms.", System.currentTimeMillis() - start); } else { // copy all models to local after all jobs are finished if(!gs.hasHyperParam()) { // copy model files at last. for(int i = 0; i < baggingNum; i++) { String modelName = getModelName(i); Path modelPath = fileSystem.makeQualified(new Path(super.getPathFinder().getModelsPath(sourceType), modelName)); if(ShifuFileUtils.getFileSystemBySourceType(sourceType).exists(modelPath)) { copyModelToLocal(modelName, modelPath, sourceType); } else { LOG.warn("Model {} isn't there, maybe job is failed, for bagging it can be ignored.", modelPath.toString()); status = 1; } } // copy temp model files, for RF/GBT, not to copy tmp models because of larger space needed, for others // by default copy tmp models to local boolean copyTmpModelsToLocal = Boolean.TRUE.toString().equalsIgnoreCase( Environment.getProperty(Constants.SHIFU_TMPMODEL_COPYTOLOCAL, "true")); if(CommonUtils.isTreeModel(modelConfig.getAlgorithm())) { copyTmpModelsToLocal = Boolean.TRUE.toString().equalsIgnoreCase( Environment.getProperty(Constants.SHIFU_TMPMODEL_COPYTOLOCAL, "false")); List<BasicML> models = CommonUtils.loadBasicModels(this.modelConfig, this.columnConfigList, null); // compute feature importance and write to local file after models are trained Map<Integer, MutablePair<String, Double>> featureImportances = CommonUtils .computeTreeModelFeatureImportance(models); CommonUtils.writeFeatureImportance(this.pathFinder.getLocalFeatureImportancePath(), featureImportances); } if(copyTmpModelsToLocal) { copyTmpModelsToLocal(tmpModelsPath, sourceType); } else { LOG.info("Tmp models are not copied into local, please find them in hdfs path: {}", tmpModelsPath); } LOG.info("Distributed training finished in {}ms.", System.currentTimeMillis() - start); } } if(status != 0) { LOG.error("Error may occurred. There is no model generated. Please check!"); } return status; } private Map<String, Object> findBestParams(SourceType sourceType, FileSystem fileSystem, GridSearch gs) throws IOException { // read validation error and find the best one update ModelConfig. double minValErr = Double.MAX_VALUE; int minIndex = -1; for(int i = 0; i < gs.getFlattenParams().size(); i++) { Path valErrPath = fileSystem.makeQualified(new Path(super.getPathFinder().getValErrorPath(sourceType), "val_error_" + i)); if(ShifuFileUtils.isFileExists(valErrPath.toString(), sourceType)) { double valErr; BufferedReader reader = null; try { reader = ShifuFileUtils.getReader(valErrPath.toString(), sourceType); String line = reader.readLine(); if(line == null) { continue; } String valErrStr = line.toString(); LOG.debug("valErrStr is {}", valErrStr); valErr = Double.valueOf(valErrStr); } catch (NumberFormatException e) { LOG.warn("Parse val error failed, ignore such error. Message: {}", e.getMessage()); continue; } finally { if(reader != null) { reader.close(); } } if(valErr < minValErr) { minValErr = valErr; minIndex = i; } } } Map<String, Object> params = gs.getParams(minIndex); LOG.info( "The {} params is selected by grid search with params {}, please use it and set it in ModelConfig.json.", minIndex, params); return params; } private List<Double> readAllValidationErrors(SourceType sourceType, FileSystem fileSystem, int k) throws IOException { List<Double> valErrs = new ArrayList<Double>(); for(int i = 0; i < k; i++) { Path valErrPath = fileSystem.makeQualified(new Path(super.getPathFinder().getValErrorPath(sourceType), "val_error_" + i)); if(ShifuFileUtils.isFileExists(valErrPath.toString(), sourceType)) { double valErr; BufferedReader reader = null; try { reader = ShifuFileUtils.getReader(valErrPath.toString(), sourceType); String line = reader.readLine(); if(line == null) { continue; } String valErrStr = line.toString(); LOG.debug("valErrStr is {}", valErrStr); valErr = Double.valueOf(valErrStr); valErrs.add(valErr); } catch (NumberFormatException e) { LOG.warn("Parse val error failed, ignore such error. Message: {}", e.getMessage()); continue; } finally { if(reader != null) { reader.close(); } } } } return valErrs; } static String serializeRequiredFieldList(RequiredFieldList requiredFieldList) { try { return ObjectSerializer.serialize(requiredFieldList); } catch (IOException e) { throw new RuntimeException("Failed to searlize required fields.", e); } } /* * Return 1, continuous training, 0, not continuous training, -1 GBT existing trees is over treeNum */ private int checkContinuousTraining(FileSystem fileSystem, List<String> localArgs, Path modelPath, Map<String, Object> modelParams) throws IOException { int finalContinuous = 0; if(Boolean.TRUE.toString().equals(this.modelConfig.getTrain().getIsContinuous().toString())) { // if varselect d-training or no such existing models, directly to disable continuous training. if(this.isForVarSelect) { finalContinuous = 0; LOG.warn("For varSelect step, continous model training is always disabled."); } else if(!fileSystem.exists(modelPath)) { finalContinuous = 0; LOG.info("No existing model, model training will start from scratch."); } else if(NNConstants.NN_ALG_NAME.equalsIgnoreCase(modelConfig.getAlgorithm()) && !inputOutputModelCheckSuccess(fileSystem, modelPath, modelParams)) { // TODO hidden layer size and activation functions should also be validated finalContinuous = 0; LOG.warn("Model training parameters like hidden nodes, activiation and others are not consistent with settings, model training will start from scratch."); } else if(CommonConstants.GBT_ALG_NAME.equalsIgnoreCase(modelConfig.getAlgorithm())) { TreeModel model = (TreeModel) CommonUtils.loadModel(this.modelConfig, modelPath, fileSystem); if(!model.getAlgorithm().equalsIgnoreCase(modelConfig.getAlgorithm())) { finalContinuous = 0; LOG.warn("Only GBT supports continuous training, while not GBT, will start from scratch"); } else if(!model.getLossStr().equalsIgnoreCase( this.modelConfig.getTrain().getParams().get("Loss").toString())) { finalContinuous = 0; LOG.warn("Loss is changed, continuous training is disabled, will start from scratch"); } else if(model.getTrees().size() == 0) { finalContinuous = 0; } else if(model.getTrees().size() >= Integer.valueOf(modelConfig.getTrain().getParams().get("TreeNum") .toString())) { // if over TreeNum, return -1; finalContinuous = -1; } else { finalContinuous = 1; } } else if(CommonConstants.RF_ALG_NAME.equalsIgnoreCase(modelConfig.getAlgorithm())) { finalContinuous = 0; LOG.warn("RF doesn't support continuous training"); } else { finalContinuous = 1; } } else { finalContinuous = 0; } localArgs.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, CommonConstants.CONTINUOUS_TRAINING, finalContinuous == 1 ? "true" : "false")); return finalContinuous; } @SuppressWarnings("unchecked") private boolean inputOutputModelCheckSuccess(FileSystem fileSystem, Path modelPath, Map<String, Object> modelParams) throws IOException { BasicNetwork model = (BasicNetwork) CommonUtils.loadModel(this.modelConfig, modelPath, fileSystem); int[] outputCandidateCounts = DTrainUtils.getInputOutputCandidateCounts(getColumnConfigList()); int inputs = outputCandidateCounts[0] == 0 ? outputCandidateCounts[2] : outputCandidateCounts[0]; boolean isInputOutConsistent = model.getInputCount() == inputs && model.getOutputCount() == outputCandidateCounts[1]; if(!isInputOutConsistent) { return false; } // same hidden layer ? boolean isHasSameHidderLayer = (model.getLayerCount() - 2) == (Integer) modelParams .get(CommonConstants.NUM_HIDDEN_LAYERS); if(!isHasSameHidderLayer) { return false; } // same hidden nodes ? boolean isHasSameHiddenNodes = true; List<Integer> hiddenNodeList = (List<Integer>) modelParams.get(CommonConstants.NUM_HIDDEN_NODES); for(int i = 0; i < hiddenNodeList.size(); i++) { if(model.getLayerNeuronCount(i + 1) != hiddenNodeList.get(i)) { isHasSameHiddenNodes = false; break; } } if(!isHasSameHiddenNodes) { return false; } // same activiations ? boolean isHasSameHiddenActiviation = true; List<String> actFunc = (List<String>) modelParams.get(CommonConstants.ACTIVATION_FUNC); for(int i = 0; i < actFunc.size(); i++) { ActivationFunction activation = model.getActivation(i + 1); if(actFunc.get(i).equalsIgnoreCase(NNConstants.NN_LINEAR)) { isHasSameHiddenActiviation = ActivationLinear.class == activation.getClass(); } else if(actFunc.get(i).equalsIgnoreCase(NNConstants.NN_SIGMOID)) { isHasSameHiddenActiviation = ActivationSigmoid.class == activation.getClass(); } else if(actFunc.get(i).equalsIgnoreCase(NNConstants.NN_TANH)) { isHasSameHiddenActiviation = ActivationTANH.class == activation.getClass(); } else if(actFunc.get(i).equalsIgnoreCase(NNConstants.NN_LOG)) { isHasSameHiddenActiviation = ActivationLOG.class == activation.getClass(); } else if(actFunc.get(i).equalsIgnoreCase(NNConstants.NN_SIN)) { isHasSameHiddenActiviation = ActivationSIN.class == activation.getClass(); } else if(actFunc.get(i).equalsIgnoreCase(NNConstants.NN_RELU)) { isHasSameHiddenActiviation = ActivationReLU.class == activation.getClass(); } else { isHasSameHiddenActiviation = ActivationSigmoid.class == activation.getClass(); } if(!isHasSameHiddenActiviation) { break; } } if(!isHasSameHiddenActiviation) { return false; } return true; } private String getProgressLogFile(int i) { return String.format("tmp/%s_%s.log", System.currentTimeMillis(), i); } private void stopTailThread(TailThread thread) throws IOException { thread.interrupt(); try { thread.join(NNConstants.DEFAULT_JOIN_TIME); } catch (InterruptedException e) { LOG.error("Thread stopped!", e); Thread.currentThread().interrupt(); } // delete progress file at last thread.deleteProgressFiles(); } private TailThread startTailThread(final String[] progressLog) { TailThread thread = new TailThread(progressLog); thread.setName("Tail Progress Thread"); thread.setDaemon(true); thread.setUncaughtExceptionHandler(new UncaughtExceptionHandler() { @Override public void uncaughtException(Thread t, Throwable e) { LOG.warn(String.format("Error in thread %s: %s", t.getName(), e.getMessage())); } }); thread.start(); return thread; } private void copyTmpModelsToLocal(final Path tmpModelsDir, final SourceType sourceType) throws IOException { // copy all tmp nn to local, these tmp nn are outputs from if(!this.isDryTrain()) { if(ShifuFileUtils.getFileSystemBySourceType(sourceType).exists(tmpModelsDir)) { Path localTmpModelsFolder = new Path(Constants.TMP); HDFSUtils.getLocalFS().delete(localTmpModelsFolder, true); HDFSUtils.getLocalFS().mkdirs(localTmpModelsFolder); ShifuFileUtils.getFileSystemBySourceType(sourceType) .copyToLocalFile(tmpModelsDir, localTmpModelsFolder); } } } private void prepareDTParams(final List<String> args, final SourceType sourceType) { args.add("-w"); args.add(DTWorker.class.getName()); args.add("-m"); args.add(DTMaster.class.getName()); args.add("-mr"); args.add(DTMasterParams.class.getName()); args.add("-wr"); args.add(DTWorkerParams.class.getName()); args.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, GuaguaConstants.GUAGUA_MASTER_INTERCEPTERS, DTOutput.class.getName())); } private void prepareLRParams(final List<String> args, final SourceType sourceType) { args.add("-w"); args.add(LogisticRegressionWorker.class.getName()); args.add("-m"); args.add(LogisticRegressionMaster.class.getName()); args.add("-mr"); args.add(LogisticRegressionParams.class.getName()); args.add("-wr"); args.add(LogisticRegressionParams.class.getName()); args.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, GuaguaConstants.GUAGUA_MASTER_INTERCEPTERS, LogisticRegressionOutput.class.getName())); } private void prepareNNParams(final List<String> args, final SourceType sourceType) { args.add("-w"); if(modelConfig.getNormalize().getIsParquet()) { args.add(NNParquetWorker.class.getName()); } else { args.add(NNWorker.class.getName()); } args.add("-m"); args.add(NNMaster.class.getName()); args.add("-mr"); args.add(NNParams.class.getName()); args.add("-wr"); args.add(NNParams.class.getName()); args.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, GuaguaConstants.GUAGUA_MASTER_INTERCEPTERS, NNOutput.class.getName())); } private void prepareCommonParams(boolean isGsMode, final List<String> args, final SourceType sourceType) { String alg = super.getModelConfig().getTrain().getAlgorithm(); args.add("-libjars"); addRuntimeJars(args); args.add("-i"); if(CommonUtils.isTreeModel(alg)) { args.add(ShifuFileUtils.getFileSystemBySourceType(sourceType) .makeQualified(new Path(super.getPathFinder().getCleanedDataPath())).toString()); } else { args.add(ShifuFileUtils.getFileSystemBySourceType(sourceType) .makeQualified(new Path(super.getPathFinder().getNormalizedDataPath())).toString()); } if(StringUtils.isNotBlank(modelConfig.getValidationDataSetRawPath())) { args.add("-inputformat"); args.add(ShifuInputFormat.class.getName()); } String zkServers = Environment.getProperty(Environment.ZOO_KEEPER_SERVERS); if(StringUtils.isEmpty(zkServers)) { LOG.warn("No specified zookeeper settings from zookeeperServers in shifuConfig file, Guagua will set embeded zookeeper server in client process or master node. For fail-over zookeeper applications, specified zookeeper servers are strongly recommended."); } else { args.add("-z"); args.add(zkServers); } if(LogisticRegressionContants.LR_ALG_NAME.equalsIgnoreCase(alg)) { this.prepareLRParams(args, sourceType); } else if(NNConstants.NN_ALG_NAME.equalsIgnoreCase(alg)) { this.prepareNNParams(args, sourceType); } else if(CommonUtils.isTreeModel(alg)) { this.prepareDTParams(args, sourceType); } args.add("-c"); int numTrainEpoches = super.getModelConfig().getTrain().getNumTrainEpochs(); // only for NN varselect, use half of epochs for sensitivity analysis // if for gs mode, half of iterations are used LOG.debug("this.isForVarSelect() - {}, isGsMode - {}", this.isForVarSelect(), isGsMode); if(NNConstants.NN_ALG_NAME.equalsIgnoreCase(alg) && (this.isForVarSelect() || isGsMode) && numTrainEpoches >= VAR_SELECT_TRAINING_DECAY_EPOCHES_THRESHOLD) { numTrainEpoches = numTrainEpoches / 2; } // if GBDT or RF, such iteration should be extended to make sure all trees will be executed successfully without // maxIteration limitation if(CommonUtils.isTreeModel(alg) && numTrainEpoches <= 30000) { numTrainEpoches = 30000; } // the reason to add 1 is that the first iteration in implementation is used for training preparation. numTrainEpoches = numTrainEpoches + 1; if(LogisticRegressionContants.LR_ALG_NAME.equalsIgnoreCase(alg)) { LOG.info("Number of train iterations is set to {}.", numTrainEpoches - 1); } else if(NNConstants.NN_ALG_NAME.equalsIgnoreCase(alg)) { LOG.info("Number of train epochs is set to {}.", numTrainEpoches - 1); } else if(CommonUtils.isTreeModel(alg)) { LOG.info("Number of train iterations is set to {}.", numTrainEpoches - 1); } args.add(String.valueOf(numTrainEpoches)); if(CommonUtils.isTreeModel(alg)) { // for tree models, using cleaned validation data path args.add(String.format( CommonConstants.MAPREDUCE_PARAM_FORMAT, CommonConstants.CROSS_VALIDATION_DIR, ShifuFileUtils.getFileSystemBySourceType(sourceType) .makeQualified(new Path(super.getPathFinder().getCleanedValidationDataPath(sourceType))) .toString())); } else { args.add(String.format( CommonConstants.MAPREDUCE_PARAM_FORMAT, CommonConstants.CROSS_VALIDATION_DIR, ShifuFileUtils.getFileSystemBySourceType(sourceType) .makeQualified(new Path(super.getPathFinder().getNormalizedValidationDataPath(sourceType))) .toString())); } args.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, NNConstants.MAPRED_JOB_QUEUE_NAME, Environment.getProperty(Environment.HADOOP_JOB_QUEUE, Constants.DEFAULT_JOB_QUEUE))); args.add(String.format( CommonConstants.MAPREDUCE_PARAM_FORMAT, CommonConstants.SHIFU_MODEL_CONFIG, ShifuFileUtils.getFileSystemBySourceType(sourceType).makeQualified( new Path(super.getPathFinder().getModelConfigPath(sourceType))))); args.add(String.format( CommonConstants.MAPREDUCE_PARAM_FORMAT, CommonConstants.SHIFU_COLUMN_CONFIG, ShifuFileUtils.getFileSystemBySourceType(sourceType).makeQualified( new Path(super.getPathFinder().getColumnConfigPath(sourceType))))); args.add(String .format(CommonConstants.MAPREDUCE_PARAM_FORMAT, CommonConstants.MODELSET_SOURCE_TYPE, sourceType)); args.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, CommonConstants.SHIFU_DRY_DTRAIN, isDryTrain())); args.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, NNConstants.NN_POISON_SAMPLER, Environment.getProperty(NNConstants.NN_POISON_SAMPLER, "true"))); // hard code set computation threshold for 50s. Can be changed in shifuconfig file args.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, GuaguaConstants.GUAGUA_COMPUTATION_TIME_THRESHOLD, 60 * 1000L)); setHeapSizeAndSplitSize(args); // set default embedded zookeeper to client to avoid mapper oom: master mapper embeded zookeeper will use // 512M-1G memeory which may cause oom issue. args.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, GuaguaConstants.GUAGUA_ZK_EMBEDBED_IS_IN_CLIENT, "true")); args.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, "mapreduce.map.cpu.vcores", modelConfig .getTrain().getWorkerThreadCount() == null ? 1 : modelConfig.getTrain().getWorkerThreadCount())); // one can set guagua conf in shifuconfig for(Map.Entry<Object, Object> entry: Environment.getProperties().entrySet()) { if(CommonUtils.isHadoopConfigurationInjected(entry.getKey().toString())) { args.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, entry.getKey().toString(), entry .getValue().toString())); } } } private List<Integer> getSubsamplingFeatures(List<Integer> allFeatures, FeatureSubsetStrategy featureSubsetStrategy, double featureSubsetRate, int inputNum) { if(featureSubsetStrategy == null) { if(Double.compare(1d, featureSubsetRate) == 0) { return new ArrayList<Integer>(); } else { return sampleFeaturesForNodeStats(allFeatures, (int) (allFeatures.size() * featureSubsetRate)); } } else { switch(featureSubsetStrategy) { case HALF: return sampleFeaturesForNodeStats(allFeatures, allFeatures.size() / 2); case ONETHIRD: return sampleFeaturesForNodeStats(allFeatures, allFeatures.size() / 3); case TWOTHIRDS: return sampleFeaturesForNodeStats(allFeatures, allFeatures.size() * 2 / 3); case SQRT: return sampleFeaturesForNodeStats(allFeatures, (int) (allFeatures.size() * Math.sqrt(inputNum) / inputNum)); case LOG2: return sampleFeaturesForNodeStats(allFeatures, (int) (allFeatures.size() * Math.log(inputNum) / Math.log(2) / inputNum)); case AUTO: case ALL: default: return new ArrayList<Integer>(); } } } private List<Integer> sampleFeaturesForNodeStats(List<Integer> allFeatures, int sample) { List<Integer> features = new ArrayList<Integer>(sample); for(int i = 0; i < sample; i++) { features.add(allFeatures.get(i)); } for(int i = sample; i < allFeatures.size(); i++) { int replacementIndex = (int) (featureSamplingRandom.nextDouble() * i); if(replacementIndex >= 0 && replacementIndex < sample) { features.set(replacementIndex, allFeatures.get(i)); } } return features; } private void setHeapSizeAndSplitSize(final List<String> args) { // can be override by shifuconfig, ok for hard code if(this.isDebug()) { args.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, GuaguaMapReduceConstants.MAPRED_CHILD_JAVA_OPTS, "-Xms2048m -Xmx2048m -verbose:gc -XX:+PrintGCDetails -XX:+PrintGCTimeStamps")); } else { args.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, GuaguaMapReduceConstants.MAPRED_CHILD_JAVA_OPTS, "-Xms2048m -Xmx2048m -verbose:gc -XX:+PrintGCDetails -XX:+PrintGCTimeStamps")); args.add(String .format(CommonConstants.MAPREDUCE_PARAM_FORMAT, "mapreduce.map.java.opts", "-Xms2048m -Xmx2048m -server -XX:+UseParNewGC -XX:+UseConcMarkSweepGC " + "-XX:CMSInitiatingOccupancyFraction=70 -verbose:gc -XX:+PrintGCDetails -XX:+PrintGCTimeStamps")); } if(super.modelConfig.getNormalize().getIsParquet()) { args.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, GuaguaConstants.GUAGUA_SPLIT_COMBINABLE, Environment.getProperty(GuaguaConstants.GUAGUA_SPLIT_COMBINABLE, "false"))); } else { args.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, GuaguaConstants.GUAGUA_SPLIT_COMBINABLE, Environment.getProperty(GuaguaConstants.GUAGUA_SPLIT_COMBINABLE, "true"))); long maxCombineSize = computeDynamicCombineSize(); LOG.info( "Dynamic worker size is tuned to {}. If not good for # of workers, configure it in SHIFU_HOME/conf/shifuconfig::guagua.split.maxCombinedSplitSize", maxCombineSize); args.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, GuaguaConstants.GUAGUA_SPLIT_MAX_COMBINED_SPLIT_SIZE, Environment.getProperty(GuaguaConstants.GUAGUA_SPLIT_MAX_COMBINED_SPLIT_SIZE, maxCombineSize + ""))); } // special tuning parameters for shifu, 0.97 means each iteation master wait for 97% workers and then can go to // next iteration. args.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, GuaguaConstants.GUAGUA_MIN_WORKERS_RATIO, 0.97)); // 2 seconds if waiting over 10, consider 99% workers; these two can be overrided in shifuconfig args.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, GuaguaConstants.GUAGUA_MIN_WORKERS_TIMEOUT, 2 * 1000L)); } private long computeDynamicCombineSize() { // set to dynamic to save mappers, sometimes maybe OOM, users should tune guagua.split.maxCombinedSplitSize // in shifuconfig; by default it is 200M, consider in some cases user selects only a half of features, this // number should be 400m int[] inputOutputIndex = DTrainUtils.getInputOutputCandidateCounts(this.columnConfigList); int candidateCount = inputOutputIndex[2]; // 1. set benchmark long maxCombineSize = CommonUtils.isTreeModel(modelConfig.getAlgorithm()) ? 209715200L : 168435456L; // default 200M for gbt/RF, 150M for NN // why nn default is 150M, because of all categorical data is normalized to numeric, which is to save disk // for RF/gbt, categorical is still string and so default disk size is 200M // 2. according to ratio of ( candidate count / benchmark 600 features), tune combine size, 0.85 is a factor double ratio = candidateCount / 600d; if(ratio > 2d) { // 0.85 is a factor if selected ratio is 0.5 and only be effective if selected ratio over 2 ratio = 0.85 * ratio; } maxCombineSize = Double.valueOf((maxCombineSize * 1d * (ratio))).longValue(); return maxCombineSize; } private void copyModelToLocal(String modelName, Path modelPath, SourceType sourceType) throws IOException { if(!this.isDryTrain()) { ShifuFileUtils.getFileSystemBySourceType(sourceType).copyToLocalFile(modelPath, new Path(super.getPathFinder().getModelsPath(SourceType.LOCAL), modelName)); } } // GuaguaOptionsParser doesn't to support *.jar currently. private void addRuntimeJars(final List<String> args) { List<String> jars = new ArrayList<String>(16); // jackson-databind-*.jar jars.add(JarManager.findContainingJar(ObjectMapper.class)); // jackson-core-*.jar jars.add(JarManager.findContainingJar(JsonParser.class)); // jackson-annotations-*.jar jars.add(JarManager.findContainingJar(JsonIgnore.class)); // commons-compress-*.jar jars.add(JarManager.findContainingJar(BZip2CompressorInputStream.class)); // commons-lang-*.jar jars.add(JarManager.findContainingJar(StringUtils.class)); // commons-collections-*.jar jars.add(JarManager.findContainingJar(ListUtils.class)); // common-io-*.jar jars.add(JarManager.findContainingJar(org.apache.commons.io.IOUtils.class)); // guava-*.jar jars.add(JarManager.findContainingJar(Splitter.class)); // encog-core-*.jar jars.add(JarManager.findContainingJar(MLDataSet.class)); // shifu-*.jar jars.add(JarManager.findContainingJar(getClass())); // guagua-core-*.jar jars.add(JarManager.findContainingJar(GuaguaConstants.class)); // guagua-mapreduce-*.jar jars.add(JarManager.findContainingJar(GuaguaMapReduceConstants.class)); // zookeeper-*.jar jars.add(JarManager.findContainingJar(ZooKeeper.class)); // netty-*.jar jars.add(JarManager.findContainingJar(ServerBootstrap.class)); if(modelConfig.getNormalize().getIsParquet()) { // this jars are only for parquet format // parquet-mr-*.jar jars.add(JarManager.findContainingJar(ParquetRecordReader.class)); // parquet-pig-*.jar jars.add(JarManager.findContainingJar(parquet.pig.ParquetLoader.class)); // pig-*.jar jars.add(JarManager.findContainingJar(PigContext.class)); // parquet-common-*.jar jars.add(JarManager.findContainingJar(ParquetRuntimeException.class)); // parquet-column-*.jar jars.add(JarManager.findContainingJar(ParquetProperties.class)); // parquet-encoding-*.jar jars.add(JarManager.findContainingJar(Packer.class)); // parquet-generator-*.jar jars.add(JarManager.findContainingJar(Generator.class)); // parquet-format-*.jar jars.add(JarManager.findContainingJar(PageType.class)); // snappy-*.jar jars.add(JarManager.findContainingJar(Snappy.class)); // parquet-jackson-*.jar jars.add(JarManager.findContainingJar(Base64Variant.class)); // antlr jar jars.add(JarManager.findContainingJar(RecognitionException.class)); // joda-time jar jars.add(JarManager.findContainingJar(ReadableInstant.class)); } String hdpVersion = HDPUtils.getHdpVersionForHDP224(); if(StringUtils.isNotBlank(hdpVersion)) { jars.add(HDPUtils.findContainingFile("hdfs-site.xml")); jars.add(HDPUtils.findContainingFile("core-site.xml")); jars.add(HDPUtils.findContainingFile("mapred-site.xml")); jars.add(HDPUtils.findContainingFile("yarn-site.xml")); } args.add(StringUtils.join(jars, NNConstants.LIB_JAR_SEPARATOR)); } /** * For RF/GBT model, no need do normalizing, but clean and filter data is needed. Before real training, we have to * clean and filter data. */ protected void checkAndCleanDataForTreeModels(boolean isToShuffle) throws IOException { String alg = this.getModelConfig().getTrain().getAlgorithm(); // only for tree models if(!CommonUtils.isTreeModel(alg)) { return; } // check if binBoundaries and binCategories are good and log error for(ColumnConfig columnConfig: columnConfigList) { if(columnConfig.isFinalSelect() && !columnConfig.isTarget() && !columnConfig.isMeta()) { if(columnConfig.isNumerical() && columnConfig.getBinBoundary() == null) { throw new IllegalArgumentException("Final select " + columnConfig.getColumnName() + "column but binBoundary in ColumnConfig.json is null."); } if(columnConfig.isNumerical() && columnConfig.getBinBoundary().size() <= 1) { LOG.warn( "Column {} {} with only one or zero element in binBounday, such column will be ignored in tree model training.", columnConfig.getColumnNum(), columnConfig.getColumnName()); } if(columnConfig.isCategorical() && columnConfig.getBinCategory() == null) { throw new IllegalArgumentException("Final select " + columnConfig.getColumnName() + "column but binCategory in ColumnConfig.json is null."); } if(columnConfig.isCategorical() && columnConfig.getBinCategory().size() <= 0) { LOG.warn( "Column {} {} with only zero element in binCategory, such column will be ignored in tree model training.", columnConfig.getColumnNum(), columnConfig.getColumnName()); } } } // run cleaning data logic for model input SourceType sourceType = modelConfig.getDataSet().getSource(); String cleanedDataPath = this.pathFinder.getCleanedDataPath(); String needReGen = Environment.getProperty("shifu.tree.regeninput", Boolean.FALSE.toString()); // 1. shifu.tree.regeninput = true, no matter what, will regen; // 2. if cleanedDataPath does not exist, generate clean data for tree ensemble model training // 3. if validationDataPath is not blank and cleanedValidationDataPath does not exist, generate clean data for // tree ensemble model training if(Boolean.TRUE.toString().equalsIgnoreCase(needReGen) || !ShifuFileUtils.isFileExists(cleanedDataPath, sourceType) || (StringUtils.isNotBlank(modelConfig.getValidationDataSetRawPath()) && !ShifuFileUtils.isFileExists( pathFinder.getCleanedValidationDataPath(), sourceType))) { runDataClean(isToShuffle); } else { // no need regen data LOG.warn("For RF/GBT, training input in {} exists, no need to regenerate it.", cleanedDataPath); LOG.warn("Need regen it, please set shifu.tree.regeninput in shifuconfig to true."); } } /** * Get model name * * @param i * index for model name * @return the ith model name */ public String getModelName(int i) { String alg = super.getModelConfig().getTrain().getAlgorithm(); return String.format("model%s.%s", i, alg.toLowerCase()); } // d-train part ends here public boolean isDryTrain() { return isDryTrain; } public void setDryTrain(boolean isDryTrain) { this.isDryTrain = isDryTrain; } public boolean isDebug() { return isDebug; } public void setDebug(boolean isDebug) { this.isDebug = isDebug; } public void setToShuffle(boolean toShuffle) { isToShuffle = toShuffle; } /** * @return the isForVarSelect */ public boolean isForVarSelect() { return isForVarSelect; } /** * @param isForVarSelect * the isForVarSelect to set */ public void setForVarSelect(boolean isForVarSelect) { this.isForVarSelect = isForVarSelect; } /** * A thread used to tail progress log from hdfs log file. */ private static class TailThread extends Thread { private long offset[]; private String[] progressLogs; public TailThread(String[] progressLogs) { this.progressLogs = progressLogs; this.offset = new long[this.progressLogs.length]; for(String progressLog: progressLogs) { try { // delete it firstly, it will be updated from master HDFSUtils.getFS().delete(new Path(progressLog), true); } catch (IOException e) { LOG.error("Error in delete progressLog", e); } } } public void run() { while(!Thread.currentThread().isInterrupted()) { for(int i = 0; i < this.progressLogs.length; i++) { try { this.offset[i] = dumpFromOffset(new Path(this.progressLogs[i]), this.offset[i]); } catch (FileNotFoundException e) { // ignore because of not created in worker. } catch (IOException e) { LOG.warn(String.format("Error in dump progress log %s: %s", getName(), e.getMessage())); } catch (Throwable e) { LOG.warn(String.format("Error in thread %s: %s", getName(), e.getMessage())); } } try { Thread.sleep(2000); } catch (InterruptedException e) { Thread.currentThread().interrupt(); } } LOG.debug("DEBUG: Exit from tail thread."); } private long dumpFromOffset(Path item, long offset) throws IOException { if(!HDFSUtils.getFS().exists(item)) { // if file is not there, just return initial offset and wait for it is created return 0L; } FSDataInputStream in; try { in = HDFSUtils.getFS().open(item); } catch (Exception e) { // in hadoop 0.20.2, we found InteruptedException here and cannot be caught by run, here is to ignore // such exception. It's ok we return old offset to read message twice. return offset; } ByteArrayOutputStream out = null; DataOutputStream dataOut = null; try { out = new ByteArrayOutputStream(); dataOut = new DataOutputStream(out); in.seek(offset); // use conf so the system configured io block size is used IOUtils.copyBytes(in, out, HDFSUtils.getFS().getConf(), false); String msgs = new String(out.toByteArray(), Charset.forName("UTF-8")).trim(); if(StringUtils.isNotEmpty(msgs)) { for(String msg: Splitter.on('\n').split(msgs)) { LOG.info(msg.trim()); } } offset = in.getPos(); } catch (IOException e) { if(e.getMessage().indexOf("Cannot seek after EOF") < 0) { throw e; } else { // LOG.warn(e.getMessage()); } } finally { IOUtils.closeStream(in); IOUtils.closeStream(dataOut); } return offset; } public void deleteProgressFiles() throws IOException { for(String progressFile: this.progressLogs) { HDFSUtils.getFS().delete(new Path(progressFile), true); } } } }