package com.matrobot.gha.insights.ml; /** * Gradient descent for multiple variables * * @author Krzysztof Langner */ public class GradientDescentLinear{ private double alpha = 1; private double[] coefficients; /** * Override this function for different regression model (e.g logistic regression) * @param input * @return */ public double predict(double[] input) { return getLinearRegression(input); } protected double getLinearRegression(double[] input) { double sum = 0; for(int i = 0; i < coefficients.length && i < input.length+1; i++){ double x = (i==0)? 1: input[i-1]; sum += coefficients[i]*x; } return sum; } public void setAlpha(double alpha){ this.alpha = alpha; } public void train(Dataset dataset) { // long time = System.currentTimeMillis(); double[] tempCoeffs = new double[dataset.getFeatureCount()+1]; double maxGradient = 10; coefficients = new double[dataset.getFeatureCount()+1]; for(int i = 0; i < coefficients.length; i++){ coefficients[i] = 0; } double oldCost = calculateCost(dataset); while(maxGradient > .001){ maxGradient = 0; for(int i = 0; i < coefficients.length; i++){ double sum = 0; for(Sample sample : dataset.getData()){ double h = predict(sample.features); double x = (i==0)? 1 : sample.features[i-1]; sum += (h-sample.output)*x; } double gradient = sum/dataset.size(); tempCoeffs[i] = coefficients[i] - alpha*gradient; maxGradient = Math.max(maxGradient, Math.abs(gradient)); } for(int i = 0; i < coefficients.length; i++){ coefficients[i] = tempCoeffs[i]; } double newCost = calculateCost(dataset); if(newCost > oldCost){ System.out.println("Cost function incresing. Probably alpha too big"); // break; } else{ oldCost = newCost; } } // System.out.println("Learning time: " + (System.currentTimeMillis()-time)/1000); } /** * Calculate cost function on given inputs. * Override this function for different regression model (e.g logistic regression) */ protected double calculateCost(Dataset dataset) { double sum = 0; for(Sample sample : dataset.getData()){ double h = predict(sample.features); sum += Math.pow((h-sample.output), 2); } return sum; } public void printModel() { for(int i = 0; i < coefficients.length; i++){ System.out.println(i + ": " + coefficients[i]); } } }