/* * Copyright [2013-2014] eBay Software Foundation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package ml.shifu.shifu.core.dtrain.lr; import java.io.IOException; import java.io.PrintWriter; import java.util.Arrays; import java.util.Properties; import java.util.concurrent.atomic.AtomicBoolean; import ml.shifu.guagua.master.BasicMasterInterceptor; import ml.shifu.guagua.master.MasterContext; import ml.shifu.shifu.container.obj.ModelConfig; import ml.shifu.shifu.container.obj.RawSourceData.SourceType; import ml.shifu.shifu.core.dtrain.CommonConstants; import ml.shifu.shifu.core.dtrain.DTrainUtils; import ml.shifu.shifu.core.dtrain.gs.GridSearch; //import ml.shifu.shifu.core.dtrain.nn.NNConstants; import ml.shifu.shifu.fs.ShifuFileUtils; import ml.shifu.shifu.util.CommonUtils; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FSDataOutputStream; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.IOUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * {@link LogisticRegressionOutput} is used to write the final model output to file system. */ public class LogisticRegressionOutput extends BasicMasterInterceptor<LogisticRegressionParams, LogisticRegressionParams> { private static final Logger LOG = LoggerFactory.getLogger(LogisticRegressionOutput.class); private static final double EPSILON = 0.0000001; private String trainerId; private String tmpModelsFolder; /** * Model Config read from HDFS */ private ModelConfig modelConfig; /** * Whether the training is dry training. */ private boolean isDry; /** * A flag: whether params initialized. */ private AtomicBoolean isInit = new AtomicBoolean(false); /** * The minimum test error during model training */ private double minTestError = Double.MAX_VALUE; /** * The best weights that we meet */ private double[] optimizedWeights = null; /** * Progress output stream which is used to write progress to that HDFS file. Should be closed in * {@link #postApplication(MasterContext)}. */ private FSDataOutputStream progressOutput = null; /** * If current mode is cross validation */ private boolean isKFoldCV; /** * If current mode is grid search */ private boolean isGsMode; @Override public void preApplication(MasterContext<LogisticRegressionParams, LogisticRegressionParams> context) { init(context); } @Override public void postIteration(final MasterContext<LogisticRegressionParams, LogisticRegressionParams> context) { if(this.isDry) { // for dry mode, we don't save models files. return; } double currentError = ((modelConfig.getTrain().getValidSetRate() < EPSILON) ? context.getMasterResult() .getTrainError() : context.getMasterResult().getTestError()); // save the weights according the error decreasing if(currentError < this.minTestError) { this.minTestError = currentError; this.optimizedWeights = context.getMasterResult().getParameters(); } // save tmp to hdfs according to raw trainer logic if(context.getCurrentIteration() % DTrainUtils.tmpModelFactor(context.getTotalIteration()) == 0) { Thread tmpNNThread = new Thread(new Runnable() { @Override public void run() { saveTmpModelToHDFS(context.getCurrentIteration(), context.getMasterResult().getParameters()); // save model results for continue model training, if current job is failed, then next running // we can start from this point to save time. // another case for master recovery, if master is failed, read such checkpoint model Path out = new Path(context.getProps().getProperty(CommonConstants.GUAGUA_OUTPUT)); writeModelWeightsToFileSystem(optimizedWeights, out); } }, "saveTmpModelToHDFS thread"); tmpNNThread.setDaemon(true); tmpNNThread.start(); } updateProgressLog(context); } @SuppressWarnings("deprecation") private void updateProgressLog(final MasterContext<LogisticRegressionParams, LogisticRegressionParams> context) { int currentIteration = context.getCurrentIteration(); if(currentIteration == 1) { // first iteration is used for training preparation return; } String progress = new StringBuilder(200).append(" Trainer ").append(this.trainerId).append(" Epoch #") .append(currentIteration - 1).append(" Training Error:") .append(context.getMasterResult().getTrainError()).append(" Validation Error:") .append(context.getMasterResult().getTestError()).append("\n").toString(); try { LOG.debug("Writing progress results to {} {}", context.getCurrentIteration(), progress.toString()); this.progressOutput.write(progress.getBytes("UTF-8")); this.progressOutput.flush(); this.progressOutput.sync(); } catch (IOException e) { LOG.error("Error in write progress log:", e); } } @Override public void postApplication(MasterContext<LogisticRegressionParams, LogisticRegressionParams> context) { IOUtils.closeStream(this.progressOutput); // for dry mode, we don't save models files. if(this.isDry) { return; } if(optimizedWeights == null) { optimizedWeights = context.getMasterResult().getParameters(); } Path out = new Path(context.getProps().getProperty(CommonConstants.GUAGUA_OUTPUT)); writeModelWeightsToFileSystem(optimizedWeights, out); if(this.isKFoldCV || this.isGsMode) { Path valErrOutput = new Path(context.getProps().getProperty(CommonConstants.GS_VALIDATION_ERROR)); writeValErrorToFileSystem(context.getMasterResult().getTestError(), valErrOutput); } IOUtils.closeStream(this.progressOutput); } private void writeValErrorToFileSystem(double valError, Path out) { FSDataOutputStream fos = null; try { fos = FileSystem.get(new Configuration()).create(out); LOG.info("Writing valerror to {}", out); fos.write((valError + "").getBytes("UTF-8")); } catch (IOException e) { LOG.error("Error in writing output.", e); } finally { IOUtils.closeStream(fos); } } /** * Save tmp nn model to HDFS. */ private void saveTmpModelToHDFS(int iteration, double[] weights) { Path out = new Path(DTrainUtils.getTmpModelName(this.tmpModelsFolder, this.trainerId, iteration, modelConfig .getTrain().getAlgorithm().toLowerCase())); writeModelWeightsToFileSystem(weights, out); } private void init(MasterContext<LogisticRegressionParams, LogisticRegressionParams> context) { this.isDry = Boolean.TRUE.toString().equals(context.getProps().getProperty(CommonConstants.SHIFU_DRY_DTRAIN)); if(this.isDry) { return; } if(isInit.compareAndSet(false, true)) { loadConfigFiles(context.getProps()); this.trainerId = context.getProps().getProperty(CommonConstants.SHIFU_TRAINER_ID); this.tmpModelsFolder = context.getProps().getProperty(CommonConstants.SHIFU_TMP_MODELS_FOLDER); Integer kCrossValidation = this.modelConfig.getTrain().getNumKFold(); if(kCrossValidation != null && kCrossValidation > 0) { isKFoldCV = true; } GridSearch gs = new GridSearch(modelConfig.getTrain().getParams()); this.isGsMode = gs.hasHyperParam(); } try { Path progressLog = new Path(context.getProps().getProperty(CommonConstants.SHIFU_DTRAIN_PROGRESS_FILE)); // if the progressLog already exists, that because the master failed, and fail-over // we need to append the log, so that client console can get refreshed. Or console will appear stuck. if(ShifuFileUtils.isFileExists(progressLog, SourceType.HDFS)) { this.progressOutput = FileSystem.get(new Configuration()).append(progressLog); } else { this.progressOutput = FileSystem.get(new Configuration()).create(progressLog); } } catch (IOException e) { LOG.error("Error in create progress log:", e); } } /** * Load all configurations for modelConfig and columnConfigList from source type. Use null check to make sure model * config and column config loaded once. */ private void loadConfigFiles(final Properties props) { try { SourceType sourceType = SourceType.valueOf(props.getProperty(CommonConstants.MODELSET_SOURCE_TYPE, SourceType.HDFS.toString())); this.modelConfig = CommonUtils.loadModelConfig(props.getProperty(CommonConstants.SHIFU_MODEL_CONFIG), sourceType); } catch (IOException e) { throw new RuntimeException(e); } } private void writeModelWeightsToFileSystem(double[] weights, Path out) { if(weights == null || weights.length <= 0) { return; } FSDataOutputStream fos = null; PrintWriter pw = null; try { fos = FileSystem.get(new Configuration()).create(out); LOG.info("Writing results to {}", out); if(out != null) { pw = new PrintWriter(fos); pw.println(Arrays.toString(weights)); } } catch (IOException e) { LOG.error("Error in writing output.", e); } finally { IOUtils.closeStream(pw); } } }