/** * Copyright [2012-2014] PayPal Software Foundation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package ml.shifu.shifu.core.dtrain.nn; import java.io.IOException; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Properties; import java.util.Set; import java.util.concurrent.atomic.AtomicBoolean; import ml.shifu.guagua.master.BasicMasterInterceptor; import ml.shifu.guagua.master.MasterContext; import ml.shifu.shifu.container.obj.ColumnConfig; import ml.shifu.shifu.container.obj.ModelConfig; import ml.shifu.shifu.container.obj.RawSourceData.SourceType; import ml.shifu.shifu.core.dtrain.CommonConstants; import ml.shifu.shifu.core.dtrain.DTrainUtils; import ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork; import ml.shifu.shifu.core.dtrain.dataset.PersistBasicFloatNetwork; import ml.shifu.shifu.core.dtrain.gs.GridSearch; import ml.shifu.shifu.fs.ShifuFileUtils; import ml.shifu.shifu.util.CommonUtils; import org.apache.commons.lang3.StringUtils; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FSDataOutputStream; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.IOUtils; import org.encog.neural.networks.BasicNetwork; import org.encog.persist.EncogDirectoryPersistence; import org.encog.persist.PersistorRegistry; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * {@link NNOutput} is used to write the model output to file system. */ public class NNOutput extends BasicMasterInterceptor<NNParams, NNParams> { private static final Logger LOG = LoggerFactory.getLogger(NNOutput.class); private static final double EPSILON = 0.0000001; /** * Model Config read from HDFS */ private ModelConfig modelConfig; /** * Column Config list read from HDFS */ private List<ColumnConfig> columnConfigList; /** * network */ private BasicNetwork network; /** * Trainer id in bagging jobs */ private String trainerId; /** * Save model to tmp hdfs folder */ private String tmpModelsFolder; /** * Whether the training is dry training. */ private boolean isDry; /** * A flag: whether params initialized. */ private AtomicBoolean isInit = new AtomicBoolean(false); /** * The minimum test error during model training */ private double minTestError = Double.MAX_VALUE; /** * The best weights that we meet */ private double[] optimizedWeights = null; /** * Progress output stream which is used to write progress to that HDFS file. Should be closed in * {@link #postApplication(MasterContext)}. */ private FSDataOutputStream progressOutput = null; /** * If for gs mode, write valid error to output */ private GridSearch gridSearch; /** * Valid model parameters which is unified for both normal training and grid search. */ private Map<String, Object> validParams; private boolean isKFoldCV; /** * Cache subset features with feature index for searching */ protected Set<Integer> subFeatures; @Override public void preApplication(MasterContext<NNParams, NNParams> context) { init(context); } @Override public void postIteration(final MasterContext<NNParams, NNParams> context) { if(this.isDry) { // for dry mode, we don't save models files. return; } double currentError = ((modelConfig.getTrain().getValidSetRate() < EPSILON) ? context.getMasterResult() .getTrainError() : context.getMasterResult().getTestError()); // save the weights according the error decreasing if(currentError < this.minTestError) { this.minTestError = currentError; this.optimizedWeights = context.getMasterResult().getWeights(); } // save tmp to hdfs according to raw trainer logic final int tmpModelFactor = DTrainUtils.tmpModelFactor(context.getTotalIteration()); if(context.getCurrentIteration() % tmpModelFactor == 0) { Thread tmpNNThread = new Thread(new Runnable() { @Override public void run() { saveTmpNNToHDFS(context.getCurrentIteration(), context.getMasterResult().getWeights()); // save model results for continue model training, if current job is failed, then next running // we can start from this point to save time. // another case for master recovery, if master is failed, read such checkpoint model Path out = new Path(context.getProps().getProperty(CommonConstants.GUAGUA_OUTPUT)); writeModelWeightsToFileSystem(optimizedWeights, out); } }, "saveTmpNNToHDFS thread"); tmpNNThread.setDaemon(true); tmpNNThread.start(); } updateProgressLog(context); } @SuppressWarnings("deprecation") private void updateProgressLog(final MasterContext<NNParams, NNParams> context) { int currentIteration = context.getCurrentIteration(); if(context.isFirstIteration()) { // first iteration is used for training preparation return; } String progress = new StringBuilder(200).append(" Trainer ").append(this.trainerId).append(" Epoch #") .append(currentIteration - 1).append(" Training Error:") .append(String.format("%.10f", context.getMasterResult().getTrainError())).append(" Validation Error:") .append(String.format("%.10f", context.getMasterResult().getTestError())).append("\n").toString(); try { LOG.debug("Writing progress results to {} {}", context.getCurrentIteration(), progress.toString()); this.progressOutput.write(progress.getBytes("UTF-8")); this.progressOutput.flush(); this.progressOutput.sync(); } catch (IOException e) { LOG.error("Error in write progress log:", e); } } @Override public void postApplication(MasterContext<NNParams, NNParams> context) { IOUtils.closeStream(this.progressOutput); // for dry mode, we don't save models files. if(this.isDry) { return; } if(optimizedWeights != null) { Path out = new Path(context.getProps().getProperty(CommonConstants.GUAGUA_OUTPUT)); writeModelWeightsToFileSystem(optimizedWeights, out); } if(this.gridSearch.hasHyperParam() || this.isKFoldCV) { Path valErrOutput = new Path(context.getProps().getProperty(CommonConstants.GS_VALIDATION_ERROR)); writeValErrorToFileSystem(context.getMasterResult().getTestError(), valErrOutput); } } private void writeValErrorToFileSystem(double valError, Path out) { FSDataOutputStream fos = null; try { fos = FileSystem.get(new Configuration()).create(out); LOG.info("Writing valerror to {}", out); fos.write((valError + "").getBytes("UTF-8")); } catch (IOException e) { LOG.error("Error in writing output.", e); } finally { IOUtils.closeStream(fos); } } /** * Save tmp nn model to HDFS. */ private void saveTmpNNToHDFS(int iteration, double[] weights) { Path out = new Path(DTrainUtils.getTmpModelName(this.tmpModelsFolder, this.trainerId, iteration, modelConfig .getTrain().getAlgorithm().toLowerCase())); writeModelWeightsToFileSystem(weights, out); } private void init(MasterContext<NNParams, NNParams> context) { this.isDry = Boolean.TRUE.toString().equals(context.getProps().getProperty(CommonConstants.SHIFU_DRY_DTRAIN)); if(this.isDry) { return; } if(isInit.compareAndSet(false, true)) { loadConfigFiles(context.getProps()); this.trainerId = context.getProps().getProperty(CommonConstants.SHIFU_TRAINER_ID); this.tmpModelsFolder = context.getProps().getProperty(CommonConstants.SHIFU_TMP_MODELS_FOLDER); gridSearch = new GridSearch(modelConfig.getTrain().getParams()); validParams = this.modelConfig.getTrain().getParams(); if(gridSearch.hasHyperParam()) { validParams = gridSearch.getParams(Integer.parseInt(trainerId)); LOG.info("Start grid search in nn output with params: {}", validParams); } Integer kCrossValidation = this.modelConfig.getTrain().getNumKFold(); if(kCrossValidation != null && kCrossValidation > 0) { isKFoldCV = true; } initNetwork(context); } try { Path progressLog = new Path(context.getProps().getProperty(CommonConstants.SHIFU_DTRAIN_PROGRESS_FILE)); // if the progressLog already exists, that because the master failed, and fail-over // we need to append the log, so that client console can get refreshed. Or console will appear stuck. if(ShifuFileUtils.isFileExists(progressLog, SourceType.HDFS)) { this.progressOutput = FileSystem.get(new Configuration()).append(progressLog); } else { this.progressOutput = FileSystem.get(new Configuration()).create(progressLog); } } catch (IOException e) { LOG.error("Error in create progress log:", e); } } /** * Load all configurations for modelConfig and columnConfigList from source type. Use null check to make sure model * config and column config loaded once. */ private void loadConfigFiles(final Properties props) { try { SourceType sourceType = SourceType.valueOf(props.getProperty(CommonConstants.MODELSET_SOURCE_TYPE, SourceType.HDFS.toString())); this.modelConfig = CommonUtils.loadModelConfig(props.getProperty(CommonConstants.SHIFU_MODEL_CONFIG), sourceType); this.columnConfigList = CommonUtils.loadColumnConfigList( props.getProperty(CommonConstants.SHIFU_COLUMN_CONFIG), sourceType); } catch (IOException e) { throw new RuntimeException(e); } } @SuppressWarnings("unchecked") private void initNetwork(MasterContext<NNParams, NNParams> context) { int[] inputOutputIndex = DTrainUtils.getInputOutputCandidateCounts(this.columnConfigList); @SuppressWarnings("unused") int inputNodeCount = inputOutputIndex[0] == 0 ? inputOutputIndex[2] : inputOutputIndex[0]; // if is one vs all classification, outputNodeCount is set to 1 int outputNodeCount = modelConfig.isRegression() ? inputOutputIndex[1] : (modelConfig.getTrain().isOneVsAll() ? inputOutputIndex[1] : modelConfig.getTags().size()); int numLayers = (Integer) validParams.get(CommonConstants.NUM_HIDDEN_LAYERS); List<String> actFunc = (List<String>) validParams.get(CommonConstants.ACTIVATION_FUNC); List<Integer> hiddenNodeList = (List<Integer>) validParams.get(CommonConstants.NUM_HIDDEN_NODES); boolean isAfterVarSelect = inputOutputIndex[0] != 0; // cache all feature list for sampling features List<Integer> allFeatures = CommonUtils.getAllFeatureList(columnConfigList, isAfterVarSelect); String subsetStr = context.getProps().getProperty(CommonConstants.SHIFU_NN_FEATURE_SUBSET); if(StringUtils.isBlank(subsetStr)) { this.subFeatures = new HashSet<Integer>(allFeatures); } else { String[] splits = subsetStr.split(","); this.subFeatures = new HashSet<Integer>(); for(String split: splits) { this.subFeatures.add(Integer.parseInt(split)); } } this.network = DTrainUtils.generateNetwork(this.subFeatures.size(), outputNodeCount, numLayers, actFunc, hiddenNodeList, false); ((BasicFloatNetwork) this.network).setFeatureSet(this.subFeatures); PersistorRegistry.getInstance().add(new PersistBasicFloatNetwork()); } private void writeModelWeightsToFileSystem(double[] weights, Path out) { FSDataOutputStream fos = null; try { fos = FileSystem.get(new Configuration()).create(out); LOG.info("Writing results to {}", out); this.network.getFlat().setWeights(weights); if(out != null) { EncogDirectoryPersistence.saveObject(fos, this.network); } } catch (IOException e) { LOG.error("Error in writing output.", e); } finally { IOUtils.closeStream(fos); } } public ModelConfig getModelConfig() { return modelConfig; } }