package quickml.supervised.parametricModels;
import com.google.common.collect.Lists;
import it.unimi.dsi.fastutil.ints.Int2DoubleMap;
import it.unimi.dsi.fastutil.ints.Int2DoubleOpenHashMap;
import org.javatuples.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import quickml.data.instances.SparseRegressionInstance;
import java.io.Serializable;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import static quickml.MathUtils.cappedlogBase2;
import static quickml.MathUtils.sigmoid;
/**
* Created by alexanderhawk on 4/1/16.
*/
public class OptimizableCostFunctionImp implements OptimizableCostFunction<SparseRegressionInstance> {
private int executorThreadCount = Runtime.getRuntime().availableProcessors();
private ExecutorService executorService=Executors.newFixedThreadPool(executorThreadCount);
public static final String EXPECTED_FRACTION_OF_FEATURES_TO_UPDATE_PER_WORKER = "expectedFractionOfFeaturesToUpdatePerWorker";
public static final String EXECUTOR_THREAD_COUNT = "executorThreadCount";
public static final String MIN_INSTANCES_FOR_PARELLIZATION = "minInstancesForParrellization";
public static final String SPARSE_PARELLIZATION = "sparseParallelization";
public static final String OPTIMIZABLE_COST_FUNCTION = "optimizableCostFunction";
public static final String MAX_GRADIENT_NORM = "maxGradientNorm";
public static final String RIDGE = "ridge";
public static final String LASSO = "lasso";
double ridge = 0;
double lasso = 0;
private double expectedFractionOfFeaturesToUpdatePerWorker = 1.0;
private int minInstancesForParrellization = 100;
private boolean sparseParallelization = true;
private double maxGradientNorm = Double.MAX_VALUE;
private static final Logger logger = LoggerFactory.getLogger(OptimizableCostFunctionImp.class);
public OptimizableCostFunctionImp executorThreadCount(int executorThreadCount) {
if (executorThreadCount < this.executorThreadCount) {
this.executorThreadCount = executorThreadCount;
} else {
logger.warn("can't use more executors than cores");
}
return this;
}
public OptimizableCostFunctionImp maxGradientNorm(double maxGradientNorm) {
this.maxGradientNorm = maxGradientNorm;
return this;
}
public OptimizableCostFunctionImp minInstancesForParrellization(int minInstancesForParrellization) {
this.minInstancesForParrellization = minInstancesForParrellization;
return this;
}
public OptimizableCostFunctionImp expectedFractionOfFeaturesToUpdatePerWorker(double expectedFractionOfFeaturesToUpdatePerWorker) {
this.expectedFractionOfFeaturesToUpdatePerWorker = expectedFractionOfFeaturesToUpdatePerWorker;
return this;
}
public OptimizableCostFunctionImp sparseParallelization(boolean sparseParallelization) {
this.sparseParallelization = sparseParallelization;
return this;
}
public OptimizableCostFunctionImp ridgeRegularizationConstant(final double ridgeRegularizationConstant) {
this.ridge = ridgeRegularizationConstant;
return this;
}
public OptimizableCostFunctionImp lassoRegularizationConstant(final double ridgeRegularizationConstant) {
this.lasso = ridgeRegularizationConstant;
return this;
}
public static double probabilityOfTheNegativeClass(double[] weights, SparseRegressionInstance instance) {
return 1.0 - probabilityOfThePositiveClass(weights, instance);
}
public static double probabilityOfThePositiveClass(double[] weights, SparseRegressionInstance instance) {
return sigmoid(instance.dotProduct(weights));
}
public void updateBuilderConfig(final Map<String, Serializable> config) {
if (config.containsKey(SPARSE_PARELLIZATION)) {
sparseParallelization((Boolean) config.get(SPARSE_PARELLIZATION));
}
if (config.containsKey(MIN_INSTANCES_FOR_PARELLIZATION)) {
minInstancesForParrellization((Integer) config.get(MIN_INSTANCES_FOR_PARELLIZATION));
}
if (config.containsKey(EXECUTOR_THREAD_COUNT)) {
executorThreadCount((Integer) config.get(EXECUTOR_THREAD_COUNT));
}
if (config.containsKey(EXPECTED_FRACTION_OF_FEATURES_TO_UPDATE_PER_WORKER)) {
expectedFractionOfFeaturesToUpdatePerWorker((Double) config.get(EXPECTED_FRACTION_OF_FEATURES_TO_UPDATE_PER_WORKER));
}
if (config.containsKey(MAX_GRADIENT_NORM)) {
maxGradientNorm((Double) config.get(MAX_GRADIENT_NORM));
}
if (config.containsKey(LASSO)) {
ridgeRegularizationConstant((Double) config.get(LASSO));
}
if (config.containsKey(RIDGE)) {
lassoRegularizationConstant((Double) config.get(RIDGE));
}
}
@Override
public double computeCost(List<? extends SparseRegressionInstance> instances, double[] weights, double minPredictedProbablity) {
double cost = 0.0;
for (SparseRegressionInstance instance : instances) {
if ((double) instance.getLabel() == 1.0) {
cost += -cappedlogBase2(probabilityOfThePositiveClass(weights, instance), minPredictedProbablity);
} else if ((double) instance.getLabel() == 0.0) {
cost += -cappedlogBase2(probabilityOfTheNegativeClass(weights, instance), minPredictedProbablity);
}
}
cost += getRegularizationCost(weights, ridge, lasso);
cost /= instances.size();
return cost;
}
@Override
public void updateGradient(final List<? extends SparseRegressionInstance> sparseClassifierInstances, final double[] fixedWeights, final double[] gradient) {
int currentMiniBatchSize = sparseClassifierInstances.size();
final int[] threadStartAndStopIndices = getThreadStartIndices(currentMiniBatchSize, executorThreadCount, minInstancesForParrellization);
int actualNumThreads = threadStartAndStopIndices.length - 1;
if (sparseParallelization) {
sparseCalculationOfGradient(sparseClassifierInstances, fixedWeights, gradient, threadStartAndStopIndices, actualNumThreads);
} else {
nonSparseCalculationOfGradient(sparseClassifierInstances, fixedWeights, gradient, threadStartAndStopIndices, actualNumThreads);
}
addRegularizationComponentOfTheGradient(fixedWeights, gradient, ridge, lasso);
normalizeTheGradient(currentMiniBatchSize, maxGradientNorm, gradient);
}
private void sparseCalculationOfGradient(final List<? extends SparseRegressionInstance> sparseClassifierInstances, final double[] fixedWeights, double[] gradient, final int[] threadStartAndStopIndices, int actualNumThreads) {
List<Future<Int2DoubleOpenHashMap>> contributionsToTheGradient = Lists.newArrayListWithCapacity(actualNumThreads);
for (int i = 0; i < actualNumThreads; i++) {
final int index = i;
contributionsToTheGradient.add(executorService.submit(new Callable<Int2DoubleOpenHashMap>() {
@Override
public Int2DoubleOpenHashMap call() throws Exception {
expectedFractionOfFeaturesToUpdatePerWorker = 1.0;
try {
Int2DoubleOpenHashMap sparseWorkerContributionToTheGradient = getSparseWorkerContributionToTheGradient(sparseClassifierInstances.subList(threadStartAndStopIndices[index], threadStartAndStopIndices[index + 1]), fixedWeights, expectedFractionOfFeaturesToUpdatePerWorker);
return sparseWorkerContributionToTheGradient;
} catch (IllegalArgumentException e) {
logger.info("what?");
throw new RuntimeException(e);
}
}
}));
}
sparseReductionToTheGradient(gradient, contributionsToTheGradient);
}
private void nonSparseCalculationOfGradient(final List<? extends SparseRegressionInstance> sparseClassifierInstances, final double[] fixedWeights, double[] gradient, final int[] threadStartAndStopIndices, int actualNumThreads) {
List<Future<double[]>> contributionsToTheGradient = Lists.newArrayListWithCapacity(actualNumThreads);
for (int i = 0; i < actualNumThreads; i++) {
final int index = i;
contributionsToTheGradient.add(executorService.submit(new Callable<double[]>() {
@Override
public double[] call() throws Exception {
expectedFractionOfFeaturesToUpdatePerWorker = 1.0;
return getWorkerContributionToTheGradient(sparseClassifierInstances.subList(threadStartAndStopIndices[index], threadStartAndStopIndices[index + 1]), fixedWeights);
}
}));
}
reductionToTheGradient(gradient, contributionsToTheGradient);
}
static void sparseUpdateUnnormalizedGradientForInstance(double[] weights, Int2DoubleOpenHashMap contributionsToTheGradient,
SparseRegressionInstance instance) {
//could do this with a map for truly sparse instances...but
double postiveClassProbability = probabilityOfThePositiveClass(weights, instance);
Pair<int[], double[]> sparseAttributes = instance.getSparseAttributes();
int[] indices = sparseAttributes.getValue0();
double[] values = sparseAttributes.getValue1();
for (int i = 0; i < indices.length; i++) {
int featureIndex = indices[i];
contributionsToTheGradient.addTo(featureIndex, gradientContributionOfAFeatureValue((Double) instance.getLabel(), postiveClassProbability, values[i]));
}
}
private static double gradientContributionOfAFeatureValue(double label, double postiveClassProbability, double value) {
return -(label - postiveClassProbability) * value;
}
static void updateUnnormalizedGradientForInstance(double[] weights, double[] contributionsToTheGradient,
SparseRegressionInstance instance) {
//could do this with a map for truly sparse instances...but
double postiveClassProbability = probabilityOfThePositiveClass(weights, instance);
Pair<int[], double[]> sparseAttributes = instance.getSparseAttributes();
int[] indices = sparseAttributes.getValue0();
double[] values = sparseAttributes.getValue1();
for (int i = 0; i < indices.length; i++) {
int featureIndex = indices[i];
contributionsToTheGradient[featureIndex] += gradientContributionOfAFeatureValue((Double) instance.getLabel(), postiveClassProbability, values[i]);
}
}
public Int2DoubleOpenHashMap getSparseWorkerContributionToTheGradient(List<? extends SparseRegressionInstance> instances, double[] weights, double expectedFractionOfFeaturesToUpdate) {
Int2DoubleOpenHashMap contributionsToTheGradient = new Int2DoubleOpenHashMap((int) (expectedFractionOfFeaturesToUpdate * weights.length));
contributionsToTheGradient.defaultReturnValue(0.0);
for (SparseRegressionInstance instance : instances) {
sparseUpdateUnnormalizedGradientForInstance(weights, contributionsToTheGradient, instance);
}
return contributionsToTheGradient;
}
public static void sparseReductionToTheGradient(double[] gradient, List<Future<Int2DoubleOpenHashMap>> contributions) {
for (Future<Int2DoubleOpenHashMap> contribution : contributions) {
addSparseContribution(gradient, contribution);
}
}
public static void addSparseContribution(double[] gradient, Future<Int2DoubleOpenHashMap> contributionFuture) {
try {
Int2DoubleOpenHashMap contribution = contributionFuture.get();
for (Int2DoubleMap.Entry entry : contribution.int2DoubleEntrySet()) {
gradient[entry.getKey()] += entry.getValue();
}
} catch (Exception e) {
e.printStackTrace();
throw new RuntimeException(e);
}
}
public static double[] getWorkerContributionToTheGradient(List<? extends SparseRegressionInstance> instances, double[] weights) {
double[] contributionsToTheGradient = new double[weights.length];
for (SparseRegressionInstance instance : instances) {
updateUnnormalizedGradientForInstance(weights, contributionsToTheGradient, instance);
}
return contributionsToTheGradient;
}
public static void reductionToTheGradient(double[] gradient, List<Future<double[]>> contributions) {
for (Future<double[]> contribution : contributions) {
addContribution(gradient, contribution);
}
}
public static void addRegularizationComponentOfTheGradient(double[] weights, double[] gradient, double ridge, double lasso) {
for (int i = 1; i < weights.length; i++) {//start at 1 to skip the bias term
double lassoDerivative = lasso;
if (weights[i] < 0.0) {
lassoDerivative *= -1;
}
gradient[i] += ridge * weights[i] + lassoDerivative;
}
}
public static void addContribution(double[] gradient, Future<double[]> contributionFuture) {
try {
double[] contribution = contributionFuture.get();
for (int i = 0; i< gradient.length; i++) {
gradient[i] += contribution[i];
}
} catch (Exception e) {
e.printStackTrace();
throw new RuntimeException(e);
}
}
public static double getRegularizationCost(double[] weights, double ridge, double lasso) {
double cost = 0;
for (int i = 0; i < weights.length; i++) {
cost += weights[i] * weights[i] * ridge / 2.0 + Math.abs(weights[i]) * lasso;
}
return cost;
}
public static void normalizeTheGradient(int minibatchSize, double maxGradientNorm, double[] gradient) {
for (int i = 1; i < gradient.length; i++) {
gradient[i] /= minibatchSize;
}
if (maxGradientNorm != Double.MAX_VALUE) {
applyMaxGradientNorm(maxGradientNorm, gradient);
}
}
public static void applyMaxGradientNorm(double maxGradientNorm, double[] gradient) {
double gradientSumOfSquares = 0;
for (double g : gradient) {
gradientSumOfSquares += Math.pow(g, 2);
}
double gradientNorm = Math.sqrt(gradientSumOfSquares);
if (gradientNorm > maxGradientNorm) {
double n = gradientNorm / maxGradientNorm;
for (int i = 0; i < gradient.length; i++) {
gradient[i] = gradient[i] / Math.sqrt(n);
}
}
}
//this methods seems wack
public int[] getThreadStartIndices(int numInstances, int actualNumThreads, int minInstancesForParrallization) {
if (actualNumThreads < minInstancesForParrallization) {
int[] threadStartIndices = new int[2];
threadStartIndices[0] = 0;
threadStartIndices[1] = actualNumThreads;
return threadStartIndices;
} else if (actualNumThreads <= executorThreadCount) {
actualNumThreads = actualNumThreads;
int[] threadStartIndices = new int[actualNumThreads+1];
for (int i = 0; i < actualNumThreads; i++) {
threadStartIndices[i] = i;
}
threadStartIndices[actualNumThreads] =actualNumThreads; //could be put in loop but follow the convention of putting final stop index outside
return threadStartIndices;
}
int[] threadStartIndices = new int[executorThreadCount + 1];
int lowerSamplesPerThread = numInstances / executorThreadCount;
int upperSamplesPerThread = numInstances / executorThreadCount + 1;
int remainder = numInstances % executorThreadCount;
int currentStartIndex = 0;
for (int i = 0; i < executorThreadCount; i++) {
threadStartIndices[i] = currentStartIndex;
if (i >= executorThreadCount - remainder) {
currentStartIndex += upperSamplesPerThread;
} else {
currentStartIndex += lowerSamplesPerThread;
}
}
threadStartIndices[executorThreadCount] = numInstances;
return threadStartIndices;
}
@Override
public void shutdown(){
executorService.shutdown();
}
}