package org.deeplearning4j.nn.conf.layers;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.learning.*;
import org.nd4j.linalg.learning.config.*;
import java.util.HashMap;
import java.util.Map;
/**
* Created by Alex on 22/02/2017.
*/
@Slf4j
public class LayerValidation {
/**
* Validate the updater configuration - setting the default updater values, if necessary
*/
public static void updaterValidation(String layerName, Layer layer, Double learningRate, Double momentum,
Map<Integer, Double> momentumSchedule, Double adamMeanDecay, Double adamVarDecay, Double rho,
Double rmsDecay, Double epsilon) {
updaterValidation(layerName, layer, learningRate == null ? Double.NaN : learningRate,
momentum == null ? Double.NaN : momentum, momentumSchedule,
adamMeanDecay == null ? Double.NaN : adamMeanDecay,
adamVarDecay == null ? Double.NaN : adamVarDecay, rho == null ? Double.NaN : rho,
rmsDecay == null ? Double.NaN : rmsDecay, epsilon == null ? Double.NaN : epsilon);
}
/**
* Validate the updater configuration - setting the default updater values, if necessary
*/
public static void updaterValidation(String layerName, Layer layer, double learningRate, double momentum,
Map<Integer, Double> momentumSchedule, double adamMeanDecay, double adamVarDecay, double rho,
double rmsDecay, double epsilon) {
if ((!Double.isNaN(momentum) || !Double.isNaN(layer.getMomentum())) && layer.getUpdater() != Updater.NESTEROVS)
log.warn("Layer \"" + layerName
+ "\" momentum has been set but will not be applied unless the updater is set to NESTEROVS.");
if ((momentumSchedule != null || layer.getMomentumSchedule() != null)
&& layer.getUpdater() != Updater.NESTEROVS)
log.warn("Layer \"" + layerName
+ "\" momentum schedule has been set but will not be applied unless the updater is set to NESTEROVS.");
if ((!Double.isNaN(adamVarDecay) || (!Double.isNaN(layer.getAdamVarDecay())))
&& layer.getUpdater() != Updater.ADAM)
log.warn("Layer \"" + layerName
+ "\" adamVarDecay is set but will not be applied unless the updater is set to Adam.");
if ((!Double.isNaN(adamMeanDecay) || !Double.isNaN(layer.getAdamMeanDecay()))
&& layer.getUpdater() != Updater.ADAM)
log.warn("Layer \"" + layerName
+ "\" adamMeanDecay is set but will not be applied unless the updater is set to Adam.");
if ((!Double.isNaN(rho) || !Double.isNaN(layer.getRho())) && layer.getUpdater() != Updater.ADADELTA)
log.warn("Layer \"" + layerName
+ "\" rho is set but will not be applied unless the updater is set to ADADELTA.");
if ((!Double.isNaN(rmsDecay) || (!Double.isNaN(layer.getRmsDecay()))) && layer.getUpdater() != Updater.RMSPROP)
log.warn("Layer \"" + layerName
+ "\" rmsdecay is set but will not be applied unless the updater is set to RMSPROP.");
//Set values from old (deprecated) .epsilon(), .momentum(), etc methods to the built-in updaters
//Note that there are *layer* versions (available via the layer) and *global* versions (via the method args)
//The layer versions take precedence over the global versions. If neither are set, we use whatever is set
// on the IUpdater instance, which may be the default, or may be user-configured
//Note that default values for all other parameters are set by default in the Sgd/Adam/whatever classes
//Hence we don't need to set them here
//Finally: we'll also set the (updater enumeration field to something sane) to avoid updater=SGD,
// iupdater=Adam() type situations. Though the updater field isn't used, we don't want to confuse users
IUpdater u = layer.getIUpdater();
if(!Double.isNaN(layer.getLearningRate())){
//Note that for LRs, if user specifies .learningRate(x).updater(Updater.SGD) (for example), we need to set the
// LR in the Sgd object. We can do this using the schedules method, which also works for custom updaters
//Local layer LR set
u.applySchedules(0, layer.getLearningRate());
} else if(!Double.isNaN(learningRate)){
//Global LR set
u.applySchedules(0, learningRate);
}
if(u instanceof Sgd){
layer.setUpdater(Updater.SGD);
} else if(u instanceof Adam ) {
Adam a = (Adam) u;
if (!Double.isNaN(layer.getEpsilon())) {
//user has done legacy .epsilon(...) on the layer itself
a.setEpsilon(layer.getEpsilon());
} else if (!Double.isNaN(epsilon)) {
//user has done legacy .epsilon(...) on MultiLayerNetwork or ComputationGraph
a.setEpsilon(epsilon);
}
if (!Double.isNaN(layer.getAdamMeanDecay())) {
a.setBeta1(layer.getAdamMeanDecay());
} else if (!Double.isNaN(adamMeanDecay)) {
a.setBeta1(adamMeanDecay);
}
if(!Double.isNaN(layer.getAdamVarDecay())){
a.setBeta2(layer.getAdamVarDecay());
} else if(!Double.isNaN(adamVarDecay)){
a.setBeta2(adamVarDecay);
}
layer.setUpdater(Updater.ADAM);
} else if(u instanceof AdaDelta) {
AdaDelta a = (AdaDelta)u;
if(!Double.isNaN(layer.getRho())){
a.setRho(layer.getRho());
} else if(!Double.isNaN(rho)){
a.setRho(rho);
}
if(!Double.isNaN(layer.getEpsilon())){
a.setEpsilon(layer.getEpsilon());
} else if(!Double.isNaN(epsilon)){
a.setEpsilon(epsilon);
}
layer.setUpdater(Updater.ADADELTA);
} else if(u instanceof Nesterovs ){
Nesterovs n = (Nesterovs)u;
if(!Double.isNaN(layer.getMomentum())){
n.setMomentum(layer.getMomentum());
} else if(!Double.isNaN(momentum)){
n.setMomentum(momentum);
}
if(layer.getMomentumSchedule() != null && !layer.getMomentumSchedule().isEmpty() ){
n.setMomentumSchedule(layer.getMomentumSchedule());
} else if(momentumSchedule != null && !momentumSchedule.isEmpty() ){
n.setMomentumSchedule(momentumSchedule);
}
layer.setUpdater(Updater.NESTEROVS);
} else if(u instanceof AdaGrad){
AdaGrad a = (AdaGrad)u;
if(!Double.isNaN(layer.getEpsilon())){
a.setEpsilon(layer.getEpsilon());
} else if(!Double.isNaN(epsilon)){
a.setEpsilon(epsilon);
}
layer.setUpdater(Updater.ADAGRAD);
} else if(u instanceof RmsProp){
RmsProp r = (RmsProp)u;
if(!Double.isNaN(layer.getEpsilon())){
r.setEpsilon(layer.getEpsilon());
} else if(!Double.isNaN(epsilon)){
r.setEpsilon(epsilon);
}
if(!Double.isNaN(layer.getRmsDecay())){
r.setRmsDecay(layer.getRmsDecay());
} else if(!Double.isNaN(rmsDecay)){
r.setRmsDecay(rmsDecay);
}
layer.setUpdater(Updater.RMSPROP);
} else if(u instanceof AdaMax){
AdaMax a = (AdaMax)u;
if(!Double.isNaN(layer.getEpsilon())){
a.setEpsilon(layer.getEpsilon());
} else if(!Double.isNaN(epsilon)){
a.setEpsilon(epsilon);
}
if(!Double.isNaN(layer.getAdamMeanDecay())){
a.setBeta1(layer.getAdamMeanDecay());
} else if(!Double.isNaN(adamMeanDecay)){
a.setBeta1(adamMeanDecay);
}
if(!Double.isNaN(layer.getAdamVarDecay())){
a.setBeta2(layer.getAdamVarDecay());
} else if(!Double.isNaN(adamVarDecay)){
a.setBeta2(adamVarDecay);
}
layer.setUpdater(Updater.ADAMAX);
} else if(u instanceof NoOp){
layer.setUpdater(Updater.NONE);
} else {
//Probably a custom updater
layer.setUpdater(null);
}
//Finally: Let's set the legacy momentum, epsilon, rmsDecay fields on the layer
//At this point, it's purely cosmetic, to avoid NaNs etc there that might confuse users
//The *true* values are now in the IUpdater instances
if(layer.getUpdater() != null) { //May be null with custom updaters etc
switch (layer.getUpdater()) {
case NESTEROVS:
if (Double.isNaN(momentum) && Double.isNaN(layer.getMomentum())) {
layer.setMomentum(Nesterovs.DEFAULT_NESTEROV_MOMENTUM);
} else if (Double.isNaN(layer.getMomentum()))
layer.setMomentum(momentum);
if (momentumSchedule != null && layer.getMomentumSchedule() == null)
layer.setMomentumSchedule(momentumSchedule);
else if (momentumSchedule == null && layer.getMomentumSchedule() == null)
layer.setMomentumSchedule(new HashMap<Integer, Double>());
break;
case ADAM:
if (Double.isNaN(adamMeanDecay) && Double.isNaN(layer.getAdamMeanDecay())) {
layer.setAdamMeanDecay(Adam.DEFAULT_ADAM_BETA1_MEAN_DECAY);
} else if (Double.isNaN(layer.getAdamMeanDecay()))
layer.setAdamMeanDecay(adamMeanDecay);
if (Double.isNaN(adamVarDecay) && Double.isNaN(layer.getAdamVarDecay())) {
layer.setAdamVarDecay(Adam.DEFAULT_ADAM_BETA2_VAR_DECAY);
} else if (Double.isNaN(layer.getAdamVarDecay()))
layer.setAdamVarDecay(adamVarDecay);
if (Double.isNaN(epsilon) && Double.isNaN(layer.getEpsilon())) {
layer.setEpsilon(Adam.DEFAULT_ADAM_EPSILON);
} else if (Double.isNaN(layer.getEpsilon())) {
layer.setEpsilon(epsilon);
}
break;
case ADADELTA:
if (Double.isNaN(rho) && Double.isNaN(layer.getRho())) {
layer.setRho(AdaDelta.DEFAULT_ADADELTA_RHO);
} else if (Double.isNaN(layer.getRho())) {
layer.setRho(rho);
}
if (Double.isNaN(epsilon) && Double.isNaN(layer.getEpsilon())) {
layer.setEpsilon(AdaDelta.DEFAULT_ADADELTA_EPSILON);
} else if (Double.isNaN(layer.getEpsilon())) {
layer.setEpsilon(epsilon);
}
break;
case ADAGRAD:
if (Double.isNaN(epsilon) && Double.isNaN(layer.getEpsilon())) {
layer.setEpsilon(AdaGrad.DEFAULT_ADAGRAD_EPSILON);
} else if (Double.isNaN(layer.getEpsilon())) {
layer.setEpsilon(epsilon);
}
break;
case RMSPROP:
if (Double.isNaN(rmsDecay) && Double.isNaN(layer.getRmsDecay())) {
layer.setRmsDecay(RmsProp.DEFAULT_RMSPROP_RMSDECAY);
} else if (Double.isNaN(layer.getRmsDecay()))
layer.setRmsDecay(rmsDecay);
if (Double.isNaN(epsilon) && Double.isNaN(layer.getEpsilon())) {
layer.setEpsilon(RmsProp.DEFAULT_RMSPROP_EPSILON);
} else if (Double.isNaN(layer.getEpsilon())) {
layer.setEpsilon(epsilon);
}
break;
case ADAMAX:
if (Double.isNaN(adamMeanDecay) && Double.isNaN(layer.getAdamMeanDecay())) {
layer.setAdamMeanDecay(AdaMax.DEFAULT_ADAMAX_BETA1_MEAN_DECAY);
} else if (Double.isNaN(layer.getAdamMeanDecay()))
layer.setAdamMeanDecay(adamMeanDecay);
if (Double.isNaN(adamVarDecay) && Double.isNaN(layer.getAdamVarDecay())) {
layer.setAdamVarDecay(AdaMax.DEFAULT_ADAMAX_BETA2_VAR_DECAY);
} else if (Double.isNaN(layer.getAdamVarDecay()))
layer.setAdamVarDecay(adamVarDecay);
if (Double.isNaN(epsilon) && Double.isNaN(layer.getEpsilon())) {
layer.setEpsilon(AdaMax.DEFAULT_ADAMAX_EPSILON);
} else if (Double.isNaN(layer.getEpsilon())) {
layer.setEpsilon(epsilon);
}
}
}
}
public static void generalValidation(String layerName, Layer layer, boolean useRegularization,
boolean useDropConnect, Double dropOut, Double l2, Double l2Bias, Double l1, Double l1Bias,
Distribution dist) {
generalValidation(layerName, layer, useRegularization, useDropConnect, dropOut == null ? 0.0 : dropOut,
l2 == null ? Double.NaN : l2, l2Bias == null ? Double.NaN : l2Bias,
l1 == null ? Double.NaN : l1, l1Bias == null ? Double.NaN : l1Bias, dist);
}
public static void generalValidation(String layerName, Layer layer, boolean useRegularization,
boolean useDropConnect, double dropOut, double l2, double l2Bias, double l1, double l1Bias,
Distribution dist) {
if (useDropConnect && (Double.isNaN(dropOut) && (Double.isNaN(layer.getDropOut()))))
log.warn("Layer \"" + layerName
+ "\" dropConnect is set to true but dropout rate has not been added to configuration.");
if (useDropConnect && layer.getDropOut() == 0.0)
log.warn("Layer \"" + layerName + " dropConnect is set to true but dropout rate is set to 0.0");
if (useRegularization && (Double.isNaN(l1) && layer != null && Double.isNaN(layer.getL1()) && Double.isNaN(l2)
&& Double.isNaN(layer.getL2()) && Double.isNaN(l2Bias) && Double.isNaN(l1Bias)
&& (Double.isNaN(dropOut) || dropOut == 0.0)
&& (Double.isNaN(layer.getDropOut()) || layer.getDropOut() == 0.0)))
log.warn("Layer \"" + layerName
+ "\" regularization is set to true but l1, l2 or dropout has not been added to configuration.");
if (layer != null) {
if (useRegularization) {
if (!Double.isNaN(l1) && Double.isNaN(layer.getL1())) {
layer.setL1(l1);
}
if (!Double.isNaN(l2) && Double.isNaN(layer.getL2())) {
layer.setL2(l2);
}
if (!Double.isNaN(l1Bias) && Double.isNaN(layer.getL1Bias())) {
layer.setL1Bias(l1Bias);
}
if (!Double.isNaN(l2Bias) && Double.isNaN(layer.getL2Bias())) {
layer.setL2Bias(l2Bias);
}
} else if (!useRegularization && ((!Double.isNaN(l1) && l1 > 0.0)
|| (!Double.isNaN(layer.getL1()) && layer.getL1() > 0.0) || (!Double.isNaN(l2) && l2 > 0.0)
|| (!Double.isNaN(layer.getL2()) && layer.getL2() > 0.0)
|| (!Double.isNaN(l1Bias) && l1Bias > 0.0)
|| (!Double.isNaN(layer.getL1Bias()) && layer.getL1Bias() > 0.0)
|| (!Double.isNaN(l2Bias) && l2Bias > 0.0)
|| (!Double.isNaN(layer.getL2Bias()) && layer.getL2Bias() > 0.0))) {
log.warn("Layer \"" + layerName
+ "\" l1 or l2 has been added to configuration but useRegularization is set to false.");
}
if (Double.isNaN(l2) && Double.isNaN(layer.getL2())) {
layer.setL2(0.0);
}
if (Double.isNaN(l1) && Double.isNaN(layer.getL1())) {
layer.setL1(0.0);
}
if (Double.isNaN(l2Bias) && Double.isNaN(layer.getL2Bias())) {
layer.setL2Bias(0.0);
}
if (Double.isNaN(l1Bias) && Double.isNaN(layer.getL1Bias())) {
layer.setL1Bias(0.0);
}
if (layer.getWeightInit() == WeightInit.DISTRIBUTION) {
if (dist != null && layer.getDist() == null)
layer.setDist(dist);
else if (dist == null && layer.getDist() == null) {
layer.setDist(new NormalDistribution(0, 1));
log.warn("Layer \"" + layerName
+ "\" distribution is automatically set to normalize distribution with mean 0 and variance 1.");
}
} else if ((dist != null || layer.getDist() != null)) {
log.warn("Layer \"" + layerName
+ "\" distribution is set but will not be applied unless weight init is set to WeighInit.DISTRIBUTION.");
}
}
}
}