/**
* Copyright [2012-2014] PayPal Software Foundation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.shifu.shifu.core.dtrain.nn;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import ml.shifu.guagua.GuaguaRuntimeException;
import ml.shifu.guagua.master.AbstractMasterComputable;
import ml.shifu.guagua.master.MasterContext;
import ml.shifu.guagua.util.NumberFormatUtils;
import ml.shifu.shifu.container.obj.ColumnConfig;
import ml.shifu.shifu.container.obj.ModelConfig;
import ml.shifu.shifu.container.obj.RawSourceData.SourceType;
import ml.shifu.shifu.core.ConvergeJudger;
import ml.shifu.shifu.core.dtrain.CommonConstants;
import ml.shifu.shifu.core.dtrain.DTrainUtils;
import ml.shifu.shifu.core.dtrain.RegulationLevel;
import ml.shifu.shifu.core.dtrain.Weight;
import ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork;
import ml.shifu.shifu.core.dtrain.gs.GridSearch;
import ml.shifu.shifu.fs.ShifuFileUtils;
import ml.shifu.shifu.util.CommonUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.hadoop.fs.Path;
import org.encog.neural.networks.BasicNetwork;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* {@link NNMaster} is used to accumulate all workers NN parameters.
*
* <p>
* We accumulate all gradients from workers to calculate model weights. And set weights to workers. Then workers use
* weights to set their models and train for another iteration.
*
* <p>
* This logic follows Encog multi-core implementation.
*
* <p>
* Make sure workers and master use the same initialization weights.
*/
public class NNMaster extends AbstractMasterComputable<NNParams, NNParams> {
private static final Logger LOG = LoggerFactory.getLogger(NNMaster.class);
/**
* Global master NN parameters instance which is used to update model weights by using accumulated gradients.
*/
private NNParams globalNNParams = new NNParams();
/**
* Model configuration loaded from configuration file.
*/
private ModelConfig modelConfig;
/**
* To calculate weights according to last weights and accumulated gradients
*/
private Weight weightCalculator = null;
/**
* Column configuration loaded from configuration file.
*/
private List<ColumnConfig> columnConfigList;
/**
* Propagation type for Encog neural network model setting: Q, B, R, C
*/
private String propagation = "Q";
/**
* Raw learning rate set by model configuration.
*/
private Double rawLearningRate = 0.1d;
/**
* Real learning rate used to train nn model
*/
private Double learningRate = 0.1d;
/**
* L1 and L2 regurized constant
*/
private double regularizedConstant = 0.0d;
/**
* Learning decay setting to decrease learning rate iteration by iteration. Common setting value is from 0 to 0.1
*/
private double learningDecay = 0d;
/**
* Whether to enable continuous model training based on existing models.
*/
private boolean isContinuousEnabled = false;
/**
* Convergence threshold setting.
*/
private double convergenceThreshold = 0d;
/**
* Convergence judger instance for convergence checking.
*/
private ConvergeJudger judger = new ConvergeJudger();
/**
* Valid params specially for grid search
*/
private Map<String, Object> validParams;
/**
* Validation tolerance which is for early stop, by default it is 0d which means early stop is not enabled.
*/
private double validationTolerance = 0d;
/**
* The best validation error for error computing
*/
private double bestValidationError = Double.MAX_VALUE;
/**
* Dropout rate which is in [0, 1], default it is 0
*/
private double dropoutRate = 0d;
/**
* Cache all features with feature index for searching
*/
private List<Integer> allFeatures;
/**
* Cache subset features with feature index for searching
*/
private List<Integer> subFeatures;
/**
* If variables are selected, if not, select variables with good candidate.
*/
private boolean isAfterVarSelect;
@Override
public NNParams doCompute(MasterContext<NNParams, NNParams> context) {
if(context.isFirstIteration()) {
// For first step, we not only initialize whole context but also return weights to master to make sure all
// workers and master are using the same weights.
NNParams params = null;
if(this.isContinuousEnabled) {
params = initOrRecoverParams(context);
} else {
// first iteration is used to set initial weights
params = initWeights();
LOG.info("Starting to train model from scratch.");
}
// should be set here to make sure master and workers use the same weights
this.globalNNParams.setWeights(params.getWeights());
// for continuous model training, here can be optimized by return null and load model weights in worker by
// reading HDFS.
return params;
}
if(context.getWorkerResults() == null) {
throw new IllegalArgumentException("workers' results are null.");
}
double totalTestError = 0;
double totalTrainError = 0;
int size = 0;
// before accumulate, reset gradients and train size
this.globalNNParams.reset();
long totalCount = 0L;
int totalWorkerCount = 0;
for(NNParams nn: context.getWorkerResults()) {
totalTestError += nn.getTestError();
totalTrainError += nn.getTrainError();
this.globalNNParams.accumulateGradients(nn.getGradients());
this.globalNNParams.accumulateTrainSize(nn.getTrainSize());
totalCount += nn.getCount();
// original worker count before combinable
totalWorkerCount += nn.getWrCount();
size++;
}
LOG.debug("ELM gradients debug for 0 gradient {}", this.globalNNParams.getGradients()[0]);
LOG.debug("Total Count is {}. totalWorkerCount is {}", totalCount, totalWorkerCount);
// worker result size is 0. throw exception because shouldn't happen
if(size == 0) {
throw new IllegalArgumentException("workers' results are empty.");
}
// initialize weightCalCulater.
if(this.weightCalculator == null) {
this.learningRate = this.rawLearningRate;
this.weightCalculator = new Weight(this.globalNNParams.getGradients().length,
this.globalNNParams.getTrainSize(), learningRate, propagation, this.regularizedConstant,
RegulationLevel.to(this.validParams.get(CommonConstants.REG_LEVEL_KEY)), this.dropoutRate);
} else {
this.learningRate = this.learningRate * (1.0d - this.learningDecay);
// without learningDecay Parameter using sqrt(iteration number) to decrease learning rate
// this.learningRate = this.learningRate / Math.sqrt(context.getCurrentIteration() -1);
this.weightCalculator.setLearningRate(this.learningRate);
this.weightCalculator.setNumTrainSize(this.globalNNParams.getTrainSize());
}
double[] oldWeights = Arrays.copyOf(this.globalNNParams.getWeights(), this.globalNNParams.getWeights().length);
// use last weights and current gradients to calculate
double[] weights = this.weightCalculator.calculateWeights(this.globalNNParams.getWeights(),
this.globalNNParams.getGradients());
this.globalNNParams.setWeights(weights);
// average error
double currentTestError = totalTestError / totalWorkerCount;
double currentTrainError = totalTrainError / totalWorkerCount;
boolean vtTriggered = false;
// if validationTolerance == 0d, means vt check is not enabled
if(validationTolerance > 0d) {
double weightSumSquare = 0d;
double diffWeightSumSquare = 0d;
for(int i = 0; i < weights.length; i++) {
weightSumSquare += Math.pow(weights[i], 2);
diffWeightSumSquare += Math.pow(weights[i] - oldWeights[i], 2);
}
if(Math.pow(diffWeightSumSquare, 0.5) < this.validationTolerance
* Math.max(Math.pow(weightSumSquare, 0.5), 1d)) {
LOG.info("Debug: diffWeightSumSquare {}, weightSumSquare {}, validationTolerance {}",
Math.pow(diffWeightSumSquare, 0.5), Math.pow(weightSumSquare, 0.5), validationTolerance);
vtTriggered = true;
}
}
if(currentTestError < this.bestValidationError) {
this.bestValidationError = currentTestError;
}
LOG.info("NNMaster compute iteration {} ( avg train error {}, avg validation error {} )", new Object[] {
context.getCurrentIteration(), currentTrainError, currentTestError });
NNParams params = new NNParams();
params.setTrainError(currentTrainError);
params.setTestError(currentTestError);
// prevent null point
params.setGradients(new double[0]);
params.setWeights(weights);
LOG.debug("master result {} in iteration {}", params, context.getCurrentIteration());
// Convergence judging part
double avgErr = (currentTrainError + currentTestError) / 2;
LOG.info("NNMaster compute iteration {} average error: {}, threshold: {}", context.getCurrentIteration(),
avgErr, convergenceThreshold);
if(judger.judge(avgErr, convergenceThreshold) || vtTriggered) {
LOG.info("NNMaster compute iteration {} converged !", context.getCurrentIteration());
params.setHalt(true);
} else {
LOG.debug("NNMaster compute iteration {} not converged yet !", context.getCurrentIteration());
}
return params;
}
private NNParams initOrRecoverParams(MasterContext<NNParams, NNParams> context) {
// read existing model weights
NNParams params = null;
try {
Path modelPath = new Path(context.getProps().getProperty(CommonConstants.GUAGUA_OUTPUT));
BasicFloatNetwork existingModel = (BasicFloatNetwork) CommonUtils.loadModel(modelConfig, modelPath,
ShifuFileUtils.getFileSystemBySourceType(this.modelConfig.getDataSet().getSource()));
if(existingModel == null) {
params = initWeights();
LOG.info("Starting to train model from scratch.");
} else {
params = initModelParams(existingModel);
LOG.info("Starting to train model from existing model {}.", modelPath);
}
} catch (IOException e) {
throw new GuaguaRuntimeException(e);
}
return params;
}
private NNParams initModelParams(BasicNetwork loadModel) {
NNParams params = new NNParams();
params.setTrainError(0);
params.setTestError(0);
// prevent null point
params.setGradients(new double[0]);
params.setWeights(loadModel.getFlat().getWeights());
return params;
}
@SuppressWarnings({ "unchecked" })
private NNParams initWeights() {
NNParams params = new NNParams();
int[] inputAndOutput = DTrainUtils.getInputOutputCandidateCounts(this.columnConfigList);
@SuppressWarnings("unused")
int inputNodeCount = inputAndOutput[0] == 0 ? inputAndOutput[2] : inputAndOutput[0];
// if is one vs all classification, outputNodeCount is set to 1
int outputNodeCount = modelConfig.isRegression() ? inputAndOutput[1]
: (modelConfig.getTrain().isOneVsAll() ? inputAndOutput[1] : modelConfig.getTags().size());
int numLayers = (Integer) validParams.get(CommonConstants.NUM_HIDDEN_LAYERS);
List<String> actFunc = (List<String>) validParams.get(CommonConstants.ACTIVATION_FUNC);
List<Integer> hiddenNodeList = (List<Integer>) validParams.get(CommonConstants.NUM_HIDDEN_NODES);
BasicNetwork network = DTrainUtils.generateNetwork(this.subFeatures.size(), outputNodeCount, numLayers,
actFunc, hiddenNodeList);
params.setTrainError(0);
params.setTestError(0);
// prevent null point
params.setGradients(new double[0]);
params.setWeights(network.getFlat().getWeights());
return params;
}
@Override
public void init(MasterContext<NNParams, NNParams> context) {
Properties props = context.getProps();
try {
SourceType sourceType = SourceType.valueOf(props.getProperty(CommonConstants.MODELSET_SOURCE_TYPE,
SourceType.HDFS.toString()));
this.modelConfig = CommonUtils.loadModelConfig(props.getProperty(CommonConstants.SHIFU_MODEL_CONFIG),
sourceType);
this.columnConfigList = CommonUtils.loadColumnConfigList(
props.getProperty(CommonConstants.SHIFU_COLUMN_CONFIG), sourceType);
} catch (IOException e) {
throw new RuntimeException(e);
}
int trainerId = Integer.valueOf(context.getProps().getProperty(CommonConstants.SHIFU_TRAINER_ID, "0"));
GridSearch gs = new GridSearch(modelConfig.getTrain().getParams());
validParams = this.modelConfig.getTrain().getParams();
if(gs.hasHyperParam()) {
validParams = gs.getParams(trainerId);
LOG.info("Start grid search master with params: {}", validParams);
}
Object vtObj = validParams.get("ValidationTolerance");
if(vtObj != null) {
try {
validationTolerance = Double.parseDouble(vtObj.toString());
LOG.warn("Validation by tolerance is enabled with value {}.", validationTolerance);
} catch (NumberFormatException ee) {
validationTolerance = 0d;
LOG.warn(
"Validation by tolerance isn't enabled because of non numerical value of ValidationTolerance: {}.",
vtObj);
}
} else {
LOG.info("Validation by tolerance isn't enabled.");
}
Object pObject = validParams.get(CommonConstants.PROPAGATION);
this.propagation = pObject == null ? "Q" : (String) pObject;
this.rawLearningRate = Double.valueOf(validParams.get(CommonConstants.LEARNING_RATE).toString());
Object dropoutRateObj = validParams.get(CommonConstants.DROPOUT_RATE);
if(dropoutRateObj != null) {
this.dropoutRate = Double.valueOf(dropoutRateObj.toString());
}
LOG.info("dropoutRate in master is :{}", this.dropoutRate);
Object learningDecayO = validParams.get("LearningDecay");
if(learningDecayO != null) {
this.learningDecay = Double.valueOf(learningDecayO.toString());
}
LOG.info("learningDecay in master is :{}", learningDecay);
Double threshold = this.modelConfig.getTrain().getConvergenceThreshold();
this.convergenceThreshold = threshold == null ? 0d : threshold.doubleValue();
LOG.info("Convergence threshold in master is :{}", this.convergenceThreshold);
this.isContinuousEnabled = Boolean.TRUE.toString().equalsIgnoreCase(
context.getProps().getProperty(CommonConstants.CONTINUOUS_TRAINING));
Object rconstant = validParams.get(CommonConstants.LR_REGULARIZED_CONSTANT);
this.regularizedConstant = NumberFormatUtils.getDouble(rconstant == null ? "" : rconstant.toString(), 0d);
// check if variables are set final selected
int[] inputOutputIndex = DTrainUtils.getNumericAndCategoricalInputAndOutputCounts(this.columnConfigList);
this.isAfterVarSelect = (inputOutputIndex[3] == 1);
// cache all feature list for sampling features
this.allFeatures = CommonUtils.getAllFeatureList(columnConfigList, isAfterVarSelect);
String subsetStr = context.getProps().getProperty(CommonConstants.SHIFU_NN_FEATURE_SUBSET);
if(StringUtils.isBlank(subsetStr)) {
this.subFeatures = this.allFeatures;
} else {
String[] splits = subsetStr.split(",");
this.subFeatures = new ArrayList<Integer>(splits.length);
for(String split: splits) {
this.subFeatures.add(Integer.parseInt(split));
}
}
// recover master states here is globalNNParams
// not init but not first iteration, first recover from last master result set from guagua
if(!context.isFirstIteration()) {
NNParams params = context.getMasterResult();
if(params != null && params.getWeights() != null) {
this.globalNNParams.setWeights(params.getWeights());
} else {
// else read from checkpoint
params = initOrRecoverParams(context);
this.globalNNParams.setWeights(params.getWeights());
}
}
}
}