package quickml.supervised.parametricModels;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import quickml.data.instances.Instance;
import quickml.supervised.classifier.logisticRegression.GradientDescent;
import java.io.Serializable;
import java.util.*;
/**
* Created by alexanderhawk on 10/12/15.
*/
public class SGD<T extends Instance> implements GradientDescent<T> {
public static final Logger logger = LoggerFactory.getLogger(SGD.class);
public static final String LEARNING_RATE = "learningRate";
public static final String USE_BOLD_DRIVER = "useBoldDriver";
public static final String MAX_EPOCHS = "maxEpochs";
public static final String MIN_EPOCHS = "minEpochs";
public static final String MINI_BATCH_SIZE = "miniBatchSize";
public static final String COST_CONVERGENCE_THRESHOLD = "costConvergenceThreshold";
public static final String LEARNING_RATE_BOOST_FACTOR = "learningRateBoostFactor";
public static final String LEARNING_RATE_REDUCTION_FACTOR = "learningRateReductionFactor";
public static final String WEIGHT_CONVERGENCE_THRESHOLD = "weightConvergenceThreshold";
public static final String MIN_PREDICTED_PROBABILITY = "minPredictedProbablity";
public static final String OPTIMIZABLE_COST_FUNCTION = "optimizableCostFunction";
//training hyper-params
private int minibatchSize = 1;
private int maxEpochs = 8;
private int minEpochs = 3;
private double weightConvergenceThreshold = 0.001;
private double costConvergenceThreshold = 0.001;
private double learningRate = 10E-5;
private double minPredictedProbablity = 10E-6;
private double learningRateReductionFactor = 0.5;
private double learningRateBoostFactor = 1.07;
private boolean useBoldDriver = false;
private OptimizableCostFunction<T> optimizableCostFunction;
public SGD() {
}
public void updateBuilderConfig(final Map<String, Serializable> config) {
optimizableCostFunction.updateBuilderConfig(config);
if (config.containsKey(LEARNING_RATE)) {
learningRate((Double) config.get(LEARNING_RATE));
}
if (config.containsKey(USE_BOLD_DRIVER)) {
useBoldDriver((Boolean) config.get(USE_BOLD_DRIVER));
}
if (config.containsKey(MAX_EPOCHS)) {
maxEpochs((Integer) config.get(MAX_EPOCHS));
}
if (config.containsKey(MIN_EPOCHS)) {
minEpochs((Integer) config.get(MIN_EPOCHS));
}
if (config.containsKey(MINI_BATCH_SIZE)) {
minibatchSize((Integer) config.get(MINI_BATCH_SIZE));
}
if (config.containsKey(COST_CONVERGENCE_THRESHOLD)) {
costConvergenceThreshold((Double) config.get(COST_CONVERGENCE_THRESHOLD));
}
if (config.containsKey(LEARNING_RATE_BOOST_FACTOR)) {
learningRateBoostFactor((Double) config.get(LEARNING_RATE_BOOST_FACTOR));
}
if (config.containsKey(LEARNING_RATE_REDUCTION_FACTOR)) {
learningRateReductionFactor((Double) config.get(LEARNING_RATE_REDUCTION_FACTOR));
}
if (config.containsKey(WEIGHT_CONVERGENCE_THRESHOLD)) {
weightConvergenceThreshold((Double) config.get(WEIGHT_CONVERGENCE_THRESHOLD));
}
if (config.containsKey(MIN_PREDICTED_PROBABILITY)) {
minPredictedProbablity((Double) config.get(MIN_PREDICTED_PROBABILITY));
}
if (config.containsKey(OPTIMIZABLE_COST_FUNCTION)) {
optimizableCostFunction((OptimizableCostFunction<T> ) config.get(OPTIMIZABLE_COST_FUNCTION));
}
}
public OptimizableCostFunction<T> getOptimizableCostFunction() {
return optimizableCostFunction;
}
public SGD optimizableCostFunction(OptimizableCostFunction<T> optimizableCostFunction) {
this.optimizableCostFunction = optimizableCostFunction;
return this;
}
public SGD learningRate(double learningRate) {
this.learningRate = learningRate;
return this;
}
public SGD useBoldDriver(boolean useBoldDriver) {
this.useBoldDriver = useBoldDriver;
return this;
}
public SGD maxEpochs(int maxEpochs) {
this.maxEpochs = maxEpochs;
return this;
}
public SGD minPredictedProbablity(double minPredictedProbablity) {
this.minPredictedProbablity = minPredictedProbablity;
return this;
}
public SGD weightConvergenceThreshold(double weightConvergenceThreshold) {
this.weightConvergenceThreshold = weightConvergenceThreshold;
return this;
}
public SGD learningRateReductionFactor(double learningRateReductionFactor) {
this.learningRateReductionFactor = learningRateReductionFactor;
return this;
}
public SGD learningRateBoostFactor(double learningRateBoostFactor) {
this.learningRateBoostFactor = learningRateBoostFactor;
return this;
}
public SGD costConvergenceThreshold(double costConvergenceThreshold) {
this.costConvergenceThreshold = costConvergenceThreshold;
return this;
}
public SGD minEpochs(int minEpochs) {
this.minEpochs = minEpochs;
return this;
}
public SGD minibatchSize(int minibatchSize) {
this.minibatchSize = minibatchSize;
return this;
}
@Override
public double[] minimize(final List<T > instances, int numRegressors) {
/** minimizes the cross entropy loss function. NumRegressors includes the bias term.
*/
double[] weights = initializeWeights(numRegressors);
double previousCostFunctionValue = 0;
double costFunctionValue = optimizableCostFunction.computeCost(instances, weights, minPredictedProbablity);
for (int epoch = 0; epoch < maxEpochs; epoch++) {
logCostFunctionValueAtRegularIntervals(previousCostFunctionValue, costFunctionValue, epoch);
double[] weightsAtPreviousEpoch = Arrays.copyOf(weights, weights.length);
for (int miniBatchStartIndex = 0; miniBatchStartIndex < instances.size(); miniBatchStartIndex += minibatchSize) {
final double[] fixedWeights = Arrays.copyOf(weights, weights.length);
final double[] gradient = new double[weights.length];
int currentMiniBatchSize = getCurrentMiniBatchSize(minibatchSize, instances.size(), miniBatchStartIndex);
List<T> instancesForBatch = instances.subList(miniBatchStartIndex, currentMiniBatchSize);
optimizableCostFunction.updateGradient(instancesForBatch, fixedWeights, gradient);
for (int k = 0; k < weights.length; k++) {
weights[k] = weights[k] - gradient[k] * learningRate;
}
}
previousCostFunctionValue = costFunctionValue;
costFunctionValue = optimizableCostFunction.computeCost(instances, weights, minPredictedProbablity);
if (ceaseMinimization(weights, previousCostFunctionValue, epoch, weightsAtPreviousEpoch)) {
logger.info("breaking after {} epochs with cost {}", epoch + 1, costFunctionValue);
break;
}
adjustLearningRateIfNecessary(previousCostFunctionValue, costFunctionValue);
Collections.shuffle(instances);
}
optimizableCostFunction.shutdown();
return weights;
}
public static int getCurrentMiniBatchSize(int minibatchSize, int totalNumInstances, int miniBatchStartIndex) {
return Math.min(minibatchSize, totalNumInstances - miniBatchStartIndex);
}
private boolean ceaseMinimization(double[] weights, double previousCostFunctionValue, int epoch, double[] weightsAtPreviousEpoch) {
return epoch > minEpochs && weightsConverged(weights, weightsAtPreviousEpoch, weightConvergenceThreshold)
&& costsConverged(previousCostFunctionValue, previousCostFunctionValue, costConvergenceThreshold);
}
private void logCostFunctionValueAtRegularIntervals(double previousCostFunctionValue, double costFunctionValue, int i) {
if (maxEpochs < 10 || i % (maxEpochs / 10) == 0) {
logger.info("cost {}, prevCost {}, learning rate {}, before epoch {}", costFunctionValue, previousCostFunctionValue, learningRate, i);
}
}
private void adjustLearningRateIfNecessary(double previousCost, double currentCost) {
if (useBoldDriver) {
if (previousCost > currentCost) {
learningRate = learningRate * learningRateBoostFactor;
} else {
learningRate = learningRate * learningRateReductionFactor;
}
}
}
public static boolean weightsConverged(double[] weights, double[] newWeights, double weightConvergenceThreshold) {
double meanSquaredDifference = 0;
double normSquared = 0.0;
for (int i = 0; i < weights.length; i++) {
meanSquaredDifference += (weights[i] - newWeights[i]) * (weights[i] - newWeights[i]);
normSquared += weights[i] * weights[i];
}
return Math.sqrt(meanSquaredDifference / normSquared) < weightConvergenceThreshold;
}
public static boolean costsConverged(double previousCost, double presentCost, double costConvergenceThreshold) {
return Math.abs(presentCost - previousCost) / presentCost < costConvergenceThreshold;
}
private double[] initializeWeights(int numFeatures) {
double[] weights = new double[numFeatures]; //presume normalized
Random random = new Random();
for (int i = 0; i < numFeatures; i++) {
weights[i] = random.nextDouble() * 1.0 - 0.5; //a random number between -0.25 and 0.25
}
return weights;
}
}