/* * Copyright [2013-2015] PayPal Software Foundation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package ml.shifu.shifu.core.dtrain.dt; import java.io.DataInput; import java.io.DataOutput; import java.io.IOException; import java.io.Serializable; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Map.Entry; import java.util.Properties; import java.util.Random; import java.util.Set; import java.util.concurrent.Callable; import java.util.concurrent.CompletionService; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorCompletionService; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import ml.shifu.guagua.ComputableMonitor; import ml.shifu.guagua.GuaguaRuntimeException; import ml.shifu.guagua.hadoop.io.GuaguaLineRecordReader; import ml.shifu.guagua.hadoop.io.GuaguaWritableAdapter; import ml.shifu.guagua.io.Bytable; import ml.shifu.guagua.io.GuaguaFileSplit; import ml.shifu.guagua.util.MemoryLimitedList; import ml.shifu.guagua.util.NumberFormatUtils; import ml.shifu.guagua.worker.AbstractWorkerComputable; import ml.shifu.guagua.worker.WorkerContext; import ml.shifu.guagua.worker.WorkerContext.WorkerCompletionCallBack; import ml.shifu.shifu.container.obj.ColumnConfig; import ml.shifu.shifu.container.obj.ModelConfig; import ml.shifu.shifu.container.obj.ModelTrainConf.ALGORITHM; import ml.shifu.shifu.container.obj.RawSourceData.SourceType; import ml.shifu.shifu.core.TreeModel; import ml.shifu.shifu.core.dtrain.CommonConstants; import ml.shifu.shifu.core.dtrain.DTrainUtils; import ml.shifu.shifu.core.dtrain.dt.DTWorkerParams.NodeStats; import ml.shifu.shifu.core.dtrain.gs.GridSearch; import ml.shifu.shifu.fs.ShifuFileUtils; import ml.shifu.shifu.util.ClassUtils; import ml.shifu.shifu.util.CommonUtils; import org.apache.commons.lang.StringUtils; import org.apache.commons.math3.distribution.PoissonDistribution; import org.apache.hadoop.fs.FSDataInputStream; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.LongWritable; import org.apache.hadoop.io.Text; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import com.google.common.base.Splitter; /** * {@link DTWorker} is to collection node statistics for node with sub-sampling features from master {@link DTMaster}. * * <p> * Random forest and gradient boost decision tree are all supported in such worker. For RF, just to collect statistics * for nodes from master. For GBDT, extra label and predict updated in each iteration. * * <p> * For GBDT, loaded data instances will also be changed for predict and label. Which means such data can only be stored * into memory. To store predict and label in GBDT, In Data predict and label are all set even with RF. Data are stored * as float types to save memory consumption. * * <p> * For GBDT, when a new tree is transferred to worker, data predict and label are all updated and such value can be * covered according to trees and learning rate. * * <p> * For RF, bagging with replacement are enabled by {@link PoissonDistribution}. * * <p> * Weighted training are supported in our RF and GBDT impl, in such worker, data.significance is weight field set from * input. If no weight, such value is set to 1. * * <p> * Bin index is stored in each Data object as short to save memory, especially for categorical features, memory is saved * a lot from String to short. With short type, number of categories only limited in Short.MAX_VALUE. * * @author Zhang David (pengzhang@paypal.com) */ @ComputableMonitor(timeUnit = TimeUnit.SECONDS, duration = 300) public class DTWorker extends AbstractWorkerComputable<DTMasterParams, DTWorkerParams, GuaguaWritableAdapter<LongWritable>, GuaguaWritableAdapter<Text>> { protected static final Logger LOG = LoggerFactory.getLogger(DTWorker.class); /** * Model configuration loaded from configuration file. */ private ModelConfig modelConfig; /** * Column configuration loaded from configuration file. */ private List<ColumnConfig> columnConfigList; /** * Total tree numbers */ private int treeNum; /** * Basic input count for final-select variables or good candidates(if no any variables are selected) */ protected int inputCount; /** * Basic categorical input count */ protected int categoricalInputCount; /** * Means if do variable selection, if done, many variables will be set to finalSelect = true; if not, no variables * are selected and should be set to all good candidate variables. */ private boolean isAfterVarSelect = true; /** * input record size, inc one by one. */ protected long count; /** * sampled input record size. */ protected long sampleCount; /** * Positive count in training data list, only be effective in 0-1 regression or onevsall classification */ protected long positiveTrainCount; /** * Positive count in training data list and being selected in training, only be effective in 0-1 regression or * onevsall classification */ protected long positiveSelectedTrainCount; /** * Negative count in training data list , only be effective in 0-1 regression or onevsall classification */ protected long negativeTrainCount; /** * Negative count in training data list and being selected, only be effective in 0-1 regression or onevsall * classification */ protected long negativeSelectedTrainCount; /** * Positive count in validation data list, only be effective in 0-1 regression or onevsall classification */ protected long positiveValidationCount; /** * Negative count in validation data list, only be effective in 0-1 regression or onevsall classification */ protected long negativeValidationCount; /** * Training data set with only in memory because for GBDT data will be changed in later iterations. */ private volatile MemoryLimitedList<Data> trainingData; /** * Validation data set with only in memory because for GBDT data will be changed in later iterations. */ private volatile MemoryLimitedList<Data> validationData; /** * PoissonDistribution which is used for up sampling positive records. */ protected PoissonDistribution upSampleRng = null; /** * Bagging with poisson distribution instances */ private Map<Integer, PoissonDistribution[]> baggingRngMap = new HashMap<Integer, PoissonDistribution[]>(); /** * Construct a bagging random map for different classes. For stratified sampling, this is useful for each class * sampling. */ private Map<Integer, Random> baggingRandomMap = new HashMap<Integer, Random>(); /** * Construct a validation random map for different classes. For stratified sampling, this is useful for each class * sampling. */ private Map<Integer, Random> validationRandomMap = new HashMap<Integer, Random>(); /** * Default splitter used to split input record. Use one instance to prevent more news in Splitter.on. */ protected static final Splitter DEFAULT_SPLITTER = Splitter.on(CommonConstants.DEFAULT_COLUMN_SEPARATOR) .trimResults(); /** * Index map in which column index and data input array index for fast location. */ private ConcurrentMap<Integer, Integer> inputIndexMap = new ConcurrentHashMap<Integer, Integer>(); /** * Different {@link Impurity} for node, {@link Entropy} and {@link Gini} are mostly for classification, * {@link Variance} are mostly for regression. */ private Impurity impurity; /** * If for random forest running, this is default for such master. */ private boolean isRF = true; /** * If gradient boost decision tree, for GBDT, each time a tree is trained, next train is trained by gradient label * from previous tree. */ private boolean isGBDT = false; /** * Learning rate GBDT. */ private double learningRate = 0.1d; /** * Different loss strategy for GBDT. */ private Loss loss = null; /** * By default in GBDT, sample with replacement is enabled, but looks sometimes good performance with replacement & * GBDT */ private boolean gbdtSampleWithReplacement = false; /** * Trainer id used to tag bagging training job, starting from 0, 1, 2 ... */ private int trainerId = 0; /** * If one vs all method for multiple classification. */ private boolean isOneVsAll = false; /** * Create a thread pool to do gradient computing and test set error computing using multiple threads. */ private ExecutorService threadPool; /** * Worker thread count used as multiple threading to get node status */ private int workerThreadCount; /** * Indicates if validation are set by users for validationDataPath, not random picking */ protected boolean isManualValidation = false; /** * Whether to enable continuous model training based on existing models. */ private boolean isContinuousEnabled; /** * Mapping for (ColumnNum, Map(Category, CategoryIndex) for categorical feature */ private Map<Integer, Map<String, Integer>> columnCategoryIndexMapping; /** * Checkpoint output HDFS file */ private Path checkpointOutput; /** * Trees for fail over or continous model training, this is recovered from hdfs and no need back up */ private List<TreeNode> recoverTrees; /** * A flag means current worker is fail over task and gbdt predict value needs to be recovered. After data recovered, * such falg should reset to false */ private boolean isNeedRecoverGBDTPredict = false; /** * If stratified sampling or random sampling */ private boolean isStratifiedSampling = false; /** * If k-fold cross validation */ private boolean isKFoldCV; /** * Drop out rate for gbdt to drop trees in training. http://xgboost.readthedocs.io/en/latest/tutorials/dart.html */ private double dropOutRate = 0.0; /** * Random object to drop out trees, work with {@link #dropOutRate} */ private Random dropOutRandom = new Random(System.currentTimeMillis() + 5000L); /** * Random object to sample negative records */ private Random sampelNegOnlyRandom = new Random(System.currentTimeMillis() + 1000L); @Override public void initRecordReader(GuaguaFileSplit fileSplit) throws IOException { super.setRecordReader(new GuaguaLineRecordReader(fileSplit)); } protected boolean isUpSampleEnabled() { // only enabled in regression return this.upSampleRng != null && (modelConfig.isRegression() || (modelConfig.isClassification() && modelConfig.getTrain() .isOneVsAll())); } @Override public void init(WorkerContext<DTMasterParams, DTWorkerParams> context) { Properties props = context.getProps(); try { SourceType sourceType = SourceType.valueOf(props.getProperty(CommonConstants.MODELSET_SOURCE_TYPE, SourceType.HDFS.toString())); this.modelConfig = CommonUtils.loadModelConfig(props.getProperty(CommonConstants.SHIFU_MODEL_CONFIG), sourceType); this.columnConfigList = CommonUtils.loadColumnConfigList( props.getProperty(CommonConstants.SHIFU_COLUMN_CONFIG), sourceType); } catch (IOException e) { throw new RuntimeException(e); } this.columnCategoryIndexMapping = new HashMap<Integer, Map<String, Integer>>(); for(ColumnConfig config: this.columnConfigList) { if(config.isCategorical()) { if(config.getBinCategory() != null) { Map<String, Integer> tmpMap = new HashMap<String, Integer>(); for(int i = 0; i < config.getBinCategory().size(); i++) { List<String> catVals = CommonUtils.flattenCatValGrp(config.getBinCategory().get(i)); for(String cval: catVals) { tmpMap.put(cval, i); } } this.columnCategoryIndexMapping.put(config.getColumnNum(), tmpMap); } } } Integer kCrossValidation = this.modelConfig.getTrain().getNumKFold(); if(kCrossValidation != null && kCrossValidation > 0) { isKFoldCV = true; } Double upSampleWeight = modelConfig.getTrain().getUpSampleWeight(); if(Double.compare(upSampleWeight, 1d) != 0 && (modelConfig.isRegression() || (modelConfig.isClassification() && modelConfig.getTrain() .isOneVsAll()))) { // set mean to upSampleWeight -1 and get sample + 1 to make sure no zero sample value LOG.info("Enable up sampling with weight {}.", upSampleWeight); this.upSampleRng = new PoissonDistribution(upSampleWeight - 1); } this.isContinuousEnabled = Boolean.TRUE.toString().equalsIgnoreCase( context.getProps().getProperty(CommonConstants.CONTINUOUS_TRAINING)); this.workerThreadCount = modelConfig.getTrain().getWorkerThreadCount(); this.threadPool = Executors.newFixedThreadPool(this.workerThreadCount); // enable shut down logic context.addCompletionCallBack(new WorkerCompletionCallBack<DTMasterParams, DTWorkerParams>() { @Override public void callback(WorkerContext<DTMasterParams, DTWorkerParams> context) { DTWorker.this.threadPool.shutdownNow(); try { DTWorker.this.threadPool.awaitTermination(2, TimeUnit.SECONDS); } catch (InterruptedException e) { Thread.currentThread().interrupt(); } } }); this.trainerId = Integer.valueOf(context.getProps().getProperty(CommonConstants.SHIFU_TRAINER_ID, "0")); this.isOneVsAll = modelConfig.isClassification() && modelConfig.getTrain().isOneVsAll(); GridSearch gs = new GridSearch(modelConfig.getTrain().getParams()); Map<String, Object> validParams = this.modelConfig.getTrain().getParams(); if(gs.hasHyperParam()) { validParams = gs.getParams(this.trainerId); LOG.info("Start grid search worker with params: {}", validParams); } this.treeNum = Integer.valueOf(validParams.get("TreeNum").toString()); double memoryFraction = Double.valueOf(context.getProps().getProperty("guagua.data.memoryFraction", "0.6")); LOG.info("Max heap memory: {}, fraction: {}", Runtime.getRuntime().maxMemory(), memoryFraction); double validationRate = this.modelConfig.getValidSetRate(); if(StringUtils.isNotBlank(modelConfig.getValidationDataSetRawPath())) { // fixed 0.6 and 0.4 of max memory for trainingData and validationData this.trainingData = new MemoryLimitedList<Data>( (long) (Runtime.getRuntime().maxMemory() * memoryFraction * 0.6), new ArrayList<Data>()); this.validationData = new MemoryLimitedList<Data>( (long) (Runtime.getRuntime().maxMemory() * memoryFraction * 0.4), new ArrayList<Data>()); } else { if(Double.compare(validationRate, 0d) != 0) { this.trainingData = new MemoryLimitedList<Data>((long) (Runtime.getRuntime().maxMemory() * memoryFraction * (1 - validationRate)), new ArrayList<Data>()); this.validationData = new MemoryLimitedList<Data>((long) (Runtime.getRuntime().maxMemory() * memoryFraction * validationRate), new ArrayList<Data>()); } else { this.trainingData = new MemoryLimitedList<Data>( (long) (Runtime.getRuntime().maxMemory() * memoryFraction), new ArrayList<Data>()); } } int[] inputOutputIndex = DTrainUtils.getNumericAndCategoricalInputAndOutputCounts(this.columnConfigList); // numerical + categorical = # of all input this.inputCount = inputOutputIndex[0] + inputOutputIndex[1]; // regression outputNodeCount is 1, binaryClassfication, it is 1, OneVsAll it is 1, Native classification it is // 1, with index of 0,1,2,3 denotes different classes this.isAfterVarSelect = (inputOutputIndex[3] == 1); this.isManualValidation = (modelConfig.getValidationDataSetRawPath() != null && !"".equals(modelConfig .getValidationDataSetRawPath())); int numClasses = this.modelConfig.isClassification() ? this.modelConfig.getTags().size() : 2; String imStr = validParams.get("Impurity").toString(); int minInstancesPerNode = Integer.valueOf(validParams.get("MinInstancesPerNode").toString()); double minInfoGain = Double.valueOf(validParams.get("MinInfoGain").toString()); if(imStr.equalsIgnoreCase("entropy")) { impurity = new Entropy(numClasses, minInstancesPerNode, minInfoGain); } else if(imStr.equalsIgnoreCase("gini")) { impurity = new Gini(numClasses, minInstancesPerNode, minInfoGain); } else if(imStr.equalsIgnoreCase("friedmanmse")) { impurity = new FriedmanMSE(minInstancesPerNode, minInfoGain); } else { impurity = new Variance(minInstancesPerNode, minInfoGain); } this.isRF = ALGORITHM.RF.toString().equalsIgnoreCase(modelConfig.getAlgorithm()); this.isGBDT = ALGORITHM.GBT.toString().equalsIgnoreCase(modelConfig.getAlgorithm()); String lossStr = validParams.get("Loss").toString(); if(lossStr.equalsIgnoreCase("log")) { this.loss = new LogLoss(); } else if(lossStr.equalsIgnoreCase("absolute")) { this.loss = new AbsoluteLoss(); } else if(lossStr.equalsIgnoreCase("halfgradsquared")) { this.loss = new HalfGradSquaredLoss(); } else if(lossStr.equalsIgnoreCase("squared")) { this.loss = new SquaredLoss(); } else { try { this.loss = (Loss) ClassUtils.newInstance(Class.forName(lossStr)); } catch (ClassNotFoundException e) { LOG.warn("Class not found for {}, using default SquaredLoss", lossStr); this.loss = new SquaredLoss(); } } if(this.isGBDT) { this.learningRate = Double.valueOf(validParams.get(CommonConstants.LEARNING_RATE).toString()); Object swrObj = validParams.get("GBTSampleWithReplacement"); if(swrObj != null) { this.gbdtSampleWithReplacement = Boolean.TRUE.toString().equalsIgnoreCase(swrObj.toString()); } Object dropoutObj = validParams.get(CommonConstants.DROPOUT_RATE); if(dropoutObj != null) { this.dropOutRate = Double.valueOf(dropoutObj.toString()); } } this.isStratifiedSampling = this.modelConfig.getTrain().getStratifiedSample(); this.checkpointOutput = new Path(context.getProps().getProperty( CommonConstants.SHIFU_DT_MASTER_CHECKPOINT_FOLDER, "tmp/cp_" + context.getAppId())); LOG.info( "Worker init params:isAfterVarSel={}, treeNum={}, impurity={}, loss={}, learningRate={}, gbdtSampleWithReplacement={}, isRF={}, isGBDT={}, isStratifiedSampling={}, isKFoldCV={}, kCrossValidation={}, dropOutRate={}", isAfterVarSelect, treeNum, impurity.getClass().getName(), loss.getClass().getName(), this.learningRate, this.gbdtSampleWithReplacement, this.isRF, this.isGBDT, this.isStratifiedSampling, this.isKFoldCV, kCrossValidation, this.dropOutRate); // for fail over, load existing trees if(!context.isFirstIteration()) { if(this.isGBDT) { // set flag here and recover later in doComputing, this is to make sure recover after load part which // can load latest trees in #doCompute isNeedRecoverGBDTPredict = true; } else { // RF , trees are recovered from last master results recoverTrees = context.getLastMasterResult().getTrees(); } } if(context.isFirstIteration() && this.isContinuousEnabled && this.isGBDT) { Path modelPath = new Path(context.getProps().getProperty(CommonConstants.GUAGUA_OUTPUT)); TreeModel existingModel = null; try { existingModel = (TreeModel) CommonUtils.loadModel(modelConfig, modelPath, ShifuFileUtils.getFileSystemBySourceType(this.modelConfig.getDataSet().getSource())); } catch (IOException e) { LOG.error("Error in get existing model, will ignore and start from scratch", e); } if(existingModel == null) { LOG.warn("No mdel is found even set to continuous model training."); return; } else { recoverTrees = existingModel.getTrees(); LOG.info("Loading existing {} trees", recoverTrees.size()); } } } /* * (non-Javadoc) * * @see ml.shifu.guagua.worker.AbstractWorkerComputable#doCompute(ml.shifu.guagua.worker.WorkerContext) */ @Override public DTWorkerParams doCompute(WorkerContext<DTMasterParams, DTWorkerParams> context) { if(context.isFirstIteration()) { return new DTWorkerParams(); } DTMasterParams lastMasterResult = context.getLastMasterResult(); final List<TreeNode> trees = lastMasterResult.getTrees(); final Map<Integer, TreeNode> todoNodes = lastMasterResult.getTodoNodes(); if(todoNodes == null) { return new DTWorkerParams(); } Map<Integer, NodeStats> statistics = initTodoNodeStats(todoNodes); double trainError = 0d, validationError = 0d; double weightedTrainCount = 0d, weightedValidationCount = 0d; // renew random seed if(this.isGBDT && !this.gbdtSampleWithReplacement && lastMasterResult.isSwitchToNextTree()) { this.baggingRandomMap = new HashMap<Integer, Random>(); } long start = System.nanoTime(); for(Data data: this.trainingData) { if(this.isRF) { for(TreeNode treeNode: trees) { if(treeNode.getNode().getId() == Node.INVALID_INDEX) { continue; } Node predictNode = predictNodeIndex(treeNode.getNode(), data, true); if(predictNode.getPredict() != null) { // only update when not in first node, for treeNode, no predict statistics at that time float weight = data.subsampleWeights[treeNode.getTreeId()]; if(Float.compare(weight, 0f) == 0) { // oob data, no need to do weighting validationError += data.significance * loss.computeError((float) (predictNode.getPredict().getPredict()), data.label); weightedValidationCount += data.significance; } else { trainError += weight * data.significance * loss.computeError((float) (predictNode.getPredict().getPredict()), data.label); weightedTrainCount += weight * data.significance; } } } } if(this.isGBDT) { if(this.isContinuousEnabled && lastMasterResult.isContinuousRunningStart()) { recoverGBTData(context, data.output, data.predict, data, false); trainError += data.significance * loss.computeError(data.predict, data.label); weightedTrainCount += data.significance; } else { if(isNeedRecoverGBDTPredict) { if(this.recoverTrees == null) { this.recoverTrees = recoverCurrentTrees(); } // recover gbdt data for fail over recoverGBTData(context, data.output, data.predict, data, true); } int currTreeIndex = trees.size() - 1; if(lastMasterResult.isSwitchToNextTree()) { if(currTreeIndex >= 1) { Node node = trees.get(currTreeIndex - 1).getNode(); Node predictNode = predictNodeIndex(node, data, false); if(predictNode.getPredict() != null) { double predict = predictNode.getPredict().getPredict(); // first tree logic, master must set it to first tree even second tree with ROOT is // sending if(context.getLastMasterResult().isFirstTree()) { data.predict = (float) predict; } else { // random drop boolean drop = (this.dropOutRate > 0.0 && dropOutRandom.nextDouble() < this.dropOutRate); if(!drop) { data.predict += (float) (this.learningRate * predict); } } data.output = -1f * loss.computeGradient(data.predict, data.label); } // if not sampling with replacement in gbdt, renew bagging sample rate in next tree if(!this.gbdtSampleWithReplacement) { Random random = null; int classValue = (int) (data.label + 0.01f); if(this.isStratifiedSampling) { random = baggingRandomMap.get(classValue); if(random == null) { random = new Random(); baggingRandomMap.put(classValue, random); } } else { random = baggingRandomMap.get(0); if(random == null) { random = new Random(); baggingRandomMap.put(0, random); } } if(random.nextDouble() <= modelConfig.getTrain().getBaggingSampleRate()) { data.subsampleWeights[currTreeIndex % data.subsampleWeights.length] = 1f; } else { data.subsampleWeights[currTreeIndex % data.subsampleWeights.length] = 0f; } } } } if(context.getLastMasterResult().isFirstTree() && !lastMasterResult.isSwitchToNextTree()) { Node currTree = trees.get(currTreeIndex).getNode(); Node predictNode = predictNodeIndex(currTree, data, true); if(predictNode.getPredict() != null) { trainError += data.significance * loss.computeError((float) (predictNode.getPredict().getPredict()), data.label); weightedTrainCount += data.significance; } } else { trainError += data.significance * loss.computeError(data.predict, data.label); weightedTrainCount += data.significance; } } } } LOG.debug("Compute train error time is {}ms", TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - start)); if(validationData != null) { start = System.nanoTime(); for(Data data: this.validationData) { if(this.isRF) { for(TreeNode treeNode: trees) { if(treeNode.getNode().getId() == Node.INVALID_INDEX) { continue; } Node predictNode = predictNodeIndex(treeNode.getNode(), data, true); if(predictNode.getPredict() != null) { // only update when not in first node, for treeNode, no predict statistics at that time validationError += data.significance * loss.computeError((float) (predictNode.getPredict().getPredict()), data.label); weightedValidationCount += data.significance; } } } if(this.isGBDT) { if(this.isContinuousEnabled && lastMasterResult.isContinuousRunningStart()) { recoverGBTData(context, data.output, data.predict, data, false); validationError += data.significance * loss.computeError(data.predict, data.label); weightedValidationCount += data.significance; } else { if(isNeedRecoverGBDTPredict) { if(this.recoverTrees == null) { this.recoverTrees = recoverCurrentTrees(); } // recover gbdt data for fail over recoverGBTData(context, data.output, data.predict, data, true); } int currTreeIndex = trees.size() - 1; if(lastMasterResult.isSwitchToNextTree()) { if(currTreeIndex >= 1) { Node node = trees.get(currTreeIndex - 1).getNode(); Node predictNode = predictNodeIndex(node, data, false); if(predictNode.getPredict() != null) { double predict = predictNode.getPredict().getPredict(); if(context.getLastMasterResult().isFirstTree()) { data.predict = (float) predict; } else { data.predict += (float) (this.learningRate * predict); } data.output = -1f * loss.computeGradient(data.predict, data.label); } } } if(context.getLastMasterResult().isFirstTree() && !lastMasterResult.isSwitchToNextTree()) { Node predictNode = predictNodeIndex(trees.get(currTreeIndex).getNode(), data, true); if(predictNode.getPredict() != null) { validationError += data.significance * loss.computeError((float) (predictNode.getPredict().getPredict()), data.label); weightedValidationCount += data.significance; } } else { validationError += data.significance * loss.computeError(data.predict, data.label); weightedValidationCount += data.significance; } } } } LOG.debug("Compute val error time is {}ms", TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - start)); } if(this.isGBDT) { // reset trees to null to save memory this.recoverTrees = null; if(this.isNeedRecoverGBDTPredict) { // no need recover again this.isNeedRecoverGBDTPredict = false; } } start = System.nanoTime(); CompletionService<Map<Integer, NodeStats>> completionService = new ExecutorCompletionService<Map<Integer, NodeStats>>( this.threadPool); int realThreadCount = 0; LOG.debug("while todo size {}", todoNodes.size()); int realRecords = this.trainingData.size(); int realThreads = this.workerThreadCount > realRecords ? realRecords : this.workerThreadCount; int[] trainLows = new int[realThreads]; int[] trainHighs = new int[realThreads]; int stepCount = realRecords / realThreads; if(realRecords % realThreads != 0) { // move step count to append last gap to avoid last thread worse 2*stepCount-1 stepCount += (realRecords % realThreads) / stepCount; } for(int i = 0; i < realThreads; i++) { trainLows[i] = i * stepCount; if(i != realThreads - 1) { trainHighs[i] = trainLows[i] + stepCount - 1; } else { trainHighs[i] = realRecords - 1; } } for(int i = 0; i < realThreads; i++) { final Map<Integer, TreeNode> localTodoNodes = new HashMap<Integer, TreeNode>(todoNodes); final Map<Integer, NodeStats> localStatistics = initTodoNodeStats(todoNodes); final int startIndex = trainLows[i]; final int endIndex = trainHighs[i]; LOG.info("Thread {} todo size {} stats size {} start index {} end index {}", i, localTodoNodes.size(), localStatistics.size(), startIndex, endIndex); if(localTodoNodes.size() == 0) { continue; } realThreadCount += 1; completionService.submit(new Callable<Map<Integer, NodeStats>>() { @Override public Map<Integer, NodeStats> call() throws Exception { long start = System.nanoTime(); List<Integer> nodeIndexes = new ArrayList<Integer>(trees.size()); for(int j = startIndex; j <= endIndex; j++) { Data data = DTWorker.this.trainingData.get(j); nodeIndexes.clear(); if(DTWorker.this.isRF) { for(TreeNode treeNode: trees) { if(treeNode.getNode().getId() == Node.INVALID_INDEX) { nodeIndexes.add(Node.INVALID_INDEX); } else { Node predictNode = predictNodeIndex(treeNode.getNode(), data, false); nodeIndexes.add(predictNode.getId()); } } } if(DTWorker.this.isGBDT) { int currTreeIndex = trees.size() - 1; Node predictNode = predictNodeIndex(trees.get(currTreeIndex).getNode(), data, false); // update node index nodeIndexes.add(predictNode.getId()); } for(Map.Entry<Integer, TreeNode> entry: localTodoNodes.entrySet()) { // only do statistics on effective data Node todoNode = entry.getValue().getNode(); int treeId = entry.getValue().getTreeId(); int currPredictIndex = 0; if(DTWorker.this.isRF) { currPredictIndex = nodeIndexes.get(entry.getValue().getTreeId()); } if(DTWorker.this.isGBDT) { currPredictIndex = nodeIndexes.get(0); } if(todoNode.getId() == currPredictIndex) { List<Integer> features = entry.getValue().getFeatures(); if(features.isEmpty()) { features = getAllValidFeatures(); } for(Integer columnNum: features) { double[] featuerStatistic = localStatistics.get(entry.getKey()) .getFeatureStatistics().get(columnNum); float weight = data.subsampleWeights[treeId % data.subsampleWeights.length]; if(Float.compare(weight, 0f) != 0) { // only compute weight is not 0 short binIndex = data.inputs[DTWorker.this.inputIndexMap.get(columnNum)]; DTWorker.this.impurity.featureUpdate(featuerStatistic, binIndex, data.output, data.significance, weight); } } } } } LOG.debug("Thread computing stats time is {}ms in thread {}", TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - start), Thread.currentThread().getName()); return localStatistics; } }); } int rCnt = 0; while(rCnt < realThreadCount) { try { Map<Integer, NodeStats> currNodeStatsmap = completionService.take().get(); if(rCnt == 0) { statistics = currNodeStatsmap; } else { for(Entry<Integer, NodeStats> entry: statistics.entrySet()) { NodeStats resultNodeStats = entry.getValue(); mergeNodeStats(resultNodeStats, currNodeStatsmap.get(entry.getKey())); } } } catch (ExecutionException e) { throw new RuntimeException(e); } catch (InterruptedException e) { Thread.currentThread().interrupt(); } rCnt += 1; } LOG.debug("Compute stats time is {}ms", TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - start)); LOG.info( "worker count is {}, error is {}, and stats size is {}. weightedTrainCount {}, weightedValidationCount {}, trainError {}, validationError {}", count, trainError, statistics.size(), weightedTrainCount, weightedValidationCount, trainError, validationError); return new DTWorkerParams(weightedTrainCount, weightedValidationCount, trainError, validationError, statistics); } private void mergeNodeStats(NodeStats resultNodeStats, NodeStats nodeStats) { Map<Integer, double[]> featureStatistics = resultNodeStats.getFeatureStatistics(); for(Entry<Integer, double[]> entry: nodeStats.getFeatureStatistics().entrySet()) { double[] statistics = featureStatistics.get(entry.getKey()); for(int i = 0; i < statistics.length; i++) { statistics[i] += entry.getValue()[i]; } } } private Map<Integer, NodeStats> initTodoNodeStats(Map<Integer, TreeNode> todoNodes) { Map<Integer, NodeStats> statistics = new HashMap<Integer, NodeStats>(todoNodes.size(), 1f); for(Map.Entry<Integer, TreeNode> entry: todoNodes.entrySet()) { List<Integer> features = entry.getValue().getFeatures(); if(features.isEmpty()) { features = getAllValidFeatures(); } Map<Integer, double[]> featureStatistics = new HashMap<Integer, double[]>(features.size(), 1f); for(Integer columnNum: features) { ColumnConfig columnConfig = this.columnConfigList.get(columnNum); if(columnConfig.isNumerical()) { // TODO, how to process null bin int featureStatsSize = columnConfig.getBinBoundary().size() * this.impurity.getStatsSize(); featureStatistics.put(columnNum, new double[featureStatsSize]); } else if(columnConfig.isCategorical()) { // the last one is for invalid value category like ?, *, ... int featureStatsSize = (columnConfig.getBinCategory().size() + 1) * this.impurity.getStatsSize(); featureStatistics.put(columnNum, new double[featureStatsSize]); } } NodeStats nodeStats = new NodeStats(entry.getValue().getTreeId(), entry.getValue().getNode().getId(), featureStatistics); statistics.put(entry.getKey(), nodeStats); } return statistics; } @Override protected void postLoad(WorkerContext<DTMasterParams, DTWorkerParams> context) { // need to switch state for read this.trainingData.switchState(); if(validationData != null) { this.validationData.switchState(); } LOG.info(" - # Records of the Master Data Set: {}.", this.count); LOG.info(" - Bagging Sample Rate: {}.", this.modelConfig.getBaggingSampleRate()); LOG.info(" - Bagging With Replacement: {}.", this.modelConfig.isBaggingWithReplacement()); LOG.info(" - Cross Validation Rate: {}.", this.modelConfig.getValidSetRate()); LOG.info(" - # Records of the Training Set: {}.", this.trainingData.size()); if(modelConfig.isRegression() || modelConfig.getTrain().isOneVsAll()) { LOG.info(" - # Positive Bagging Selected Records of the Training Set: {}.", this.positiveSelectedTrainCount); LOG.info(" - # Negative Bagging Selected Records of the Training Set: {}.", this.negativeSelectedTrainCount); LOG.info(" - # Positive Raw Records of the Training Set: {}.", this.positiveTrainCount); LOG.info(" - # Negative Raw Records of the Training Set: {}.", this.negativeTrainCount); } if(validationData != null) { LOG.info(" - # Records of the Validation Set: {}.", this.validationData.size()); if(modelConfig.isRegression() || modelConfig.getTrain().isOneVsAll()) { LOG.info(" - # Positive Records of the Validation Set: {}.", this.positiveValidationCount); LOG.info(" - # Negative Records of the Validation Set: {}.", this.negativeValidationCount); } } } private List<Integer> getAllValidFeatures() { List<Integer> features = new ArrayList<Integer>(); for(ColumnConfig config: columnConfigList) { if(isAfterVarSelect) { if(config.isFinalSelect() && !config.isTarget() && !config.isMeta()) { // only select numerical feature with getBinBoundary().size() larger than 1 // or categorical feature with getBinCategory().size() larger than 0 if((config.isNumerical() && config.getBinBoundary().size() > 1) || (config.isCategorical() && config.getBinCategory().size() > 0)) { features.add(config.getColumnNum()); } } } else { if(!config.isMeta() && !config.isTarget() && CommonUtils.isGoodCandidate(config)) { // only select numerical feature with getBinBoundary().size() larger than 1 // or categorical feature with getBinCategory().size() larger than 0 if((config.isNumerical() && config.getBinBoundary().size() > 1) || (config.isCategorical() && config.getBinCategory().size() > 0)) { features.add(config.getColumnNum()); } } } } return features; } /** * 'binBoundary' is ArrayList in fact, so we can use get method. ["-Infinity", 1d, 4d, ....] * * @param value * the value to be checked * @param binBoundary * the bin boundary list * @return the index in which bin */ public static int getBinIndex(float value, List<Double> binBoundary) { if(binBoundary.size() <= 1) { // feature with binBoundary.size() <= 1 will not be send to worker, while such feature is still loading into // memory, just return the first bin index to avoid exception, while actually such feature isn't used in // GBT/RF. return 0; } // the last bin if positive infinity if(value == Float.POSITIVE_INFINITY) { return binBoundary.size() - 1; } // the first bin if negative infinity if(value == Float.NEGATIVE_INFINITY) { return 0; } int low = 0, high = binBoundary.size() - 1; while(low <= high) { int mid = (low + high) >>> 1; double lowThreshold = binBoundary.get(mid); double highThreshold = mid == binBoundary.size() - 1 ? Double.MAX_VALUE : binBoundary.get(mid + 1); if(value >= lowThreshold && value < highThreshold) { return mid; } if(value >= highThreshold) { low = mid + 1; } else { high = mid - 1; } } return -1; } private Node predictNodeIndex(Node node, Data data, boolean isForErr) { Node currNode = node; Split split = currNode.getSplit(); // if is leaf if(split == null || (currNode.getLeft() == null && currNode.getRight() == null)) { return currNode; } ColumnConfig columnConfig = this.columnConfigList.get(split.getColumnNum()); Node nextNode = null; Integer inputIndex = this.inputIndexMap.get(split.getColumnNum()); if(inputIndex == null) { throw new IllegalStateException("InputIndex should not be null: Split is " + split + ", inputIndexMap is " + this.inputIndexMap + ", data is " + data); } short value = 0; if(columnConfig.isNumerical()) { short binIndex = data.inputs[inputIndex]; value = binIndex; double valueToBinLowestValue = columnConfig.getBinBoundary().get(binIndex); if(valueToBinLowestValue < split.getThreshold()) { nextNode = currNode.getLeft(); } else { nextNode = currNode.getRight(); } } else if(columnConfig.isCategorical()) { short indexValue = (short) (columnConfig.getBinCategory().size()); value = indexValue; if(data.inputs[inputIndex] >= 0 && data.inputs[inputIndex] < (short) (columnConfig.getBinCategory().size())) { indexValue = data.inputs[inputIndex]; } else { // for invalid category, set to last one indexValue = (short) (columnConfig.getBinCategory().size()); } if(split.getLeftOrRightCategories().contains(indexValue)) { nextNode = currNode.getLeft(); } else { nextNode = currNode.getRight(); } Set<Short> childCategories = split.getLeftOrRightCategories(); if(split.isLeft()) { if(childCategories.contains(indexValue)) { nextNode = currNode.getLeft(); } else { nextNode = currNode.getRight(); } } else { if(childCategories.contains(indexValue)) { nextNode = currNode.getRight(); } else { nextNode = currNode.getLeft(); } } } if(nextNode == null) { throw new IllegalStateException("NextNode is null, parent id is " + currNode.getId() + "; parent split is " + split + "; left is " + currNode.getLeft() + "; right is " + currNode.getRight() + "; value is " + value); } return predictNodeIndex(nextNode, data, isForErr); } @Override public void load(GuaguaWritableAdapter<LongWritable> currentKey, GuaguaWritableAdapter<Text> currentValue, WorkerContext<DTMasterParams, DTWorkerParams> context) { this.count += 1; if((this.count) % 5000 == 0) { LOG.info("Read {} records.", this.count); } // hashcode for fixed input split in train and validation long hashcode = 0; short[] inputs = new short[this.inputCount]; float ideal = 0f; float significance = 1f; // use guava Splitter to iterate only once // use NNConstants.NN_DEFAULT_COLUMN_SEPARATOR to replace getModelConfig().getDataSetDelimiter(), super follows // the function in akka mode. int index = 0, inputIndex = 0; for(String input: DEFAULT_SPLITTER.split(currentValue.getWritable().toString())) { if(index == this.columnConfigList.size()) { // do we need to check if not weighted directly set to 1f; if such logic non-weight at first, then // weight, how to process??? if(StringUtils.isBlank(modelConfig.getWeightColumnName())) { significance = 1f; break; } // check here to avoid bad performance in failed NumberFormatUtils.getFloat(input, 1f) significance = input.length() == 0 ? 1f : NumberFormatUtils.getFloat(input, 1f); // if invalid weight, set it to 1f and warning in log if(Float.compare(significance, 0f) < 0) { LOG.warn("The {} record in current worker weight {} is less than 0f, it is invalid, set it to 1.", count, significance); significance = 1f; } // the last field is significance, break here break; } else { ColumnConfig columnConfig = this.columnConfigList.get(index); if(columnConfig != null && columnConfig.isTarget()) { ideal = getFloatValue(input); } else { if(!isAfterVarSelect) { // no variable selected, good candidate but not meta and not target chose if(!columnConfig.isMeta() && !columnConfig.isTarget() && CommonUtils.isGoodCandidate(columnConfig)) { if(columnConfig.isNumerical()) { float floatValue = getFloatValue(input); // cast is safe as we limit max bin to Short.MAX_VALUE short binIndex = (short) getBinIndex(floatValue, columnConfig.getBinBoundary()); inputs[inputIndex] = binIndex; if(!this.inputIndexMap.containsKey(columnConfig.getColumnNum())) { this.inputIndexMap.put(columnConfig.getColumnNum(), inputIndex); } } else if(columnConfig.isCategorical()) { short shortValue = (short) (columnConfig.getBinCategory().size()); if(input.length() == 0) { // empty shortValue = (short) (columnConfig.getBinCategory().size()); } else { Integer categoricalIndex = this.columnCategoryIndexMapping.get( columnConfig.getColumnNum()).get(input); if(categoricalIndex == null) { shortValue = -1; // invalid category, set to -1 for last index } else { // cast is safe as we limit max bin to Short.MAX_VALUE shortValue = (short) (categoricalIndex.intValue()); } if(shortValue == -1) { // not found shortValue = (short) (columnConfig.getBinCategory().size()); } } inputs[inputIndex] = shortValue; if(!this.inputIndexMap.containsKey(columnConfig.getColumnNum())) { this.inputIndexMap.put(columnConfig.getColumnNum(), inputIndex); } } hashcode = hashcode * 31 + input.hashCode(); inputIndex += 1; } } else { // final select some variables but meta and target are not included if(columnConfig != null && !columnConfig.isMeta() && !columnConfig.isTarget() && columnConfig.isFinalSelect()) { if(columnConfig.isNumerical()) { float floatValue = getFloatValue(input); // cast is safe as we limit max bin to Short.MAX_VALUE short binIndex = (short) getBinIndex(floatValue, columnConfig.getBinBoundary()); inputs[inputIndex] = binIndex; if(!this.inputIndexMap.containsKey(columnConfig.getColumnNum())) { this.inputIndexMap.put(columnConfig.getColumnNum(), inputIndex); } } else if(columnConfig.isCategorical()) { // cast is safe as we limit max bin to Short.MAX_VALUE short shortValue = (short) (columnConfig.getBinCategory().size()); if(input.length() == 0) { // empty shortValue = (short) (columnConfig.getBinCategory().size()); } else { Integer categoricalIndex = this.columnCategoryIndexMapping.get( columnConfig.getColumnNum()).get(input); if(categoricalIndex == null) { shortValue = -1; // invalid category, set to -1 for last index } else { // cast is safe as we limit max bin to Short.MAX_VALUE shortValue = (short) (categoricalIndex.intValue()); } if(shortValue == -1) { // not found shortValue = (short) (columnConfig.getBinCategory().size()); } } inputs[inputIndex] = shortValue; if(!this.inputIndexMap.containsKey(columnConfig.getColumnNum())) { this.inputIndexMap.put(columnConfig.getColumnNum(), inputIndex); } } hashcode = hashcode * 31 + input.hashCode(); inputIndex += 1; } } } } index += 1; } if(this.isOneVsAll) { // if one vs all, update target value according to index of target ideal = updateOneVsAllTargetValue(ideal); } // sample negative only logic here if(modelConfig.getTrain().getSampleNegOnly()) { if(this.modelConfig.isFixInitialInput()) { // if fixInitialInput, sample hashcode in 1-sampleRate range out if negative records int startHashCode = (100 / this.modelConfig.getBaggingNum()) * this.trainerId; // here BaggingSampleRate means how many data will be used in training and validation, if it is 0.8, we // should take 1-0.8 to check endHashCode int endHashCode = startHashCode + Double.valueOf((1d - this.modelConfig.getBaggingSampleRate()) * 100).intValue(); if((modelConfig.isRegression() || this.isOneVsAll) // regression or onevsall && (int) (ideal + 0.01d) == 0 // negative record && isInRange(hashcode, startHashCode, endHashCode)) { return; } } else { // if not fixed initial input, and for regression or onevsall multiple classification (regression also). // and if negative record do sampling out if((modelConfig.isRegression() || this.isOneVsAll) // regression or onevsall && (int) (ideal + 0.01d) == 0 // negative record && Double.compare(this.sampelNegOnlyRandom.nextDouble(), this.modelConfig.getBaggingSampleRate()) >= 0) { return; } } } float output = ideal; float predict = ideal; // up sampling logic, just add more weights while bagging sampling rate is still not changed if(modelConfig.isRegression() && isUpSampleEnabled() && Double.compare(ideal, 1d) == 0) { // Double.compare(ideal, 1d) == 0 means positive tags; sample + 1 to avoid sample count to 0 significance = significance * (this.upSampleRng.sample() + 1); } Data data = new Data(inputs, predict, output, output, significance); boolean isValidation = false; if(context.getAttachment() != null && context.getAttachment() instanceof Boolean) { isValidation = (Boolean) context.getAttachment(); } // split into validation and training data set according to validation rate boolean isInTraining = this.addDataPairToDataSet(hashcode, data, isValidation); // do bagging sampling only for training data, if(isInTraining) { data.subsampleWeights = sampleWeights(data.label); // for training data, compute real selected training data according to baggingSampleRate // if gbdt, only the 1st sampling value is used, if rf, use the 1st to denote some information, no need all if(isPositive(data.label)) { this.positiveSelectedTrainCount += data.subsampleWeights[0] * 1L; } else { this.negativeSelectedTrainCount += data.subsampleWeights[0] * 1L; } } else { // for validation data, according bagging sampling logic, we may need to sampling validation data set, while // validation data set are only used to compute validation error, not to do real sampling is ok. } } private float getFloatValue(String input) { // check here to avoid bad performance in failed NumberFormatUtils.getFloat(input, 0f) float floatValue = input.length() == 0 ? 0f : NumberFormatUtils.getFloat(input, 0f); // no idea about why NaN in input data, we should process it as missing value TODO , according to norm type floatValue = (Float.isNaN(floatValue) || Double.isNaN(floatValue)) ? 0f : floatValue; return floatValue; } private boolean isPositive(float value) { return Float.compare(1f, value) == 0 ? true : false; } /** * Add to training set or validation set according to validation rate. * * @param hashcode * the hash code of the data * @param data * data instance * @param isValidation * if it is validation * @return if in training, training is true, others are false. */ protected boolean addDataPairToDataSet(long hashcode, Data data, boolean isValidation) { if(this.isKFoldCV) { int k = this.modelConfig.getTrain().getNumKFold(); if(hashcode % k == this.trainerId) { this.validationData.append(data); if(isPositive(data.label)) { this.positiveValidationCount += 1L; } else { this.negativeValidationCount += 1L; } return false; } else { this.trainingData.append(data); if(isPositive(data.label)) { this.positiveTrainCount += 1L; } else { this.negativeTrainCount += 1L; } return true; } } if(this.isManualValidation) { if(isValidation) { this.validationData.append(data); if(isPositive(data.label)) { this.positiveValidationCount += 1L; } else { this.negativeValidationCount += 1L; } return false; } else { this.trainingData.append(data); if(isPositive(data.label)) { this.positiveTrainCount += 1L; } else { this.negativeTrainCount += 1L; } return true; } } else { if(Double.compare(this.modelConfig.getValidSetRate(), 0d) != 0) { int classValue = (int) (data.label + 0.01f); Random random = null; if(this.isStratifiedSampling) { // each class use one random instance random = validationRandomMap.get(classValue); if(random == null) { random = new Random(); this.validationRandomMap.put(classValue, random); } } else { // all data use one random instance random = validationRandomMap.get(0); if(random == null) { random = new Random(); this.validationRandomMap.put(0, random); } } if(this.modelConfig.isFixInitialInput()) { // for fix initial input, if hashcode%100 is in [start-hashcode, end-hashcode), validation, // otherwise training. start hashcode in different job is different to make sure bagging jobs have // different data. if end-hashcode is over 100, then check if hashcode is in [start-hashcode, 100] // or [0, end-hashcode] int startHashCode = (100 / this.modelConfig.getBaggingNum()) * this.trainerId; int endHashCode = startHashCode + Double.valueOf(this.modelConfig.getValidSetRate() * 100).intValue(); if(isInRange(hashcode, startHashCode, endHashCode)) { this.validationData.append(data); if(isPositive(data.label)) { this.positiveValidationCount += 1L; } else { this.negativeValidationCount += 1L; } return false; } else { this.trainingData.append(data); if(isPositive(data.label)) { this.positiveTrainCount += 1L; } else { this.negativeTrainCount += 1L; } return true; } } else { // not fixed initial input, if random value >= validRate, training, otherwise validation. if(random.nextDouble() >= this.modelConfig.getValidSetRate()) { this.trainingData.append(data); if(isPositive(data.label)) { this.positiveTrainCount += 1L; } else { this.negativeTrainCount += 1L; } return true; } else { this.validationData.append(data); if(isPositive(data.label)) { this.positiveValidationCount += 1L; } else { this.negativeValidationCount += 1L; } return false; } } } else { this.trainingData.append(data); if(isPositive(data.label)) { this.positiveTrainCount += 1L; } else { this.negativeTrainCount += 1L; } return true; } } } private boolean isInRange(long hashcode, int startHashCode, int endHashCode) { // check if in [start, end] or if in [start, 100) and [0, end-100) int hashCodeIn100 = (int) hashcode % 100; if(endHashCode <= 100) { // in range [start, end) return hashCodeIn100 >= startHashCode && hashCodeIn100 < endHashCode; } else { // in range [start, 100) or [0, endHashCode-100) return hashCodeIn100 >= startHashCode || hashCodeIn100 < (endHashCode % 100); } } // isFailoverOrContinuous true failover task, isFailoverOrContinuous false continuous model training private void recoverGBTData(WorkerContext<DTMasterParams, DTWorkerParams> context, float output, float predict, Data data, boolean isFailoverOrContinuous) { final List<TreeNode> trees = this.recoverTrees; if(trees == null) { return; } if(trees.size() >= 1) { // if isSwitchToNextTree == false, iterate all trees except current one to get new predict and // output value; if isSwitchToNextTree == true, iterate all trees except current two trees. // the last tree is a root node, the tree with index size-2 will be called in doCompute method // TreeNode lastTree = trees.get(trees.size() - 1); // if is fail over and trees size over 2, exclude last tree because last tree isn't built full and no need // to update predict value, if for continuous model training, all trees are good and should be finished // updating predict int iterLen = isFailoverOrContinuous ? trees.size() - 1 : trees.size(); for(int i = 0; i < iterLen; i++) { TreeNode currTree = trees.get(i); if(i == 0) { double oldPredict = predictNodeIndex(currTree.getNode(), data, false).getPredict().getPredict(); predict = (float) oldPredict; output = -1f * loss.computeGradient(predict, data.label); } else { // random drop if(this.dropOutRate > 0.0 && dropOutRandom.nextDouble() < this.dropOutRate) { continue; } double oldPredict = predictNodeIndex(currTree.getNode(), data, false).getPredict().getPredict(); predict += (float) (this.learningRate * oldPredict); output = -1f * loss.computeGradient(predict, data.label); } } data.output = output; data.predict = predict; } } private List<TreeNode> recoverCurrentTrees() { FSDataInputStream stream = null; List<TreeNode> trees = null; try { if(!ShifuFileUtils .isFileExists(this.checkpointOutput.toString(), this.modelConfig.getDataSet().getSource())) { return null; } FileSystem fs = ShifuFileUtils.getFileSystemBySourceType(this.modelConfig.getDataSet().getSource()); stream = fs.open(this.checkpointOutput); int treeSize = stream.readInt(); trees = new ArrayList<TreeNode>(treeSize); for(int i = 0; i < treeSize; i++) { TreeNode treeNode = new TreeNode(); treeNode.readFields(stream); trees.add(treeNode); } } catch (IOException e) { throw new GuaguaRuntimeException(e); } finally { org.apache.commons.io.IOUtils.closeQuietly(stream); } return trees; } private float[] sampleWeights(float label) { float[] sampleWeights = null; // sample negative or kFoldCV, sample rate is 1d double sampleRate = (modelConfig.getTrain().getSampleNegOnly() || this.isKFoldCV) ? 1d : modelConfig.getTrain() .getBaggingSampleRate(); int classValue = (int) (label + 0.01f); if(this.treeNum == 1 || (this.isGBDT && !this.gbdtSampleWithReplacement)) { // if tree == 1 or GBDT, don't use with replacement sampling; for GBDT, every time is one tree sampleWeights = new float[1]; Random random = null; if(this.isStratifiedSampling) { random = baggingRandomMap.get(classValue); if(random == null) { random = new Random(); baggingRandomMap.put(classValue, random); } } else { random = baggingRandomMap.get(0); if(random == null) { random = new Random(); baggingRandomMap.put(0, random); } } if(random.nextDouble() <= sampleRate) { sampleWeights[0] = 1f; } else { sampleWeights[0] = 0f; } } else { // if gbdt and gbdtSampleWithReplacement = true, still sampling with replacement sampleWeights = new float[this.treeNum]; if(this.isStratifiedSampling) { PoissonDistribution[] rng = this.baggingRngMap.get(classValue); if(rng == null) { rng = new PoissonDistribution[treeNum]; for(int i = 0; i < treeNum; i++) { rng[i] = new PoissonDistribution(sampleRate); } this.baggingRngMap.put(classValue, rng); } for(int i = 0; i < sampleWeights.length; i++) { sampleWeights[i] = rng[i].sample(); } } else { PoissonDistribution[] rng = this.baggingRngMap.get(0); if(rng == null) { rng = new PoissonDistribution[treeNum]; for(int i = 0; i < treeNum; i++) { rng[i] = new PoissonDistribution(sampleRate); } this.baggingRngMap.put(0, rng); } for(int i = 0; i < sampleWeights.length; i++) { sampleWeights[i] = rng[i].sample(); } } } return sampleWeights; } private float updateOneVsAllTargetValue(float ideal) { // if one vs all, set correlated idea value according to trainerId which means in trainer with id 0, target // 0 is treated with 1, other are 0. Such target value are set to index of tags like [0, 1, 2, 3] compared // with ["a", "b", "c", "d"] return Float.compare(ideal, trainerId) == 0 ? 1f : 0f; } static class Data implements Serializable, Bytable { private static final long serialVersionUID = 903201066309036170L; /** * Inputs for bin index, short is using to compress */ short[] inputs; /** * Original output label and not changed in GBDT */ float label; /** * Output label and maybe changed in GBDT */ volatile float output; volatile float predict; float significance; float[] subsampleWeights = new float[] { 1.0f }; public Data() { this.label = 0; } public Data(short[] inputs, float predict, float output, float label, float significance) { this.inputs = inputs; this.predict = predict; this.output = output; this.label = label; this.significance = significance; } public Data(short[] inputs, float predict, float output, float label, float significance, float[] subsampleWeights) { this.inputs = inputs; this.predict = predict; this.output = output; this.label = label; this.significance = significance; this.subsampleWeights = subsampleWeights; } @Override public void write(DataOutput out) throws IOException { out.writeInt(inputs.length); for(short input: inputs) { out.writeShort(input); } out.writeFloat(output); out.writeFloat(label); out.writeFloat(predict); out.writeFloat(significance); out.writeInt(subsampleWeights.length); for(float sample: subsampleWeights) { out.writeFloat(sample); } } @Override public void readFields(DataInput in) throws IOException { int iLen = in.readInt(); this.inputs = new short[iLen]; for(int i = 0; i < iLen; i++) { this.inputs[i] = in.readShort(); } this.output = in.readFloat(); this.label = in.readFloat(); this.predict = in.readFloat(); this.significance = in.readFloat(); int sLen = in.readInt(); this.subsampleWeights = new float[sLen]; for(int i = 0; i < sLen; i++) { this.subsampleWeights[i] = in.readFloat(); } } /* * (non-Javadoc) * * @see java.lang.Object#toString() */ @Override public String toString() { return "Data [inputs=" + Arrays.toString(inputs) + ", label=" + label + ", output=" + output + ", predict=" + predict + ", significance=" + significance + ", subsampleWeights=" + Arrays.toString(subsampleWeights) + "]"; } } }