package quickml.supervised.classifier.logisticRegression;
import com.google.common.collect.Lists;
import it.unimi.dsi.fastutil.ints.Int2DoubleMap;
import it.unimi.dsi.fastutil.ints.Int2DoubleOpenHashMap;
import org.javatuples.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.Serializable;
import java.util.*;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import static quickml.MathUtils.cappedlogBase2;
import static quickml.MathUtils.sigmoid;
/**
* Created by alexanderhawk on 10/12/15.
*/
public class SparseSGD implements GradientDescent<SparseClassifierInstance > {
private int executorThreadCount = Runtime.getRuntime().availableProcessors();
private ExecutorService executorService;
public static final String RIDGE = "ridge";
public static final String LASSO = "lasso";
public static final Logger logger = LoggerFactory.getLogger(SparseSGD.class);
public static final String LEARNING_RATE = "learningRate";
public static final String USE_BOLD_DRIVER = "useBoldDriver";
public static final String MAX_EPOCHS = "maxEpochs";
public static final String MIN_EPOCHS = "minEpochs";
public static final String MINI_BATCH_SIZE = "miniBatchSize";
public static final String COST_CONVERGENCE_THRESHOLD = "costConvergenceThreshold";
public static final String LEARNING_RATE_BOOST_FACTOR = "learningRateBoostFactor";
public static final String LEARNING_RATE_REDUCTION_FACTOR = "learningRateReductionFactor";
public static final String MAX_GRADIENT_NORM = "maxGradientNorm";
public static final String WEIGHT_CONVERGENCE_THRESHOLD = "weightConvergenceThreshold";
public static final String MIN_PREDICTED_PROBABILITY = "minPredictedProbablity";
public static final String EXPECTED_FRACTION_OF_FEATURES_TO_UPDATE_PER_WORKER = "expectedFractionOfFeaturesToUpdatePerWorker";
public static final String EXECUTOR_THREAD_COUNT = "executorThreadCount";
public static final String MIN_INSTANCES_FOR_PARELLIZATION = "minInstancesForParrellization";
public static final String SPARSE_PARELLIZATION = "sparseParallelization";
//model hyper-params
double ridge = 0;
double lasso = 0;
//training hyper-params
private int minibatchSize = 1;
private int maxEpochs = 8;
private int minEpochs = 3;
private double weightConvergenceThreshold = 0.001;
private double costConvergenceThreshold = 0.001;
private double learningRate = 10E-5;
private double maxGradientNorm = Double.MAX_VALUE;
private double minPredictedProbablity = 10E-6;
private double learningRateReductionFactor = 0.5;
private double learningRateBoostFactor = 1.07;
private boolean useBoldDriver = false;
private double expectedFractionOfFeaturesToUpdatePerWorker = 1.0;
private int minInstancesForParrellization = 100;
private boolean sparseParallelization = true;
public SparseSGD() {
}
public void updateBuilderConfig(final Map<String, Serializable> config) {
if (config.containsKey(LASSO)) {
ridgeRegularizationConstant((Double) config.get(LASSO));
}
if (config.containsKey(RIDGE)) {
lassoRegularizationConstant((Double) config.get(RIDGE));
}
if (config.containsKey(EXECUTOR_THREAD_COUNT)) {
executorThreadCount((Integer) config.get(EXECUTOR_THREAD_COUNT));
}
if (config.containsKey(EXPECTED_FRACTION_OF_FEATURES_TO_UPDATE_PER_WORKER)) {
expectedFractionOfFeaturesToUpdatePerWorker((Double) config.get(EXPECTED_FRACTION_OF_FEATURES_TO_UPDATE_PER_WORKER));
}
if (config.containsKey(LEARNING_RATE)) {
learningRate((Double) config.get(LEARNING_RATE));
}
if (config.containsKey(USE_BOLD_DRIVER)) {
useBoldDriver((Boolean) config.get(USE_BOLD_DRIVER));
}
if (config.containsKey(MAX_EPOCHS)) {
maxEpochs((Integer) config.get(MAX_EPOCHS));
}
if (config.containsKey(MIN_EPOCHS)) {
minEpochs((Integer) config.get(MIN_EPOCHS));
}
if (config.containsKey(MINI_BATCH_SIZE)) {
minibatchSize((Integer) config.get(MINI_BATCH_SIZE));
}
if (config.containsKey(COST_CONVERGENCE_THRESHOLD)) {
costConvergenceThreshold((Double) config.get(COST_CONVERGENCE_THRESHOLD));
}
if (config.containsKey(LEARNING_RATE_BOOST_FACTOR)) {
learningRateBoostFactor((Double) config.get(LEARNING_RATE_BOOST_FACTOR));
}
if (config.containsKey(LEARNING_RATE_REDUCTION_FACTOR)) {
learningRateReductionFactor((Double) config.get(LEARNING_RATE_REDUCTION_FACTOR));
}
if (config.containsKey(MAX_GRADIENT_NORM)) {
maxGradientNorm((Double) config.get(MAX_GRADIENT_NORM));
}
if (config.containsKey(WEIGHT_CONVERGENCE_THRESHOLD)) {
weightConvergenceThreshold((Double) config.get(WEIGHT_CONVERGENCE_THRESHOLD));
}
if (config.containsKey(MIN_PREDICTED_PROBABILITY)) {
minPredictedProbablity((Double) config.get(MIN_PREDICTED_PROBABILITY));
}
if (config.containsKey(MIN_INSTANCES_FOR_PARELLIZATION)) {
minInstancesForParrellization((Integer) config.get(MIN_INSTANCES_FOR_PARELLIZATION));
}
if (config.containsKey(SPARSE_PARELLIZATION)) {
sparseParallelization((Boolean) config.get(SPARSE_PARELLIZATION));
}
}
public SparseSGD sparseParallelization(boolean sparseParallelization) {
this.sparseParallelization = sparseParallelization;
return this;
}
public SparseSGD executorThreadCount(int executorThreadCount) {
if (executorThreadCount < this.executorThreadCount) {
this.executorThreadCount = executorThreadCount;
} else {
logger.warn("can't use more executors than cores");
}
return this;
}
public SparseSGD minInstancesForParrellization(int minInstancesForParrellization) {
this.minInstancesForParrellization = minInstancesForParrellization;
return this;
}
public SparseSGD expectedFractionOfFeaturesToUpdatePerWorker(double expectedFractionOfFeaturesToUpdatePerWorker) {
this.expectedFractionOfFeaturesToUpdatePerWorker = expectedFractionOfFeaturesToUpdatePerWorker;
return this;
}
public SparseSGD learningRate(double learningRate) {
this.learningRate = learningRate;
return this;
}
public SparseSGD useBoldDriver(boolean useBoldDriver) {
this.useBoldDriver = useBoldDriver;
return this;
}
public SparseSGD maxEpochs(int maxEpochs) {
this.maxEpochs = maxEpochs;
return this;
}
public SparseSGD minPredictedProbablity(double minPredictedProbablity) {
this.minPredictedProbablity = minPredictedProbablity;
return this;
}
public SparseSGD weightConvergenceThreshold(double weightConvergenceThreshold) {
this.weightConvergenceThreshold = weightConvergenceThreshold;
return this;
}
public SparseSGD maxGradientNorm(double maxGradientNorm) {
this.maxGradientNorm = maxGradientNorm;
return this;
}
public SparseSGD learningRateReductionFactor(double learningRateReductionFactor) {
this.learningRateReductionFactor = learningRateReductionFactor;
return this;
}
public SparseSGD learningRateBoostFactor(double learningRateBoostFactor) {
this.learningRateBoostFactor = learningRateBoostFactor;
return this;
}
public SparseSGD costConvergenceThreshold(double costConvergenceThreshold) {
this.costConvergenceThreshold = costConvergenceThreshold;
return this;
}
public SparseSGD minEpochs(int minEpochs) {
this.minEpochs = minEpochs;
return this;
}
public SparseSGD minibatchSize(int minibatchSize) {
this.minibatchSize = minibatchSize;
return this;
}
public SparseSGD ridgeRegularizationConstant(final double ridgeRegularizationConstant) {
this.ridge = ridgeRegularizationConstant;
return this;
}
public SparseSGD lassoRegularizationConstant(final double ridgeRegularizationConstant) {
this.lasso = ridgeRegularizationConstant;
return this;
}
@Override
public double[] minimize(final List<SparseClassifierInstance > sparseClassifierInstances, int numRegressors) {
/** minimizes the cross entropy loss function. NumRegressors includes the bias term.
*/
executorService = Executors.newFixedThreadPool(executorThreadCount);
double[] weights = initializeWeights(numRegressors);
double previousCostFunctionValue = 0;
double costFunctionValue = computeCrossEntropyCostFunction(sparseClassifierInstances, weights, minPredictedProbablity, ridge, lasso);
for (int epoch = 0; epoch < maxEpochs; epoch++) {
logCostFunctionValueAtRegularIntervals(previousCostFunctionValue, costFunctionValue, epoch);
double[] weightsAtPreviousEpoch = Arrays.copyOf(weights, weights.length);
for (int miniBatchStartIndex = 0; miniBatchStartIndex < sparseClassifierInstances.size(); miniBatchStartIndex += minibatchSize) {
final double[] fixedWeights = Arrays.copyOf(weights, weights.length);
final double[] gradient = new double[weights.length];
int currentMiniBatchSize = getCurrentMiniBatchSize(minibatchSize, sparseClassifierInstances.size(), miniBatchStartIndex);
final int[] threadStartAndStopIndices = getThreadStartIndices(miniBatchStartIndex, currentMiniBatchSize, executorThreadCount, minInstancesForParrellization);
int actualNumThreads = threadStartAndStopIndices.length - 1;
if (sparseParallelization) {
sparseCalculationOfGradient( sparseClassifierInstances, fixedWeights, gradient, threadStartAndStopIndices, actualNumThreads);
} else {
nonSparseCalculationOfGradient(sparseClassifierInstances, fixedWeights, gradient, threadStartAndStopIndices, actualNumThreads);
}
addRegularizationComponentOfTheGradient(weights, gradient, ridge, lasso);
normalizeTheGradient(currentMiniBatchSize, maxGradientNorm, gradient);
for (int k = 0; k < weights.length; k++) {
weights[k] = weights[k] - gradient[k] * learningRate;
}
}
previousCostFunctionValue = costFunctionValue;
costFunctionValue = computeCrossEntropyCostFunction(sparseClassifierInstances, weights, minPredictedProbablity, ridge, lasso);
if (ceaseMinimization(weights, previousCostFunctionValue, epoch, weightsAtPreviousEpoch)) {
logger.info("breaking after {} epochs with cost {}", epoch + 1, costFunctionValue);
break;
}
adjustLearningRateIfNecessary(previousCostFunctionValue, costFunctionValue);
Collections.shuffle(sparseClassifierInstances);
}
executorService.shutdown();
return weights;
}
private void sparseCalculationOfGradient(final List<? extends SparseClassifierInstance> sparseClassifierInstances, final double[] fixedWeights, double[] gradient, final int[] threadStartAndStopIndices, int actualNumThreads) {
List<Future<Int2DoubleOpenHashMap>> contributionsToTheGradient = Lists.newArrayListWithCapacity(actualNumThreads);
for (int i = 0; i < actualNumThreads; i++) {
final int index = i;
contributionsToTheGradient.add(executorService.submit(new Callable<Int2DoubleOpenHashMap>() {
@Override
public Int2DoubleOpenHashMap call() throws Exception {
expectedFractionOfFeaturesToUpdatePerWorker = 1.0;
try {
Int2DoubleOpenHashMap sparseWorkerContributionToTheGradient = getSparseWorkerContributionToTheGradient(sparseClassifierInstances.subList(threadStartAndStopIndices[index], threadStartAndStopIndices[index + 1]), fixedWeights, expectedFractionOfFeaturesToUpdatePerWorker);
return sparseWorkerContributionToTheGradient;
} catch (IllegalArgumentException e) {
logger.info("what?");
throw new RuntimeException(e);
}
}
}));
}
sparseReductionToTheGradient(gradient, contributionsToTheGradient);
}
private void nonSparseCalculationOfGradient(final List<? extends SparseClassifierInstance> sparseClassifierInstances, final double[] fixedWeights, double[] gradient, final int[] threadStartAndStopIndices, int actualNumThreads) {
List<Future<double[]>> contributionsToTheGradient = Lists.newArrayListWithCapacity(actualNumThreads);
for (int i = 0; i < actualNumThreads; i++) {
final int index = i;
contributionsToTheGradient.add(executorService.submit(new Callable<double[]>() {
@Override
public double[] call() throws Exception {
expectedFractionOfFeaturesToUpdatePerWorker = 1.0;
return getWorkerContributionToTheGradient(sparseClassifierInstances.subList(threadStartAndStopIndices[index], threadStartAndStopIndices[index + 1]), fixedWeights);
}
}));
}
reductionToTheGradient(gradient, contributionsToTheGradient);
}
public static int getCurrentMiniBatchSize(int minibatchSize, int totalNumInstances, int miniBatchStartIndex) {
return Math.min(minibatchSize, totalNumInstances - miniBatchStartIndex);
}
public static int[] getThreadStartIndices(int miniBatchStartIndex, int currentMiniBatchSize, int executorThreadCount, int minInstancesForParrallization) {
int actualNumThreads = executorThreadCount;
if (currentMiniBatchSize < minInstancesForParrallization) {
int[] threadStartIndices = new int[2];
threadStartIndices[0] = miniBatchStartIndex;
threadStartIndices[1] = miniBatchStartIndex + currentMiniBatchSize;
return threadStartIndices;
} else if (currentMiniBatchSize <= executorThreadCount) {
actualNumThreads = currentMiniBatchSize;
int[] threadStartIndices = new int[actualNumThreads+1];
for (int i = 0; i < actualNumThreads; i++) {
threadStartIndices[i] = miniBatchStartIndex + i;
}
threadStartIndices[actualNumThreads] = miniBatchStartIndex + actualNumThreads; //could be put in loop but follow the convention of putting final stop index outside
return threadStartIndices;
}
int[] threadStartIndices = new int[executorThreadCount + 1];
int lowerSamplesPerThread = currentMiniBatchSize / executorThreadCount;
int upperSamplesPerThread = currentMiniBatchSize / executorThreadCount + 1;
int remainder = currentMiniBatchSize % executorThreadCount;
int currentStartIndex = miniBatchStartIndex;
for (int i = 0; i < executorThreadCount; i++) {
threadStartIndices[i] = currentStartIndex;
if (i >= executorThreadCount - remainder) {
currentStartIndex += upperSamplesPerThread;
} else {
currentStartIndex += lowerSamplesPerThread;
}
}
threadStartIndices[executorThreadCount] = miniBatchStartIndex + currentMiniBatchSize;
return threadStartIndices;
}
private boolean ceaseMinimization(double[] weights, double previousCostFunctionValue, int epoch, double[] weightsAtPreviousEpoch) {
return epoch > minEpochs && weightsConverged(weights, weightsAtPreviousEpoch, weightConvergenceThreshold)
&& costsConverged(previousCostFunctionValue, previousCostFunctionValue, costConvergenceThreshold);
}
public static void addRegularizationComponentOfTheGradient(double[] weights, double[] gradient, double ridge, double lasso) {
for (int i = 1; i < weights.length; i++) {//start at 1 to skip the bias term
double lassoDerivative = lasso;
if (weights[i] < 0.0) {
lassoDerivative *= -1;
}
gradient[i] += ridge * weights[i] + lassoDerivative;
}
}
public static Int2DoubleOpenHashMap getSparseWorkerContributionToTheGradient(List<? extends SparseClassifierInstance> instances, double[] weights, double expectedFractionOfFeaturesToUpdate) {
Int2DoubleOpenHashMap contributionsToTheGradient = new Int2DoubleOpenHashMap((int) (expectedFractionOfFeaturesToUpdate * weights.length));
contributionsToTheGradient.defaultReturnValue(0.0);
for (SparseClassifierInstance instance : instances) {
sparseUpdateUnnormalizedGradientForInstance(weights, contributionsToTheGradient, instance);
}
return contributionsToTheGradient;
}
public static void sparseReductionToTheGradient(double[] gradient, List<Future<Int2DoubleOpenHashMap>> contributions) {
for (Future<Int2DoubleOpenHashMap> contribution : contributions) {
addSparseContribution(gradient, contribution);
}
}
public static void addSparseContribution(double[] gradient, Future<Int2DoubleOpenHashMap> contributionFuture) {
try {
Int2DoubleOpenHashMap contribution = contributionFuture.get();
for (Int2DoubleMap.Entry entry : contribution.int2DoubleEntrySet()) {
gradient[entry.getKey()] += entry.getValue();
}
} catch (Exception e) {
e.printStackTrace();
throw new RuntimeException(e);
}
}
public static double[] getWorkerContributionToTheGradient(List<? extends SparseClassifierInstance> instances, double[] weights) {
double[] contributionsToTheGradient = new double[weights.length];
for (SparseClassifierInstance instance : instances) {
updateUnnormalizedGradientForInstance(weights, contributionsToTheGradient, instance);
}
return contributionsToTheGradient;
}
public static void reductionToTheGradient(double[] gradient, List<Future<double[]>> contributions) {
for (Future<double[]> contribution : contributions) {
addContribution(gradient, contribution);
}
}
public static void addContribution(double[] gradient, Future<double[]> contributionFuture) {
try {
double[] contribution = contributionFuture.get();
for (int i = 0; i< gradient.length; i++) {
gradient[i] += contribution[i];
}
} catch (Exception e) {
e.printStackTrace();
throw new RuntimeException(e);
}
}
private void logCostFunctionValueAtRegularIntervals(double previousCostFunctionValue, double costFunctionValue, int i) {
if (maxEpochs < 10 || i % (maxEpochs / 10) == 0) {
logger.info("cost {}, prevCost {}, learning rate {}, before epoch {}", costFunctionValue, previousCostFunctionValue, learningRate, i);
}
}
private void adjustLearningRateIfNecessary(double previousCost, double currentCost) {
if (useBoldDriver) {
if (previousCost > currentCost) {
learningRate = learningRate * learningRateBoostFactor;
} else {
learningRate = learningRate * learningRateReductionFactor;
}
}
}
public static boolean weightsConverged(double[] weights, double[] newWeights, double weightConvergenceThreshold) {
double meanSquaredDifference = 0;
double normSquared = 0.0;
for (int i = 0; i < weights.length; i++) {
meanSquaredDifference += (weights[i] - newWeights[i]) * (weights[i] - newWeights[i]);
normSquared += weights[i] * weights[i];
}
return Math.sqrt(meanSquaredDifference / normSquared) < weightConvergenceThreshold;
}
public static boolean costsConverged(double previousCost, double presentCost, double costConvergenceThreshold) {
return Math.abs(presentCost - previousCost) / presentCost < costConvergenceThreshold;
}
public static double computeCrossEntropyCostFunction(List<? extends SparseClassifierInstance> instances, double[] weights, double minPredictedProbablity, double ridge, double lasso) {
double cost = 0.0;
for (SparseClassifierInstance instance : instances) {
if ((double) instance.getLabel() == 1.0) {
cost += -cappedlogBase2(probabilityOfThePositiveClass(weights, instance), minPredictedProbablity);
} else if ((double) instance.getLabel() == 0.0) {
cost += -cappedlogBase2(probabilityOfTheNegativeClass(weights, instance), minPredictedProbablity);
}
}
cost += getRegularizationCost(weights, ridge, lasso);
cost /= instances.size();
return cost;
}
public static double probabilityOfTheNegativeClass(double[] weights, SparseClassifierInstance instance) {
return 1.0 - probabilityOfThePositiveClass(weights, instance);
}
public static double probabilityOfThePositiveClass(double[] weights, SparseClassifierInstance instance) {
return sigmoid(instance.dotProduct(weights));
}
public static double getRegularizationCost(double[] weights, double ridge, double lasso) {
double cost = 0;
for (int i = 0; i < weights.length; i++) {
cost += weights[i] * weights[i] * ridge / 2.0 + Math.abs(weights[i]) * lasso;
}
return cost;
}
public static void normalizeTheGradient(int minibatchSize, double maxGradientNorm, double[] gradient) {
for (int i = 1; i < gradient.length; i++) {
gradient[i] /= minibatchSize;
}
if (maxGradientNorm != Double.MAX_VALUE) {
applyMaxGradientNorm(maxGradientNorm, gradient);
}
}
public static void applyMaxGradientNorm(double maxGradientNorm, double[] gradient) {
double gradientSumOfSquares = 0;
for (double g : gradient) {
gradientSumOfSquares += Math.pow(g, 2);
}
double gradientNorm = Math.sqrt(gradientSumOfSquares);
if (gradientNorm > maxGradientNorm) {
double n = gradientNorm / maxGradientNorm;
for (int i = 0; i < gradient.length; i++) {
gradient[i] = gradient[i] / Math.sqrt(n);
}
}
}
static void sparseUpdateUnnormalizedGradientForInstance(double[] weights, Int2DoubleOpenHashMap contributionsToTheGradient,
SparseClassifierInstance instance) {
//could do this with a map for truly sparse instances...but
double postiveClassProbability = probabilityOfThePositiveClass(weights, instance);
Pair<int[], double[]> sparseAttributes = instance.getSparseAttributes();
int[] indices = sparseAttributes.getValue0();
double[] values = sparseAttributes.getValue1();
for (int i = 0; i < indices.length; i++) {
int featureIndex = indices[i];
contributionsToTheGradient.addTo(featureIndex, gradientContributionOfAFeatureValue((Double) instance.getLabel(), postiveClassProbability, values[i]));
}
}
private static double gradientContributionOfAFeatureValue(double label, double postiveClassProbability, double value) {
return -(label - postiveClassProbability) * value;
}
static void updateUnnormalizedGradientForInstance(double[] weights, double[] contributionsToTheGradient,
SparseClassifierInstance instance) {
//could do this with a map for truly sparse instances...but
double postiveClassProbability = probabilityOfThePositiveClass(weights, instance);
Pair<int[], double[]> sparseAttributes = instance.getSparseAttributes();
int[] indices = sparseAttributes.getValue0();
double[] values = sparseAttributes.getValue1();
for (int i = 0; i < indices.length; i++) {
int featureIndex = indices[i];
contributionsToTheGradient[featureIndex] += gradientContributionOfAFeatureValue((Double) instance.getLabel(), postiveClassProbability, values[i]);
}
}
private double[] initializeWeights(int numFeatures) {
double[] weights = new double[numFeatures]; //presume normalized
Random random = new Random();
for (int i = 0; i < numFeatures; i++) {
weights[i] = random.nextDouble() * 1.0 - 0.5; //a random number between -0.25 and 0.25
}
return weights;
}
}