/* * Copyright [2013-2014] eBay Software Foundation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package ml.shifu.shifu.core.dtrain.lr; import java.io.IOException; import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.Properties; import java.util.concurrent.atomic.AtomicBoolean; import ml.shifu.guagua.GuaguaRuntimeException; import ml.shifu.guagua.master.AbstractMasterComputable; import ml.shifu.guagua.master.MasterContext; import ml.shifu.guagua.util.NumberFormatUtils; import ml.shifu.shifu.container.obj.ColumnConfig; import ml.shifu.shifu.container.obj.ModelConfig; import ml.shifu.shifu.container.obj.RawSourceData.SourceType; import ml.shifu.shifu.core.ConvergeJudger; import ml.shifu.shifu.core.LR; import ml.shifu.shifu.core.dtrain.CommonConstants; import ml.shifu.shifu.core.dtrain.DTrainUtils; import ml.shifu.shifu.core.dtrain.RegulationLevel; import ml.shifu.shifu.core.dtrain.Weight; import ml.shifu.shifu.core.dtrain.gs.GridSearch; import ml.shifu.shifu.fs.ShifuFileUtils; import ml.shifu.shifu.util.CommonUtils; import org.apache.hadoop.fs.Path; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * {@link LogisticRegressionMaster} defines logic to update global <a * href=http://en.wikipedia.org/wiki/Logistic_regression >logistic regression</a> model. * * <p> * At first iteration, master builds a random model then send to all workers to start computing. This is to make all * workers use the same model at the starting time. * * <p> * At other iterations, master works: * <ul> * <li>1. Accumulate all gradients from workers.</li> * <li>2. Update global models by using accumulated gradients.</li> * <li>3. Send new global model to workers by returning model parameters.</li> * </ul> * * <p> * L1 and l2 regulations are supported by configuration: RegularizedConstant in model params of ModelConfig.json. */ public class LogisticRegressionMaster extends AbstractMasterComputable<LogisticRegressionParams, LogisticRegressionParams> { private static final Logger LOG = LoggerFactory.getLogger(LogisticRegressionMaster.class); /** * Input column number without bias */ private int inputNum; /** * This is the model weights in LR which will be updated each iteration TODO, if master is failed, how to recovery */ private double[] weights; /** * Learning rate configured by user in params */ private double learningRate = 1.0d; /** * Regulation parameter for l1 or l2 */ private double regularizedConstant = 0.0d; /** * To calculate weights according to last weights and accumulated gradients */ private Weight weightCalculator = null; /** * Model configuration loaded from configuration file. */ private ModelConfig modelConfig; /** * Column Config list read from HDFS */ private List<ColumnConfig> columnConfigList; /** * Convergence threshold setting by user in ModelConfig.json. */ private double convergenceThreshold; /** * Convergence judger instance for convergence checking. */ private ConvergeJudger judger = new ConvergeJudger(); /** * Propagation type for lr model setting: Q, B, R, C */ private String propagation = "Q"; /** * Whether some configurations are initialized */ private AtomicBoolean isInitialized = new AtomicBoolean(false); /** * Whether to enable continuous model training based on existing models. */ private boolean isContinuousEnabled = false; /** * Validation tolerance which is for early stop, by default it is 0d which means early stop is not enabled. */ private double validationTolerance = 0d; /** * The best validation error for error computing */ private double bestValidationError = Double.MAX_VALUE; /** * Valid params specially for grid search */ private Map<String, Object> validParams; @Override public void init(MasterContext<LogisticRegressionParams, LogisticRegressionParams> context) { loadConfigFiles(context.getProps()); int trainerId = Integer.valueOf(context.getProps().getProperty(CommonConstants.SHIFU_TRAINER_ID, "0")); GridSearch gs = new GridSearch(modelConfig.getTrain().getParams()); validParams = this.modelConfig.getTrain().getParams(); if(gs.hasHyperParam()) { validParams = gs.getParams(trainerId); LOG.info("Start grid search master with params: {}", validParams); } this.learningRate = Double.valueOf(this.validParams.get(CommonConstants.LR_LEARNING_RATE).toString()); int[] inputOutputIndex = DTrainUtils.getInputOutputCandidateCounts(this.columnConfigList); this.inputNum = inputOutputIndex[0] == 0 ? inputOutputIndex[2] : inputOutputIndex[0]; Object vtObj = this.validParams.get("ValidationTolerance"); if(vtObj != null) { try { validationTolerance = Double.parseDouble(vtObj.toString()); LOG.warn("Validation by tolerance is enabled with value {}.", validationTolerance); } catch (NumberFormatException ee) { validationTolerance = 0d; LOG.warn( "Validation by tolerance isn't enabled because of non numerical value of ValidationTolerance: {}.", vtObj); } } else { LOG.info("Validation by tolerance isn't enabled."); } Double threshold = this.modelConfig.getTrain().getConvergenceThreshold(); this.convergenceThreshold = threshold == null ? 0d : threshold.doubleValue(); LOG.info("Convergence threshold in master is :{}", this.convergenceThreshold); Object pObject = validParams.get(CommonConstants.PROPAGATION); this.propagation = pObject == null ? "Q" : (String) pObject; Object rconstant = validParams.get(CommonConstants.LR_REGULARIZED_CONSTANT); this.regularizedConstant = NumberFormatUtils.getDouble(rconstant == null ? "" : rconstant.toString(), 0d); this.isContinuousEnabled = Boolean.TRUE.toString().equalsIgnoreCase( context.getProps().getProperty(CommonConstants.CONTINUOUS_TRAINING)); LOG.info("continuousEnabled: {}", this.isContinuousEnabled); // not initialized and not first iteration, should be fault tolerence, recover state in LogisticRegressionMaster if(!context.isFirstIteration()) { LogisticRegressionParams lastMasterResult = context.getMasterResult(); if(lastMasterResult != null && lastMasterResult.getParameters() != null) { // recover state in current master computable and return to workers this.weights = lastMasterResult.getParameters(); } else { // no weights, restarted from the very beginning, this may not happen this.weights = initWeights().getParameters(); } } } private LogisticRegressionParams initModelParams(LR loadModel) { LogisticRegressionParams params = new LogisticRegressionParams(); params.setTrainError(0); params.setTestError(0); // prevent null point this.weights = loadModel.getWeights(); params.setParameters(this.weights); return params; } private LogisticRegressionParams initOrRecoverParams( MasterContext<LogisticRegressionParams, LogisticRegressionParams> context) { LOG.info("read from existing model"); LogisticRegressionParams params = null; // read existing model weights try { Path modelPath = new Path(context.getProps().getProperty(CommonConstants.GUAGUA_OUTPUT)); LR existingModel = (LR) CommonUtils.loadModel(modelConfig, modelPath, ShifuFileUtils.getFileSystemBySourceType(this.modelConfig.getDataSet().getSource())); if(existingModel == null) { params = initWeights(); LOG.info("Starting to train model from scratch."); } else { params = initModelParams(existingModel); LOG.info("Starting to train model from existing model {}.", modelPath); } } catch (IOException e) { throw new GuaguaRuntimeException(e); } return params; } @Override public LogisticRegressionParams doCompute(MasterContext<LogisticRegressionParams, LogisticRegressionParams> context) { if(isInitialized.compareAndSet(false, true)) { // not initialized and not first iteration, should be fault tolerance, recover state in // LogisticRegressionMaster if(!context.isFirstIteration()) { LogisticRegressionParams lastMasterResult = context.getMasterResult(); if(lastMasterResult != null && lastMasterResult.getParameters() != null) { // recover state in current master computable and return to workers this.weights = lastMasterResult.getParameters(); return lastMasterResult; } else { // no weights, restarted from the very beginning, this may not happen return initWeights(); } } } if(context.isFirstIteration()) { if(this.isContinuousEnabled) { return initOrRecoverParams(context); } else { return initWeights(); } } else { // append bias double[] gradients = new double[this.inputNum + 1]; double trainError = 0.0d, testError = 0d; long trainSize = 0, testSize = 0; for(LogisticRegressionParams param: context.getWorkerResults()) { if(param != null) { for(int i = 0; i < gradients.length; i++) { gradients[i] += param.getParameters()[i]; } trainError += param.getTrainError(); testError += param.getTestError(); trainSize += param.getTrainSize(); testSize += param.getTestSize(); } } if(this.weightCalculator == null) { this.weightCalculator = new Weight(weights.length, trainSize, learningRate, this.propagation, this.regularizedConstant, RegulationLevel.to(this.validParams .get(CommonConstants.REG_LEVEL_KEY)), 0d); } else { this.weightCalculator.setNumTrainSize(trainSize); } double[] oldWeights = Arrays.copyOf(this.weights, this.weights.length); this.weights = this.weightCalculator.calculateWeights(this.weights, gradients); double finalTrainError = trainError / trainSize; double finalTestError = testError / testSize; LOG.info("Iteration {} with train error {}, test error {}", context.getCurrentIteration(), finalTrainError, finalTestError); LogisticRegressionParams lrParams = new LogisticRegressionParams(weights, finalTrainError, finalTestError, trainSize, testSize); boolean vtTriggered = false; // if validationTolerance == 0d, means vt check is not enabled if(validationTolerance > 0d) { double weightSumSquare = 0d; double diffWeightSumSquare = 0d; for(int i = 0; i < weights.length; i++) { weightSumSquare += Math.pow(weights[i], 2); diffWeightSumSquare += Math.pow(weights[i] - oldWeights[i], 2); } if(Math.pow(diffWeightSumSquare, 0.5) < this.validationTolerance * Math.max(Math.pow(weightSumSquare, 0.5), 1d)) { LOG.info("Debug: diffWeightSumSquare {}, weightSumSquare {}, validationTolerance {}", Math.pow(diffWeightSumSquare, 0.5), Math.pow(weightSumSquare, 0.5), validationTolerance); vtTriggered = true; } } if(finalTestError < this.bestValidationError) { this.bestValidationError = finalTestError; } if(judger.judge(finalTrainError + finalTestError / 2, convergenceThreshold) || vtTriggered) { LOG.info("LRMaster compute iteration {} converged !", context.getCurrentIteration()); lrParams.setHalt(true); } else { LOG.debug("LRMaster compute iteration {} not converged yet !", context.getCurrentIteration()); } return lrParams; } } private LogisticRegressionParams initWeights() { weights = new double[this.inputNum + 1]; for(int i = 0; i < weights.length; i++) { weights[i] = nextDouble(-1, 1); } return new LogisticRegressionParams(weights); } private void loadConfigFiles(final Properties props) { try { SourceType sourceType = SourceType.valueOf(props.getProperty(CommonConstants.MODELSET_SOURCE_TYPE, SourceType.HDFS.toString())); this.modelConfig = CommonUtils.loadModelConfig(props.getProperty(CommonConstants.SHIFU_MODEL_CONFIG), sourceType); this.columnConfigList = CommonUtils.loadColumnConfigList( props.getProperty(CommonConstants.SHIFU_COLUMN_CONFIG), sourceType); } catch (IOException e) { throw new RuntimeException(e); } } public final double nextDouble(final double min, final double max) { final double range = max - min; return (range * Math.random()) + min; } }