/**
* Copyright [2012-2014] PayPal Software Foundation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.shifu.shifu.core.dtrain;
import java.util.Random;
/**
* {@link Weight} is used to update NN weights according to propagation option. Which is also copied from Encog.
*
* <p>
* We'd like to reuse code from Encog but unfortunately the methods are private:(.
*/
public class Weight {
/**
* The zero tolerance to use.
*/
private final static double ZERO_TOLERANCE = 0.00000000000000001;
private double learningRate;
private String algorithm;
// for quick propagation
private double decay = 0.0001d;
private double[] lastDelta = null;
private double[] lastGradient = null;
private double outputEpsilon = 0.35;
private double eps = 0.0;
private double shrink = 0.0;
// for back propagation
private double momentum = 0.5;
// for resilient propagation
private double[] updateValues = null;
private static final double DEFAULT_INITIAL_UPDATE = 0.1;
private static final double DEFAULT_MAX_STEP = 50;
/**
* L1 or L2 regulation parameter.
*/
private double reg;
/**
* Number of training records
*/
private double numTrainSize;
/**
* Regulazation level
*/
private RegulationLevel rl = RegulationLevel.NONE;
/**
* Dropout rate.
*/
private double dropoutRate = 0d;
/**
* Random object to do drop out
*/
private Random random;
public Weight(int numWeight, double numTrainSize, double rate, String algorithm, double reg, RegulationLevel rl,
double dropoutRate) {
this.dropoutRate = dropoutRate;
this.random = new Random();
this.lastDelta = new double[numWeight];
this.lastGradient = new double[numWeight];
this.numTrainSize = numTrainSize;
this.eps = this.outputEpsilon / numTrainSize;
this.shrink = rate / (1.0 + rate);
this.learningRate = rate;
this.algorithm = algorithm;
this.updateValues = new double[numWeight];
for(int i = 0; i < this.updateValues.length; i++) {
this.updateValues[i] = DEFAULT_INITIAL_UPDATE;
this.lastDelta[i] = 0;
}
this.reg = reg;
if(rl != null) {
this.rl = rl;
}
}
public double[] calculateWeights(double[] weights, double[] gradients) {
for(int i = 0; i < gradients.length; i++) {
if(this.random.nextDouble() < this.dropoutRate) {
// drop out, no need to update weight, just continue next weight
continue;
}
switch(this.rl) {
case NONE:
weights[i] += updateWeight(i, weights, gradients);
break;
case L1:
if(Double.compare(this.reg, 0d) == 0) {
weights[i] += updateWeight(i, weights, gradients);
} else {
double shrinkValue = this.reg / getNumTrainSize();
double delta = updateWeight(i, weights, gradients);
weights[i] += Math.signum(delta) * Math.max(0.0, Math.abs(delta) - shrinkValue);
}
break;
case L2:
default:
weights[i] += (updateWeight(i, weights, gradients) - this.reg * weights[i] / getNumTrainSize());
break;
}
}
return weights;
}
private double updateWeight(int index, double[] weights, double[] gradients) {
if(this.algorithm.equalsIgnoreCase(DTrainUtils.BACK_PROPAGATION)) {
return updateWeightBP(index, weights, gradients);
} else if(this.algorithm.equalsIgnoreCase(DTrainUtils.QUICK_PROPAGATION)) {
return updateWeightQBP(index, weights, gradients);
} else if(this.algorithm.equalsIgnoreCase(DTrainUtils.MANHATTAN_PROPAGATION)) {
return updateWeightMHP(index, weights, gradients);
} else if(this.algorithm.equalsIgnoreCase(DTrainUtils.SCALEDCONJUGATEGRADIENT)) {
return updateWeightSCG(index, weights, gradients);
} else if(this.algorithm.equalsIgnoreCase(DTrainUtils.RESILIENTPROPAGATION)) {
return updateWeightRLP(index, weights, gradients);
}
return 0.0;
}
private double updateWeightBP(int index, double[] weights, double[] gradients) {
double delta = (gradients[index] * this.getLearningRate()) + (this.lastDelta[index] * this.momentum);
this.lastDelta[index] = delta;
return delta;
}
private double updateWeightQBP(int index, double[] weights, double[] gradients) {
final double w = weights[index];
final double d = this.lastDelta[index];
final double s = -gradients[index] + this.decay * w;
final double p = -lastGradient[index];
double nextStep = 0.0;
// The step must always be in direction opposite to the slope.
if(d < 0.0) {
// If last step was negative...
if(s > 0.0) {
// Add in linear term if current slope is still positive.
nextStep -= this.eps * s;
}
// If current slope is close to or larger than prev slope...
if(s >= (this.shrink * p)) {
// Take maximum size negative step.
nextStep += this.getLearningRate() * d;
} else {
// Else, use quadratic estimate.
nextStep += d * s / (p - s);
}
} else if(d > 0.0) {
// If last step was positive...
if(s < 0.0) {
// Add in linear term if current slope is still negative.
nextStep -= this.eps * s;
}
// If current slope is close to or more neg than prev slope...
if(s <= (this.shrink * p)) {
// Take maximum size negative step.
nextStep += this.getLearningRate() * d;
} else {
// Else, use quadratic estimate.
nextStep += d * s / (p - s);
}
} else {
// Last step was zero, so use only linear term.
nextStep -= this.eps * s;
}
// update global data arrays
this.lastDelta[index] = nextStep;
this.lastGradient[index] = gradients[index];
return nextStep;
}
private double updateWeightMHP(int index, double[] weights, double[] gradients) {
if(Math.abs(gradients[index]) < ZERO_TOLERANCE) {
return 0;
} else if(gradients[index] > 0) {
return this.getLearningRate();
} else {
return -this.getLearningRate();
}
}
private double updateWeightSCG(int index, double[] weights, double[] gradients) {
throw new RuntimeException("SCG propagation is not supported in distributed NN computing.");
}
private double updateWeightRLP(int index, double[] weights, double[] gradients) {
// multiply the current and previous gradient, and take the sign. We want to see if the gradient has changed its
// sign.
final int change = DTrainUtils.sign(gradients[index] * lastGradient[index]);
double weightChange = 0;
// if the gradient has retained its sign, then we increase the delta so that it will converge faster
if(change > 0) {
double delta = this.updateValues[index] * DTrainUtils.POSITIVE_ETA;
delta = Math.min(delta, DEFAULT_MAX_STEP);
weightChange = DTrainUtils.sign(gradients[index]) * delta;
this.updateValues[index] = delta;
lastGradient[index] = gradients[index];
} else if(change < 0) {
// if change<0, then the sign has changed, and the last delta was too big
double delta = this.updateValues[index] * DTrainUtils.NEGATIVE_ETA;
delta = Math.max(delta, DTrainUtils.DELTA_MIN);
this.updateValues[index] = delta;
weightChange = -this.lastDelta[index];
// set the previous gradient to zero so that there will be no adjustment the next iteration
lastGradient[index] = 0;
} else if(change == 0) {
// if change==0 then there is no change to the delta
final double delta = this.updateValues[index];
weightChange = DTrainUtils.sign(gradients[index]) * delta;
lastGradient[index] = gradients[index];
}
this.lastDelta[index] = weightChange;
// apply the weight change, if any
return weightChange;
}
/**
* @return the learningRate
*/
public double getLearningRate() {
return learningRate;
}
/**
* @param learningRate
* the learningRate to set
*/
public void setLearningRate(double learningRate) {
this.learningRate = learningRate;
}
/**
* @return the numTrainSize
*/
public double getNumTrainSize() {
return numTrainSize;
}
/**
* @param numTrainSize
* the numTrainSize to set
*/
public void setNumTrainSize(double numTrainSize) {
this.numTrainSize = numTrainSize;
this.eps = this.outputEpsilon / numTrainSize;
}
}