/* * Copyright [2013-2015] PayPal Software Foundation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package ml.shifu.shifu.core.dtrain.dt; import java.io.IOException; import java.util.ArrayList; import java.util.Comparator; import java.util.HashMap; import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Map.Entry; import java.util.PriorityQueue; import java.util.Properties; import java.util.Queue; import java.util.Random; import ml.shifu.guagua.GuaguaConstants; import ml.shifu.guagua.GuaguaRuntimeException; import ml.shifu.guagua.io.BytableSerializer; import ml.shifu.guagua.master.AbstractMasterComputable; import ml.shifu.guagua.master.MasterComputable; import ml.shifu.guagua.master.MasterContext; import ml.shifu.guagua.util.NumberFormatUtils; import ml.shifu.shifu.container.obj.ColumnConfig; import ml.shifu.shifu.container.obj.ModelConfig; import ml.shifu.shifu.container.obj.ModelTrainConf.ALGORITHM; import ml.shifu.shifu.container.obj.RawSourceData.SourceType; import ml.shifu.shifu.core.TreeModel; import ml.shifu.shifu.core.dtrain.CommonConstants; import ml.shifu.shifu.core.dtrain.DTrainUtils; import ml.shifu.shifu.core.dtrain.FeatureSubsetStrategy; import ml.shifu.shifu.core.dtrain.dt.DTWorkerParams.NodeStats; import ml.shifu.shifu.core.dtrain.gs.GridSearch; import ml.shifu.shifu.fs.ShifuFileUtils; import ml.shifu.shifu.util.CommonUtils; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FSDataInputStream; import org.apache.hadoop.fs.FSDataOutputStream; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.IOUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * Random forest and gradient boost decision tree {@link MasterComputable} implementation. * * <p> * {@link #isRF} and {@link #isGBDT} are for RF or GBDT checking, by default RF is trained. * * <p> * In each iteration, update node statistics and determine best split which is used for tree node split. Besides node * statistics, error and count info are also collected for metrics display. * * <p> * Each iteration, new node group with nodes in limited estimated memory consumption are sent out to all workers for * feature statistics. * * <p> * For gradient boost decision tree, each time a tree is updated and after one tree is finalized, then start a new tree. * Both random forest and gradient boost decision trees are all stored in {@link #trees}. * * <p> * Terminal condition: for random forest, just to collect all nodes in all trees from all workers. Terminal condition is * all trees cannot be split. If one tree cannot be split with threshold count and meaningful impurity, one tree if * finalized and stopped update. For gradient boost decision tree, each time only one tree is trained, if last tree * cannot be split, training is stopped. Early stop feature is enabled by validationTolerance in train part. * * <p> * In current {@link DTMaster}, there are states like {@link #trees} and {@link #toDoQueue}. All stats can be recovered * once master is done. Such states are being check-pointed to HDFS for fault tolerance. * * @author Zhang David (pengzhang@paypal.com) */ public class DTMaster extends AbstractMasterComputable<DTMasterParams, DTWorkerParams> { private static final Logger LOG = LoggerFactory.getLogger(DTMaster.class); /** * Model configuration loaded from configuration file. */ private ModelConfig modelConfig; /** * Column configuration loaded from configuration file. */ private List<ColumnConfig> columnConfigList; /** * Number of trees for both RF and GBDT */ private int treeNum; /** * Feature sub sampling strategy, this is combined with {@link #featureSubsetRate}, if * {@link #featureSubsetStrategy} is null, use {@link #featureSubsetRate}. Otherwise use * {@link #featureSubsetStrategy}. */ private FeatureSubsetStrategy featureSubsetStrategy = FeatureSubsetStrategy.ALL; /** * FeatureSubsetStrategy in train#params can be set to double or text, if double, use current double value but * {@link #featureSubsetStrategy} is set to null. */ private double featureSubsetRate; /** * Max depth of a tree, by default is 10. */ private int maxDepth; /** * Max leaves of a tree, by default is -1. If maxLeaves is set > 0, level-wise tree building is enabled no matter * {@link #maxDepth} set to what value. */ private int maxLeaves = -1; /** * maxLeaves >= -1, then isLeafWise set to true, else level-wise tree building. */ private boolean isLeafWise = false; /** * Max stats memory to group nodes. */ private long maxStatsMemory; /** * If variables are selected, if not, select variables with good candidate. */ private boolean isAfterVarSelect; /** * Different {@link Impurity} for node, {@link Entropy} and {@link Gini} are mostly for classification, * {@link Variance} are mostly for regression. */ private Impurity impurity; /** * If for random forest running, this is default for such master. */ private boolean isRF = true; /** * If gradient boost decision tree, for GBDT, each time a tree is trained, next train is trained by gradient label * from previous tree. */ private boolean isGBDT = false; /** * Learning rate for GBDT. */ private double learningRate; /** * How many workers, this is used for memory usage */ private int workerNumber; /** * Input features numbers */ private int inputNum; /** * Cache all features with feature index for searching */ private List<Integer> allFeatures; /** * Whether to enable continuous model training based on existing models. */ private boolean isContinuousEnabled = false; /** * If continuous model training, update this to existing tree size, by default is 0, no any impact on existing * process. */ private int existingTreeSize = 0; /** * Every checkpoint interval, do checkpoint to save {@link #trees} and {@link #toDoQueue} and MasterParams in that * iteration. */ @SuppressWarnings("unused") private int checkpointInterval; /** * Checkpoint output HDFS file */ private Path checkpointOutput; /** * Common conf to avoid new Configuration */ @SuppressWarnings("unused") private Configuration conf; /** * Checkpoint master params, if only recover queue in fail over, some todo nodes in master result will be ignored. * This is used to recover whole states of that iteration. In {@link #doCompute(MasterContext)}, check if * {@link #cpMasterParams} is null, if not, directly return this one and send {@link #cpMasterParams} to null to * avoid next iteration to send it again. */ private DTMasterParams cpMasterParams; /** * Max batch split size in leaf-wise tree growth.; This only works well when {@link #isLeafWise} = true. */ private int maxBatchSplitSize = 16; /** * DTEarlyStopDecider will decide automatic whether it need further training, this only for GBDT. */ private DTEarlyStopDecider dtEarlyStopDecider; /** * If early stop is enabled or not, by default false. */ private boolean enableEarlyStop = false; /** * Validation tolerance which is for early stop, by default it is 0d which means early stop is not enabled. */ private double validationTolerance = 0d; /** * Random generator for get sampling features per each iteration. */ private Random featureSamplingRandom = new Random(); /** * The best validation error for error computing */ private double bestValidationError = Double.MAX_VALUE; // ############################################################################################################ // ## There parts are states, for fail over such instances should be recovered in {@link #init(MasterContext)} // ############################################################################################################ /** * All trees trained in this master */ private List<TreeNode> trees; /** * TreeNode with splits will be add to this queue and after that, split a batch of nodes at the same iteration; this * only works well when {@link #isLeafWise} = true. */ private Queue<TreeNode> toSplitQueue; /** * TreeNodes needed to be collected statistics from workers. */ private Queue<TreeNode> toDoQueue; @Override public DTMasterParams doCompute(MasterContext<DTMasterParams, DTWorkerParams> context) { if(context.isFirstIteration()) { return buildInitialMasterParams(); } if(this.cpMasterParams != null) { DTMasterParams tmpMasterParams = rebuildRecoverMasterResultDepthList(); // set it to null to avoid send it in next iteration this.cpMasterParams = null; if(this.isGBDT) { // don't need to send full trees because worker will get existing models from HDFS // only set last tree to do node stats, no need check switch to next tree because of message may be send // to worker already tmpMasterParams.setTrees(trees.subList(trees.size() - 1, trees.size())); // set tmp trees for DTOutput tmpMasterParams.setTmpTrees(this.trees); } return tmpMasterParams; } boolean isFirst = false; Map<Integer, NodeStats> nodeStatsMap = null; double trainError = 0d, validationError = 0d; double weightedTrainCount = 0d, weightedValidationCount = 0d; for(DTWorkerParams params: context.getWorkerResults()) { if(!isFirst) { isFirst = true; nodeStatsMap = params.getNodeStatsMap(); } else { Map<Integer, NodeStats> currNodeStatsmap = params.getNodeStatsMap(); for(Entry<Integer, NodeStats> entry: nodeStatsMap.entrySet()) { NodeStats resultNodeStats = entry.getValue(); mergeNodeStats(resultNodeStats, currNodeStatsmap.get(entry.getKey())); } // set to null after merging, release memory at the earliest stage params.setNodeStatsMap(null); } trainError += params.getTrainError(); validationError += params.getValidationError(); weightedTrainCount += params.getTrainCount(); weightedValidationCount += params.getValidationCount(); } for(Entry<Integer, NodeStats> entry: nodeStatsMap.entrySet()) { NodeStats nodeStats = entry.getValue(); int treeId = nodeStats.getTreeId(); Node doneNode = Node.getNode(trees.get(treeId).getNode(), nodeStats.getNodeId()); // doneNode, NodeStats Map<Integer, double[]> statistics = nodeStats.getFeatureStatistics(); List<GainInfo> gainList = new ArrayList<GainInfo>(); for(Entry<Integer, double[]> gainEntry: statistics.entrySet()) { int columnNum = gainEntry.getKey(); ColumnConfig config = this.columnConfigList.get(columnNum); double[] statsArray = gainEntry.getValue(); GainInfo gainInfo = this.impurity.computeImpurity(statsArray, config); if(gainInfo != null) { gainList.add(gainInfo); } } GainInfo maxGainInfo = GainInfo.getGainInfoByMaxGain(gainList); if(maxGainInfo == null) { // null gain info, set to leaf and continue next stats doneNode.setLeaf(true); continue; } populateGainInfoToNode(treeId, doneNode, maxGainInfo); if(this.isLeafWise) { boolean isNotSplit = maxGainInfo.getGain() <= 0d; if(!isNotSplit) { this.toSplitQueue.offer(new TreeNode(treeId, doneNode)); } else { LOG.info("Node {} in tree {} is not to be split", doneNode.getId(), treeId); } } else { boolean isLeaf = maxGainInfo.getGain() <= 0d || Node.indexToLevel(doneNode.getId()) == this.maxDepth; doneNode.setLeaf(isLeaf); // level-wise is to split node when stats is ready splitNodeForLevelWisedTree(isLeaf, treeId, doneNode); } } if(this.isLeafWise) { // get node in toSplitQueue and split int currSplitIndex = 0; while(!toSplitQueue.isEmpty() && currSplitIndex < this.maxBatchSplitSize) { TreeNode treeNode = this.toSplitQueue.poll(); splitNodeForLeafWisedTree(treeNode.getTreeId(), treeNode.getNode()); } } Map<Integer, TreeNode> todoNodes = new HashMap<Integer, TreeNode>(); double averageValidationError = validationError / weightedValidationCount; if(this.isGBDT && this.dtEarlyStopDecider != null && averageValidationError > 0) { this.dtEarlyStopDecider.add(averageValidationError); averageValidationError = this.dtEarlyStopDecider.getCurrentAverageValue(); } boolean vtTriggered = false; // if validationTolerance == 0d, means vt check is not enabled if(validationTolerance > 0d && Math.abs(this.bestValidationError - averageValidationError) < this.validationTolerance * averageValidationError) { LOG.debug("Debug: bestValidationError {}, averageValidationError {}, validationTolerance {}", bestValidationError, averageValidationError, validationTolerance); vtTriggered = true; } if(averageValidationError < this.bestValidationError) { this.bestValidationError = averageValidationError; } // validation error is averageValidationError * weightedValidationCount because of here averageValidationError // is divided by validation count. DTMasterParams masterParams = new DTMasterParams(weightedTrainCount, trainError, weightedValidationCount, averageValidationError * weightedValidationCount); if(toDoQueue.isEmpty()) { if(this.isGBDT) { TreeNode treeNode = this.trees.get(this.trees.size() - 1); Node node = treeNode.getNode(); if(this.trees.size() >= this.treeNum) { // if all trees including trees read from existing model over treeNum, stop the whole process. masterParams.setHalt(true); LOG.info("Queue is empty, training is stopped in iteration {}.", context.getCurrentIteration()); } else if(node.getLeft() == null && node.getRight() == null) { // if very good performance, here can be some issues, say you'd like to get 5 trees, but in the 2nd // tree, you get one perfect tree, no need continue but warn users about such issue: set // BaggingSampleRate not to 1 can solve such issue to avoid overfit masterParams.setHalt(true); LOG.warn( "Tree is learned 100% well, there must be overfit here, please tune BaggingSampleRate, training is stopped in iteration {}.", context.getCurrentIteration()); } else if(this.dtEarlyStopDecider != null && (this.enableEarlyStop && this.dtEarlyStopDecider.canStop())) { masterParams.setHalt(true); LOG.info("Early stop identified, training is stopped in iteration {}.", context.getCurrentIteration()); } else if(vtTriggered) { LOG.info("Early stop training by validation tolerance."); masterParams.setHalt(true); } else { // set first tree to true even after ROOT node is set in next tree masterParams.setFirstTree(this.trees.size() == 1); // finish current tree, no need features information treeNode.setFeatures(null); TreeNode newRootNode = new TreeNode(this.trees.size(), new Node(Node.ROOT_INDEX), this.learningRate); LOG.info("The {} tree is to be built.", this.trees.size()); this.trees.add(newRootNode); newRootNode.setFeatures(getSubsamplingFeatures(this.featureSubsetStrategy, this.featureSubsetRate)); // only one node todoNodes.put(0, newRootNode); masterParams.setTodoNodes(todoNodes); // set switch flag masterParams.setSwitchToNextTree(true); } } else { // for rf masterParams.setHalt(true); LOG.info("Queue is empty, training is stopped in iteration {}.", context.getCurrentIteration()); } } else { int nodeIndexInGroup = 0; long currMem = 0L; List<Integer> depthList = new ArrayList<Integer>(); if(this.isGBDT) { depthList.add(-1); } if(isRF) { for(int i = 0; i < this.trees.size(); i++) { depthList.add(-1); } } while(!toDoQueue.isEmpty() && currMem <= this.maxStatsMemory) { TreeNode node = this.toDoQueue.poll(); int treeId = node.getTreeId(); int oldDepth = this.isGBDT ? depthList.get(0) : depthList.get(treeId); int currDepth = Node.indexToLevel(node.getNode().getId()); if(currDepth > oldDepth) { if(this.isGBDT) { // gbdt only for last depth depthList.set(0, currDepth); } else { depthList.set(treeId, currDepth); } } List<Integer> subsetFeatures = getSubsamplingFeatures(this.featureSubsetStrategy, this.featureSubsetRate); node.setFeatures(subsetFeatures); currMem += getStatsMem(subsetFeatures); todoNodes.put(nodeIndexInGroup, node); nodeIndexInGroup += 1; } masterParams.setTreeDepth(depthList); masterParams.setTodoNodes(todoNodes); masterParams.setSwitchToNextTree(false); masterParams.setContinuousRunningStart(false); masterParams.setFirstTree(this.trees.size() == 1); LOG.info("Todo node size is {}", todoNodes.size()); } if(this.isGBDT) { if(masterParams.isSwitchToNextTree()) { // send last full growth tree and current todo ROOT node tree masterParams.setTrees(trees.subList(trees.size() - 2, trees.size())); } else { // only send current trees masterParams.setTrees(trees.subList(trees.size() - 1, trees.size())); } } if(this.isRF) { // for rf, reset trees sent to workers for only trees with todo nodes, this saves message space. While // elements in todoTrees are also the same reference in this.trees, reuse the same object to save memory. if(masterParams.getTreeDepth().size() == this.trees.size()) { // if normal iteration List<TreeNode> todoTrees = new ArrayList<TreeNode>(); for(int i = 0; i < trees.size(); i++) { if(masterParams.getTreeDepth().get(i) >= 0) { // such tree in current iteration treeDepth is not -1, add it to todoTrees. todoTrees.add(trees.get(i)); } else { // mock a TreeNode instance to make sure no surprise in further serialization. In fact // meaningless. todoTrees.add(new TreeNode(i, new Node(Node.INVALID_INDEX), 1d)); } } masterParams.setTrees(todoTrees); } else { // if last iteration without maxDepthList masterParams.setTrees(trees); } } if(this.isGBDT) { // set tmp trees to DTOutput masterParams.setTmpTrees(this.trees); } if(context.getCurrentIteration() % 100 == 0) { // every 100 iterations do gc explicitly to avoid one case: // mapper memory is 2048M and final in our cluster, if -Xmx is 2G, then occasionally oom issue. // to fix this issue: 1. set -Xmx to 1800m; 2. call gc to drop unused memory at early stage. // this is ugly and if it is stable with 1800m, this line should be removed Thread gcThread = new Thread(new Runnable() { @Override public void run() { System.gc(); } }); gcThread.setDaemon(true); gcThread.start(); } // before master result, do checkpoint according to n iteration set by user doCheckPoint(context, masterParams, context.getCurrentIteration()); LOG.debug("weightedTrainCount {}, weightedValidationCount {}, trainError {}, validationError {}", weightedTrainCount, weightedValidationCount, trainError, validationError); return masterParams; } /** * Split node into left and right for leaf-wised tree growth, doneNode should be populated by * {@link #populateGainInfoToNode(Node, GainInfo)}. */ private void splitNodeForLeafWisedTree(int treeId, Node doneNode) { boolean isOverMaxLeaves = this.trees.get(treeId).getNodeNum() + 1 > this.maxLeaves; boolean canSplit = !isOverMaxLeaves && Double.compare(doneNode.getLeftImpurity(), 0d) != 0; // if can split left, at the same time create left and right node if(canSplit) { int leftIndex = Node.leftIndex(doneNode.getId()); Node left = new Node(leftIndex, doneNode.getLeftPredict(), doneNode.getLeftImpurity(), true); doneNode.setLeft(left); this.trees.get(treeId).incrNodeNum(); this.toDoQueue.offer(new TreeNode(treeId, left)); int rightIndex = Node.rightIndex(doneNode.getId()); Node right = new Node(rightIndex, doneNode.getRightPredict(), doneNode.getRightImpurity(), true); doneNode.setRight(right); this.trees.get(treeId).incrNodeNum(); this.toDoQueue.offer(new TreeNode(treeId, right)); } } /** * Split node into left and right for level-wised tree growth, doneNode should be populated by * {@link #populateGainInfoToNode(Node, GainInfo)} */ private void splitNodeForLevelWisedTree(boolean isLeaf, int treeId, Node doneNode) { if(!isLeaf) { boolean leftChildIsLeaf = Node.indexToLevel(doneNode.getId()) + 1 == this.maxDepth || Double.compare(doneNode.getLeftImpurity(), 0d) == 0; // such node is just set into isLeaf to true, a new node is created with leaf flag but will be // changed to final leaf in later iteration int leftIndex = Node.leftIndex(doneNode.getId()); Node left = new Node(leftIndex, doneNode.getLeftPredict(), doneNode.getLeftImpurity(), true); doneNode.setLeft(left); // update nodeNum this.trees.get(treeId).incrNodeNum(); if(!leftChildIsLeaf) { this.toDoQueue.offer(new TreeNode(treeId, left)); } else { LOG.debug("Left node {} in tree {} is set to leaf and not submitted to workers", leftIndex, treeId); } boolean rightChildIsLeaf = Node.indexToLevel(doneNode.getId()) + 1 == this.maxDepth || Double.compare(doneNode.getRightImpurity(), 0d) == 0; // such node is just set into isLeaf to true int rightIndex = Node.rightIndex(doneNode.getId()); Node right = new Node(rightIndex, doneNode.getRightPredict(), doneNode.getRightImpurity(), true); doneNode.setRight(right); // update nodeNum this.trees.get(treeId).incrNodeNum(); if(!rightChildIsLeaf) { this.toDoQueue.offer(new TreeNode(treeId, right)); } else { LOG.debug("Right node {} in tree {} is set to leaf and not submitted to workers", rightIndex, treeId); } } else { LOG.info("Done node {} in tree {} is final set to leaf", doneNode.getId(), treeId); } } private DTMasterParams rebuildRecoverMasterResultDepthList() { DTMasterParams tmpMasterParams = this.cpMasterParams; List<Integer> depthList = new ArrayList<Integer>(); if(isRF) { for(int i = 0; i < this.treeNum; i++) { depthList.add(-1); } } else if(isGBDT) { depthList.add(-1); } for(Entry<Integer, TreeNode> entry: tmpMasterParams.getTodoNodes().entrySet()) { int treeId = entry.getValue().getTreeId(); int oldDepth = isGBDT ? depthList.get(0) : depthList.get(treeId); int currDepth = Node.indexToLevel(entry.getValue().getNode().getId()); if(currDepth > oldDepth) { if(isGBDT) { depthList.set(0, currDepth); } if(isRF) { depthList.set(treeId, currDepth); } } } tmpMasterParams.setTreeDepth(depthList); return tmpMasterParams; } /** * Do checkpoint for master states, this is for master fail over */ private void doCheckPoint(final MasterContext<DTMasterParams, DTWorkerParams> context, final DTMasterParams masterParams, int iteration) { LOG.info("Do checkpoint at hdfs file {} at iteration {}.", this.checkpointOutput, iteration); final Queue<TreeNode> finalTodoQueue = this.toDoQueue; final Queue<TreeNode> finalToSplitQueue = this.toSplitQueue; final boolean finalIsLeaf = this.isLeafWise; long start = System.currentTimeMillis(); final List<TreeNode> finalTrees = new ArrayList<TreeNode>(); for(TreeNode treeNode: this.trees) { BytableSerializer<TreeNode> bs = new BytableSerializer<TreeNode>(); // clone by serialization byte[] bytes = bs.objectToBytes(treeNode); TreeNode newTreeNode = bs.bytesToObject(bytes, TreeNode.class.getName()); finalTrees.add(newTreeNode); } LOG.info("Do checkpoint at clone trees in iteration {} with run time {}", context.getCurrentIteration(), (System.currentTimeMillis() - start)); Thread cpPersistThread = new Thread(new Runnable() { @Override public void run() { writeStatesToHdfs(DTMaster.this.checkpointOutput, masterParams, finalTrees, finalIsLeaf, finalTodoQueue, finalToSplitQueue); } }, "Master checkpoint thread"); cpPersistThread.setDaemon(true); cpPersistThread.start(); } /** * Write {@link #trees}, {@link #toDoQueue} and MasterParams to HDFS. */ private void writeStatesToHdfs(Path out, DTMasterParams masterParams, List<TreeNode> trees, boolean isLeafWise, Queue<TreeNode> toDoQueue, Queue<TreeNode> toSplitQueue) { FSDataOutputStream fos = null; try { fos = FileSystem.get(new Configuration()).create(out); // trees int treeLength = trees.size(); fos.writeInt(treeLength); for(TreeNode treeNode: trees) { treeNode.write(fos); } // todo queue fos.writeInt(toDoQueue.size()); for(TreeNode treeNode: toDoQueue) { treeNode.write(fos); } if(isLeafWise && toSplitQueue != null) { fos.writeInt(toSplitQueue.size()); for(TreeNode treeNode: toSplitQueue) { treeNode.write(fos); } } // master result masterParams.write(fos); } catch (Throwable e) { LOG.error("Error in writing output.", e); } finally { IOUtils.closeStream(fos); fos = null; } } private void populateGainInfoToNode(int treeId, Node doneNode, GainInfo maxGainInfo) { doneNode.setPredict(maxGainInfo.getPredict()); doneNode.setSplit(maxGainInfo.getSplit()); doneNode.setGain(maxGainInfo.getGain()); doneNode.setImpurity(maxGainInfo.getImpurity()); doneNode.setLeftImpurity(maxGainInfo.getLeftImpurity()); doneNode.setRightImpurity(maxGainInfo.getRightImpurity()); doneNode.setLeftPredict(maxGainInfo.getLeftPredict()); doneNode.setRightPredict(maxGainInfo.getRightPredict()); if(Node.isRootNode(doneNode)) { this.trees.get(treeId).setRootWgtCnt(maxGainInfo.getWgtCnt()); } else { double rootWgtCnt = this.trees.get(treeId).getRootWgtCnt(); doneNode.setWgtCntRatio(maxGainInfo.getWgtCnt() / rootWgtCnt); } } private long getStatsMem(List<Integer> subsetFeatures) { long statsMem = 0L; List<Integer> tempFeatures = subsetFeatures; if(subsetFeatures.size() == 0) { tempFeatures = getAllFeatureList(this.columnConfigList, this.isAfterVarSelect); } for(Integer columnNum: tempFeatures) { ColumnConfig config = this.columnConfigList.get(columnNum); // 2 is overhead to avoid oom if(config.isNumerical()) { statsMem += config.getBinBoundary().size() * this.impurity.getStatsSize() * 8L * 2; } else if(config.isCategorical()) { statsMem += (config.getBinCategory().size() + 1) * this.impurity.getStatsSize() * 8L * 2; } } // times worker number to avoid oom in master, as combinable DTWorkerParams, use one third of worker number statsMem = statsMem * this.workerNumber / 2; return statsMem; } private void mergeNodeStats(NodeStats resultNodeStats, NodeStats nodeStats) { Map<Integer, double[]> featureStatistics = resultNodeStats.getFeatureStatistics(); for(Entry<Integer, double[]> entry: nodeStats.getFeatureStatistics().entrySet()) { double[] statistics = featureStatistics.get(entry.getKey()); for(int i = 0; i < statistics.length; i++) { statistics[i] += entry.getValue()[i]; } } } private DTMasterParams buildInitialMasterParams() { Map<Integer, TreeNode> todoNodes = new HashMap<Integer, TreeNode>(treeNum, 1.0f); int nodeIndexInGroup = 0; List<Integer> depthList = new ArrayList<Integer>(); DTMasterParams masterParams = new DTMasterParams(trees, todoNodes); if(isRF) { // for RF, all trees should be set depth for(int i = 0; i < this.treeNum; i++) { depthList.add(-1); } for(TreeNode treeNode: trees) { List<Integer> features = getSubsamplingFeatures(this.featureSubsetStrategy, this.featureSubsetRate); treeNode.setFeatures(features); todoNodes.put(nodeIndexInGroup, treeNode); int treeId = treeNode.getTreeId(); int oldDepth = depthList.get(treeId); int currDepth = Node.indexToLevel(treeNode.getNode().getId()); if(currDepth > oldDepth) { depthList.set(treeId, currDepth); } nodeIndexInGroup += 1; } // For RF, each time send whole trees masterParams.setTrees(this.trees); } else if(isGBDT) { // for gbdt, only store depth of last tree depthList.add(-1); List<Integer> features = getSubsamplingFeatures(this.featureSubsetStrategy, this.featureSubsetRate); TreeNode treeNode = trees.get(trees.size() - 1); // only for last tree treeNode.setFeatures(features); todoNodes.put(nodeIndexInGroup, treeNode); int oldDepth = depthList.get(0); int currDepth = Node.indexToLevel(treeNode.getNode().getId()); if(currDepth > oldDepth) { depthList.set(0, currDepth); } nodeIndexInGroup += 1; // isContinuousEnabled true means this is the first iteration for continuous model training, worker should // recover predict value from existing models masterParams.setContinuousRunningStart(this.isContinuousEnabled); // switch to next new tree for only ROOT node stats masterParams.setSwitchToNextTree(true); // if current tree is the first tree masterParams.setFirstTree(this.trees.size() == 1); // gbdt only send last tree to workers if(this.trees.size() > 0) { masterParams.setTrees(this.trees.subList(this.trees.size() - 1, this.trees.size())); } // tmp trees will not send to workers, just to DTOutput for model saving masterParams.setTmpTrees(this.trees); } masterParams.setTreeDepth(depthList); return masterParams; } private List<Integer> getSubsamplingFeatures(FeatureSubsetStrategy featureSubsetStrategy, double featureSubsetRate) { if(featureSubsetStrategy == null) { return sampleFeaturesForNodeStats(this.allFeatures, (int) (this.allFeatures.size() * featureSubsetRate)); } else { switch(featureSubsetStrategy) { case HALF: return sampleFeaturesForNodeStats(this.allFeatures, this.allFeatures.size() / 2); case ONETHIRD: return sampleFeaturesForNodeStats(this.allFeatures, this.allFeatures.size() / 3); case TWOTHIRDS: return sampleFeaturesForNodeStats(this.allFeatures, this.allFeatures.size() * 2 / 3); case SQRT: return sampleFeaturesForNodeStats(this.allFeatures, (int) (this.allFeatures.size() * Math.sqrt(this.inputNum) / this.inputNum)); case LOG2: return sampleFeaturesForNodeStats(this.allFeatures, (int) (this.allFeatures.size() * Math.log(this.inputNum) / Math.log(2) / this.inputNum)); case AUTO: if(this.treeNum > 1) { return sampleFeaturesForNodeStats(this.allFeatures, this.allFeatures.size() / 2); } else { return new ArrayList<Integer>(); } case ALL: default: return new ArrayList<Integer>(); } } } private List<Integer> getAllFeatureList(List<ColumnConfig> columnConfigList, boolean isAfterVarSelect) { List<Integer> features = new ArrayList<Integer>(); for(ColumnConfig config: columnConfigList) { if(isAfterVarSelect) { if(config.isFinalSelect() && !config.isTarget() && !config.isMeta()) { // only select numerical feature with getBinBoundary().size() larger than 1 // or categorical feature with getBinCategory().size() larger than 0 if((config.isNumerical() && config.getBinBoundary().size() > 1) || (config.isCategorical() && config.getBinCategory().size() > 0)) { features.add(config.getColumnNum()); } } } else { if(!config.isMeta() && !config.isTarget() && CommonUtils.isGoodCandidate(config)) { // only select numerical feature with getBinBoundary().size() larger than 1 // or categorical feature with getBinCategory().size() larger than 0 if((config.isNumerical() && config.getBinBoundary().size() > 1) || (config.isCategorical() && config.getBinCategory().size() > 0)) { features.add(config.getColumnNum()); } } } } return features; } private List<Integer> sampleFeaturesForNodeStats(List<Integer> allFeatures, int sample) { List<Integer> features = new ArrayList<Integer>(sample); for(int i = 0; i < sample; i++) { features.add(allFeatures.get(i)); } for(int i = sample; i < allFeatures.size(); i++) { int replacementIndex = (int) (featureSamplingRandom.nextDouble() * i); if(replacementIndex >= 0 && replacementIndex < sample) { features.set(replacementIndex, allFeatures.get(i)); } } return features; } @Override public void init(MasterContext<DTMasterParams, DTWorkerParams> context) { Properties props = context.getProps(); // init model config and column config list at first SourceType sourceType; try { sourceType = SourceType.valueOf(props.getProperty(CommonConstants.MODELSET_SOURCE_TYPE, SourceType.HDFS.toString())); this.modelConfig = CommonUtils.loadModelConfig(props.getProperty(CommonConstants.SHIFU_MODEL_CONFIG), sourceType); this.columnConfigList = CommonUtils.loadColumnConfigList( props.getProperty(CommonConstants.SHIFU_COLUMN_CONFIG), sourceType); } catch (IOException e) { throw new RuntimeException(e); } // worker number is used to estimate nodes per iteration for stats this.workerNumber = NumberFormatUtils.getInt(props.getProperty(GuaguaConstants.GUAGUA_WORKER_NUMBER), true); // check if variables are set final selected int[] inputOutputIndex = DTrainUtils.getNumericAndCategoricalInputAndOutputCounts(this.columnConfigList); this.inputNum = inputOutputIndex[0] + inputOutputIndex[1]; this.isAfterVarSelect = (inputOutputIndex[3] == 1); // cache all feature list for sampling features this.allFeatures = this.getAllFeatureList(columnConfigList, isAfterVarSelect); int trainerId = Integer.valueOf(context.getProps().getProperty(CommonConstants.SHIFU_TRAINER_ID, "0")); // If grid search, select valid paramters, if not parameters is what in ModelConfig.json GridSearch gs = new GridSearch(modelConfig.getTrain().getParams()); Map<String, Object> validParams = this.modelConfig.getTrain().getParams(); if(gs.hasHyperParam()) { validParams = gs.getParams(trainerId); LOG.info("Start grid search master with params: {}", validParams); } Object vtObj = validParams.get("ValidationTolerance"); if(vtObj != null) { try { validationTolerance = Double.parseDouble(vtObj.toString()); LOG.warn("Validation by tolerance is enabled with value {}.", validationTolerance); } catch (NumberFormatException ee) { validationTolerance = 0d; LOG.warn( "Validation by tolerance isn't enabled because of non numerical value of ValidationTolerance: {}.", vtObj); } } else { LOG.warn("Validation by tolerance isn't enabled."); } // tree related parameters initialization Object fssObj = validParams.get("FeatureSubsetStrategy"); if(fssObj != null) { try { this.featureSubsetRate = Double.parseDouble(fssObj.toString()); // no need validate featureSubsetRate is in (0,1], as already validated in ModelInspector this.featureSubsetStrategy = null; } catch (NumberFormatException ee) { this.featureSubsetStrategy = FeatureSubsetStrategy.of(fssObj.toString()); } } else { LOG.warn("FeatureSubsetStrategy is not set, set to TWOTHRIDS by default in DTMaster."); this.featureSubsetStrategy = FeatureSubsetStrategy.TWOTHIRDS; this.featureSubsetRate = 0; } // max depth Object maxDepthObj = validParams.get("MaxDepth"); if(maxDepthObj != null) { this.maxDepth = Integer.valueOf(maxDepthObj.toString()); } else { this.maxDepth = 10; } // max leaves which is used for leaf-wised tree building, TODO add more benchmarks Object maxLeavesObj = validParams.get("MaxLeaves"); if(maxLeavesObj != null) { this.maxLeaves = Integer.valueOf(maxLeavesObj.toString()); } else { this.maxLeaves = -1; } // enable leaf wise tree building once maxLeaves is configured if(this.maxLeaves > 0) { this.isLeafWise = true; } // maxBatchSplitSize means each time split # of batch nodes Object maxBatchSplitSizeObj = validParams.get("MaxBatchSplitSize"); if(maxBatchSplitSizeObj != null) { this.maxBatchSplitSize = Integer.valueOf(maxBatchSplitSizeObj.toString()); } else { // by default split 32 at most in a batch this.maxBatchSplitSize = 32; } assert this.maxDepth > 0 && this.maxDepth <= 20; // hide in parameters, this to avoid OOM issue for each iteration Object maxStatsMemoryMB = validParams.get("MaxStatsMemoryMB"); if(maxStatsMemoryMB != null) { this.maxStatsMemory = Long.valueOf(validParams.get("MaxStatsMemoryMB").toString()) * 1024 * 1024; } else { // by default it is 1/2 of heap, about 1.5G setting in current Shifu this.maxStatsMemory = Runtime.getRuntime().maxMemory() / 2L; } // assert this.maxStatsMemory <= Math.min(Runtime.getRuntime().maxMemory() * 0.6, 800 * 1024 * 1024L); this.treeNum = Integer.valueOf(validParams.get("TreeNum").toString()); this.isRF = ALGORITHM.RF.toString().equalsIgnoreCase(modelConfig.getAlgorithm()); this.isGBDT = ALGORITHM.GBT.toString().equalsIgnoreCase(modelConfig.getAlgorithm()); if(this.isGBDT) { // learning rate only effective in gbdt this.learningRate = Double.valueOf(validParams.get(CommonConstants.LEARNING_RATE).toString()); } // initialize impurity type according to regression or classfication String imStr = validParams.get("Impurity").toString(); int numClasses = 2; if(this.modelConfig.isClassification()) { numClasses = this.modelConfig.getTags().size(); } // these two parameters is to stop tree growth parameters int minInstancesPerNode = Integer.valueOf(validParams.get("MinInstancesPerNode").toString()); double minInfoGain = Double.valueOf(validParams.get("MinInfoGain").toString()); if(imStr.equalsIgnoreCase("entropy")) { impurity = new Entropy(numClasses, minInstancesPerNode, minInfoGain); } else if(imStr.equalsIgnoreCase("gini")) { impurity = new Gini(numClasses, minInstancesPerNode, minInfoGain); } else { impurity = new Variance(minInstancesPerNode, minInfoGain); } // checkpoint folder and interval (every # iterations to do checkpoint) this.checkpointInterval = NumberFormatUtils.getInt(context.getProps().getProperty( CommonConstants.SHIFU_DT_MASTER_CHECKPOINT_INTERVAL, "20")); this.checkpointOutput = new Path(context.getProps().getProperty( CommonConstants.SHIFU_DT_MASTER_CHECKPOINT_FOLDER, "tmp/cp_" + context.getAppId())); // cache conf to avoid new this.conf = new Configuration(); // if continous model training is enabled this.isContinuousEnabled = Boolean.TRUE.toString().equalsIgnoreCase( context.getProps().getProperty(CommonConstants.CONTINUOUS_TRAINING)); this.dtEarlyStopDecider = new DTEarlyStopDecider(this.maxDepth); if(validParams.containsKey("EnableEarlyStop") && Boolean.valueOf(validParams.get("EnableEarlyStop").toString().toLowerCase())) { this.enableEarlyStop = true; } LOG.info( "Master init params: isAfterVarSel={}, featureSubsetStrategy={}, featureSubsetRate={} maxDepth={}, maxStatsMemory={}, " + "treeNum={}, impurity={}, workerNumber={}, minInstancesPerNode={}, minInfoGain={}, isRF={}, " + "isGBDT={}, isContinuousEnabled={}, enableEarlyStop={}.", isAfterVarSelect, featureSubsetStrategy, this.featureSubsetRate, maxDepth, maxStatsMemory, treeNum, imStr, this.workerNumber, minInstancesPerNode, minInfoGain, this.isRF, this.isGBDT, this.isContinuousEnabled, this.enableEarlyStop); this.toDoQueue = new LinkedList<TreeNode>(); if(this.isLeafWise) { this.toSplitQueue = new PriorityQueue<TreeNode>(64, new Comparator<TreeNode>() { @Override public int compare(TreeNode o1, TreeNode o2) { return Double.compare(o2.getNode().getWgtCntRatio() * o2.getNode().getGain(), o1.getNode() .getWgtCntRatio() * o1.getNode().getGain()); } }); } // initialize trees if(context.isFirstIteration()) { if(this.isRF) { // for random forest, trees are trained in parallel this.trees = new ArrayList<TreeNode>(treeNum); for(int i = 0; i < treeNum; i++) { this.trees.add(new TreeNode(i, new Node(Node.ROOT_INDEX), 1d)); } } if(this.isGBDT) { if(isContinuousEnabled) { TreeModel existingModel; try { Path modelPath = new Path(context.getProps().getProperty(CommonConstants.GUAGUA_OUTPUT)); existingModel = (TreeModel) CommonUtils.loadModel(modelConfig, modelPath, ShifuFileUtils.getFileSystemBySourceType(this.modelConfig.getDataSet().getSource())); if(existingModel == null) { // null means no existing model file or model file is in wrong format this.trees = new ArrayList<TreeNode>(treeNum); this.trees.add(new TreeNode(0, new Node(Node.ROOT_INDEX), 1d));// learning rate is 1 for 1st LOG.info("Starting to train model from scratch and existing model is empty."); } else { this.trees = existingModel.getTrees(); this.existingTreeSize = this.trees.size(); // starting from existing models, first tree learning rate is current learning rate this.trees.add(new TreeNode(this.existingTreeSize, new Node(Node.ROOT_INDEX), this.existingTreeSize == 0 ? 1d : this.learningRate)); LOG.info("Starting to train model from existing model {} with existing trees {}.", modelPath, existingTreeSize); } } catch (IOException e) { throw new GuaguaRuntimeException(e); } } else { this.trees = new ArrayList<TreeNode>(treeNum); // for GBDT, initialize the first tree. trees are trained sequentially,first tree learning rate is 1 this.trees.add(new TreeNode(0, new Node(Node.ROOT_INDEX), 1.0d)); } } } else { // recover all states once master is fail over LOG.info("Recover master status from checkpoint file {}", this.checkpointOutput); recoverMasterStatus(sourceType); } } private void recoverMasterStatus(SourceType sourceType) { FSDataInputStream stream = null; FileSystem fs = ShifuFileUtils.getFileSystemBySourceType(sourceType); try { stream = fs.open(this.checkpointOutput); int treeSize = stream.readInt(); this.trees = new ArrayList<TreeNode>(treeSize); for(int i = 0; i < treeSize; i++) { TreeNode treeNode = new TreeNode(); treeNode.readFields(stream); this.trees.add(treeNode); } int queueSize = stream.readInt(); for(int i = 0; i < queueSize; i++) { TreeNode treeNode = new TreeNode(); treeNode.readFields(stream); this.toDoQueue.offer(treeNode); } if(this.isLeafWise && this.toSplitQueue != null) { queueSize = stream.readInt(); for(int i = 0; i < queueSize; i++) { TreeNode treeNode = new TreeNode(); treeNode.readFields(stream); this.toSplitQueue.offer(treeNode); } } this.cpMasterParams = new DTMasterParams(); this.cpMasterParams.readFields(stream); } catch (IOException e) { throw new GuaguaRuntimeException(e); } finally { org.apache.commons.io.IOUtils.closeQuietly(stream); } } }