/** * Copyright [2012-2014] PayPal Software Foundation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package ml.shifu.shifu.core.dtrain.dt; import java.io.DataOutputStream; import java.io.IOException; import java.util.*; import java.util.Map.Entry; import java.util.concurrent.atomic.AtomicBoolean; import java.util.zip.GZIPOutputStream; import ml.shifu.guagua.master.BasicMasterInterceptor; import ml.shifu.guagua.master.MasterContext; import ml.shifu.shifu.container.obj.ColumnConfig; import ml.shifu.shifu.container.obj.ModelConfig; import ml.shifu.shifu.container.obj.ModelTrainConf.ALGORITHM; import ml.shifu.shifu.container.obj.RawSourceData.SourceType; import ml.shifu.shifu.core.dtrain.CommonConstants; import ml.shifu.shifu.core.dtrain.DTrainUtils; import ml.shifu.shifu.core.dtrain.gs.GridSearch; import ml.shifu.shifu.fs.ShifuFileUtils; import ml.shifu.shifu.util.CommonUtils; import org.apache.commons.collections.CollectionUtils; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FSDataOutputStream; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.IOUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * {@link DTOutput} is used to write the model output and error info to file system. */ public class DTOutput extends BasicMasterInterceptor<DTMasterParams, DTWorkerParams> { private static final Logger LOG = LoggerFactory.getLogger(DTOutput.class); /** * Model Config read from HDFS */ private ModelConfig modelConfig; /** * Id for current guagua job, starting from 0, 1, 2 */ private String trainerId; /** * Tmp model folder to save tmp models */ private String tmpModelsFolder; /** * A flag: whether params initialized. */ private AtomicBoolean isInit = new AtomicBoolean(false); /** * Progress output stream which is used to write progress to that HDFS file. Should be closed in * {@link #postApplication(MasterContext)}. */ private FSDataOutputStream progressOutput = null; /** * If for random forest running, this is default for such master. */ private boolean isRF = true; /** * If gradient boost decision tree, for GBDT, each time a tree is trained, next train is trained by gradient label * from previous tree. */ private boolean isGBDT = false; /** * If for grid search, store validation error besides model files. */ private boolean isGsMode; /** * Valid training parameters including grid search */ private Map<String, Object> validParams; /** * ColumnConfig list reference */ private List<ColumnConfig> columnConfigList; /** * input count */ private int inputCount; /** * Number of trees for both RF and GBDT */ private Integer treeNum; /** * If k-fold cross validation */ private boolean isKFoldCV; @Override public void preApplication(MasterContext<DTMasterParams, DTWorkerParams> context) { init(context); } @Override public void postIteration(final MasterContext<DTMasterParams, DTWorkerParams> context) { long start = System.currentTimeMillis(); // save tmp to hdfs according to raw trainer logic final int tmpModelFactor = DTrainUtils.tmpModelFactor(context.getTotalIteration()); if(isRF) { if(context.getCurrentIteration() % (tmpModelFactor * 2) == 0) { Thread tmpModelPersistThread = new Thread(new Runnable() { @Override public void run() { // save model results for continue model training, if current job is failed, then next running // we can start from this point to save time. // another case for master recovery, if master is failed, read such checkpoint model Path out = new Path(context.getProps().getProperty(CommonConstants.GUAGUA_OUTPUT)); writeModelToFileSystem(context.getMasterResult().getTrees(), out); saveTmpModelToHDFS(context.getCurrentIteration(), context.getMasterResult().getTrees()); } }, "saveTmpNNToHDFS thread"); tmpModelPersistThread.setDaemon(true); tmpModelPersistThread.start(); } } else if(isGBDT) { // for gbdt, only store trees are all built well if(this.treeNum >= 10 && context.getMasterResult().isSwitchToNextTree() && (context.getMasterResult().getTmpTrees().size() - 1) % (this.treeNum / 10) == 0) { final List<TreeNode> trees = context.getMasterResult().getTmpTrees(); if(trees.size() > 1) { Thread tmpModelPersistThread = new Thread(new Runnable() { @Override public void run() { List<TreeNode> subTrees = trees.subList(0, trees.size() - 1); // save model results for continue model training, if current job is failed, then next // running we can start from this point to save time. // another case for master recovery, if master is failed, read such checkpoint model Path out = new Path(context.getProps().getProperty(CommonConstants.GUAGUA_OUTPUT)); writeModelToFileSystem(subTrees, out); // last one is newest one with only ROOT node, should be excluded saveTmpModelToHDFS(subTrees.size(), subTrees); } }, "saveTmpNNToHDFS thread"); tmpModelPersistThread.setDaemon(true); tmpModelPersistThread.start(); } } } updateProgressLog(context); LOG.debug("DT output post iteration time is {}ms", (System.currentTimeMillis() - start)); } @SuppressWarnings("deprecation") private void updateProgressLog(final MasterContext<DTMasterParams, DTWorkerParams> context) { int currentIteration = context.getCurrentIteration(); if(context.isFirstIteration()) { // first iteration is used for training preparation return; } double trainError = context.getMasterResult().getTrainError() / context.getMasterResult().getTrainCount(); double validationError = context.getMasterResult().getValidationCount() == 0d ? 0d : context.getMasterResult() .getValidationError() / context.getMasterResult().getValidationCount(); String info = ""; if(this.isGBDT) { int treeSize = 0; if(context.getMasterResult().isSwitchToNextTree() || context.getMasterResult().isHalt()) { treeSize = context.getMasterResult().isSwitchToNextTree() ? (context.getMasterResult().getTmpTrees() .size() - 1) : (context.getMasterResult().getTmpTrees().size()); info = new StringBuilder(200) .append("Trainer ") .append(this.trainerId) .append(" Iteration #") .append(currentIteration - 1) .append(" Training Error: ") .append((Double.isNaN(trainError) || trainError == 0d) ? "N/A" : String.format("%.10f", trainError)).append(" Validation Error: ") .append(validationError == 0d ? "N/A" : String.format("%.10f", validationError)) .append("; Tree ").append(treeSize).append(" is finished. \n").toString(); } else { int nextDepth = context.getMasterResult().getTreeDepth().get(0); info = new StringBuilder(200) .append("Trainer ") .append(this.trainerId) .append(" Iteration #") .append(currentIteration - 1) .append(" Training Error: ") .append((Double.isNaN(trainError) || trainError == 0d) ? "N/A" : String.format("%.10f", trainError)).append(" Validation Error: ") .append(validationError == 0d ? "N/A" : String.format("%.10f", validationError)) .append("; will work on depth ").append(nextDepth).append(". \n").toString(); } } if(this.isRF) { if(trainError != 0d) { List<Integer> treeDepth = context.getMasterResult().getTreeDepth(); if(treeDepth.size() == 0) { info = new StringBuilder(200).append("Trainer ").append(this.trainerId).append(" Iteration #") .append(currentIteration - 1).append(" Training Error: ") .append(trainError == 0d ? "N/A" : String.format("%.10f", trainError)) .append(" Validation Error: ") .append(validationError == 0d ? "N/A" : String.format("%.10f", validationError)) .append("\n").toString(); } else { info = new StringBuilder(200).append("Trainer ").append(this.trainerId).append(" Iteration #") .append(currentIteration - 1).append(" Training Error: ") .append(trainError == 0d ? "N/A" : String.format("%.10f", trainError)) .append(" Validation Error: ") .append(validationError == 0d ? "N/A" : String.format("%.10f", validationError)) .append("; will work on depth ").append(toListString(treeDepth)).append(". \n").toString(); } } } if(info.length() > 0) { try { LOG.debug("Writing progress results to {} {}", context.getCurrentIteration(), info.toString()); this.progressOutput.write(info.getBytes("UTF-8")); this.progressOutput.flush(); this.progressOutput.sync(); } catch (IOException e) { LOG.error("Error in write progress log:", e); } } } /** * Show -1 as N/A which means not work on such iteration */ private String toListString(List<Integer> list) { Iterator<Integer> i = list.iterator(); if(!i.hasNext()) { return "[]"; } StringBuilder sb = new StringBuilder(); sb.append('['); for(;;) { Integer e = i.next(); sb.append(e == null || e == -1 ? "N/A" : e); if(!i.hasNext()) { return sb.append(']').toString(); } sb.append(", "); } } @Override public void postApplication(MasterContext<DTMasterParams, DTWorkerParams> context) { List<TreeNode> trees = context.getMasterResult().getTrees(); if(this.isGBDT) { trees = context.getMasterResult().getTmpTrees(); } if(LOG.isDebugEnabled()) { LOG.debug("final trees", trees.toString()); } Path out = new Path(context.getProps().getProperty(CommonConstants.GUAGUA_OUTPUT)); writeModelToFileSystem(trees, out); if(this.isGsMode || this.isKFoldCV) { Path valErrOutput = new Path(context.getProps().getProperty(CommonConstants.GS_VALIDATION_ERROR)); writeValErrorToFileSystem(context.getMasterResult().getValidationError() / context.getMasterResult().getValidationCount(), valErrOutput); } IOUtils.closeStream(this.progressOutput); } private void writeValErrorToFileSystem(double valError, Path out) { FSDataOutputStream fos = null; try { fos = FileSystem.get(new Configuration()).create(out); LOG.info("Writing valerror to {}", out); fos.write((valError + "").getBytes("UTF-8")); } catch (IOException e) { LOG.error("Error in writing output.", e); } finally { IOUtils.closeStream(fos); } } private void writeModelToFileSystem(List<TreeNode> trees, Path out) { DataOutputStream fos = null; try { fos = new DataOutputStream(new GZIPOutputStream(FileSystem.get(new Configuration()).create(out))); LOG.info("Writing {} trees to {}.", trees.size(), out); // version fos.writeInt(CommonConstants.TREE_FORMAT_VERSION); fos.writeUTF(modelConfig.getAlgorithm()); fos.writeUTF(this.validParams.get("Loss").toString()); fos.writeBoolean(this.modelConfig.isClassification()); fos.writeBoolean(this.modelConfig.getTrain().isOneVsAll()); fos.writeInt(this.inputCount); Map<Integer, String> columnIndexNameMapping = new HashMap<Integer, String>(); Map<Integer, List<String>> columnIndexCategoricalListMapping = new HashMap<Integer, List<String>>(); Map<Integer, Double> numericalMeanMapping = new HashMap<Integer, Double>(); for(ColumnConfig columnConfig: this.columnConfigList) { if(columnConfig.isFinalSelect()) { columnIndexNameMapping.put(columnConfig.getColumnNum(), columnConfig.getColumnName()); } if(columnConfig.isCategorical() && CollectionUtils.isNotEmpty(columnConfig.getBinCategory())) { columnIndexCategoricalListMapping.put(columnConfig.getColumnNum(), columnConfig.getBinCategory()); } if(columnConfig.isNumerical() && columnConfig.getMean() != null) { numericalMeanMapping.put(columnConfig.getColumnNum(), columnConfig.getMean()); } } if(columnIndexNameMapping.size() == 0) { for(ColumnConfig columnConfig: this.columnConfigList) { if(CommonUtils.isGoodCandidate(columnConfig)) { columnIndexNameMapping.put(columnConfig.getColumnNum(), columnConfig.getColumnName()); } } } // serialize numericalMeanMapping fos.writeInt(numericalMeanMapping.size()); for(Entry<Integer, Double> entry: numericalMeanMapping.entrySet()) { fos.writeInt(entry.getKey()); // for some feature, it is null mean value, it is not selected, just set to 0d to avoid NPE fos.writeDouble(entry.getValue() == null ? 0d : entry.getValue()); } // serialize columnIndexNameMapping fos.writeInt(columnIndexNameMapping.size()); for(Entry<Integer, String> entry: columnIndexNameMapping.entrySet()) { fos.writeInt(entry.getKey()); fos.writeUTF(entry.getValue()); } // serialize columnIndexCategoricalListMapping fos.writeInt(columnIndexCategoricalListMapping.size()); for(Entry<Integer, List<String>> entry: columnIndexCategoricalListMapping.entrySet()) { List<String> categories = entry.getValue(); if(categories != null) { fos.writeInt(entry.getKey()); fos.writeInt(categories.size()); for(String category: categories) { fos.writeUTF(category); } } } Map<Integer, Integer> columnMapping = getColumnMapping(); fos.writeInt(columnMapping.size()); for(Entry<Integer, Integer> entry: columnMapping.entrySet()) { fos.writeInt(entry.getKey()); fos.writeInt(entry.getValue()); } int treeLength = trees.size(); fos.writeInt(treeLength); for(TreeNode treeNode: trees) { treeNode.write(fos); } } catch (IOException e) { LOG.error("Error in writing output.", e); } finally { IOUtils.closeStream(fos); } } private Map<Integer, Integer> getColumnMapping() { Map<Integer, Integer> columnMapping = new HashMap<Integer, Integer>(columnConfigList.size(), 1f); int[] inputOutputIndex = DTrainUtils.getNumericAndCategoricalInputAndOutputCounts(columnConfigList); boolean isAfterVarSelect = inputOutputIndex[3] == 1 ? true : false; int index = 0; for(int i = 0; i < columnConfigList.size(); i++) { ColumnConfig columnConfig = columnConfigList.get(i); if(!isAfterVarSelect) { if(!columnConfig.isMeta() && !columnConfig.isTarget() && CommonUtils.isGoodCandidate(columnConfig)) { columnMapping.put(columnConfig.getColumnNum(), index); index += 1; } } else { if(columnConfig != null && !columnConfig.isMeta() && !columnConfig.isTarget() && columnConfig.isFinalSelect()) { columnMapping.put(columnConfig.getColumnNum(), index); index += 1; } } } return columnMapping; } /** * Save tmp model to HDFS. */ private void saveTmpModelToHDFS(int iteration, List<TreeNode> trees) { Path out = new Path(DTrainUtils.getTmpModelName(this.tmpModelsFolder, this.trainerId, iteration, modelConfig .getTrain().getAlgorithm().toLowerCase())); writeModelToFileSystem(trees, out); } private void init(MasterContext<DTMasterParams, DTWorkerParams> context) { if(isInit.compareAndSet(false, true)) { loadConfigFiles(context.getProps()); this.trainerId = context.getProps().getProperty(CommonConstants.SHIFU_TRAINER_ID); GridSearch gs = new GridSearch(modelConfig.getTrain().getParams()); this.isGsMode = gs.hasHyperParam(); this.validParams = modelConfig.getParams(); if(isGsMode) { this.validParams = gs.getParams(Integer.parseInt(trainerId)); } Integer kCrossValidation = this.modelConfig.getTrain().getNumKFold(); if(kCrossValidation != null && kCrossValidation > 0) { isKFoldCV = true; } this.tmpModelsFolder = context.getProps().getProperty(CommonConstants.SHIFU_TMP_MODELS_FOLDER); this.isRF = ALGORITHM.RF.toString().equalsIgnoreCase(modelConfig.getAlgorithm()); this.isGBDT = ALGORITHM.GBT.toString().equalsIgnoreCase(modelConfig.getAlgorithm()); int[] inputOutputIndex = DTrainUtils.getNumericAndCategoricalInputAndOutputCounts(this.columnConfigList); // numerical + categorical = # of all input this.inputCount = inputOutputIndex[0] + inputOutputIndex[1]; try { Path progressLog = new Path(context.getProps().getProperty(CommonConstants.SHIFU_DTRAIN_PROGRESS_FILE)); // if the progressLog already exists, that because the master failed, and fail-over // we need to append the log, so that client console can get refreshed. Or console will appear stuck. if(ShifuFileUtils.isFileExists(progressLog, SourceType.HDFS)) { this.progressOutput = FileSystem.get(new Configuration()).append(progressLog); } else { this.progressOutput = FileSystem.get(new Configuration()).create(progressLog); } } catch (IOException e) { LOG.error("Error in create progress log:", e); } this.treeNum = Integer.valueOf(validParams.get("TreeNum").toString());; } } /** * Load all configurations for modelConfig and columnConfigList from source type. Use null check to make sure model * config and column config loaded once. */ private void loadConfigFiles(final Properties props) { try { SourceType sourceType = SourceType.valueOf(props.getProperty(CommonConstants.MODELSET_SOURCE_TYPE, SourceType.HDFS.toString())); this.modelConfig = CommonUtils.loadModelConfig(props.getProperty(CommonConstants.SHIFU_MODEL_CONFIG), sourceType); this.columnConfigList = CommonUtils.loadColumnConfigList( props.getProperty(CommonConstants.SHIFU_COLUMN_CONFIG), sourceType); } catch (IOException e) { throw new RuntimeException(e); } } }