package quickml.supervised.regressionModel.LinearRegression2; /** * Created by alexanderhawk on 10/12/15. */ import com.google.common.collect.Lists; import org.ejml.alg.dense.decomposition.chol.CholeskyDecompositionBlock; import org.ejml.alg.dense.decomposition.chol.CholeskyDecompositionCommon; import org.ejml.alg.dense.linsol.chol.LinearSolverChol; import org.ejml.data.D1Matrix64F; import org.ejml.data.DenseMatrix64F; import quickml.data.instances.RegressionInstance; import quickml.data.instances.SparseRegressionInstance; import quickml.supervised.PredictiveModelBuilder; import quickml.supervised.classifier.logisticRegression.InstanceTransformerUtils; import java.io.Serializable; import java.util.ArrayList; import java.util.List; import java.util.Map; import org.ejml.alg.dense.linsol.*; import static org.ejml.ops.CommonOps.*; /** * Created by alexanderhawk on 10/9/15. */ public class SimpleRidgeRegressionBuilder<T extends RegressionInstance> implements PredictiveModelBuilder<LinearModel, RegressionInstance> { public static final String MIN_OBSERVATIONS_OF_ATTRIBUTE= "minObservationsOfAttribute"; public static final String RIDGE_REGULARIZATION_CONSTANT = "ridgeRegularizationConstant"; public static final String USE_BIAS = "useBias"; private int minObservationsOfAttribute; private double ridgeRegularizationConstant; private boolean useBias = true; public SimpleRidgeRegressionBuilder<T> minObservationsOfAttribute(int minObservationsOfAttribute) { this.minObservationsOfAttribute = minObservationsOfAttribute; return this; } public SimpleRidgeRegressionBuilder<T> ridgeRegularizationConstant(double ridgeConstant) { this.ridgeRegularizationConstant = ridgeConstant; return this; } public SimpleRidgeRegressionBuilder<T> useBias(boolean useBias) { this.useBias = useBias; return this; } @Override public LinearModel buildPredictiveModel(Iterable<RegressionInstance> trainingData) { List<RegressionInstance> trainingDataList = Lists.newArrayList(trainingData); Map<String, Integer> nameToIndexMap = InstanceTransformerUtils.populateNameToIndexMap(trainingDataList, useBias); int numVariables = nameToIndexMap.size(); double[][] data = new double[trainingDataList.size()][numVariables]; double[][] responseArray = new double[trainingDataList.size()][1]; for (int row = 0; row < trainingDataList.size(); row++) { RegressionInstance regressionInstance = trainingDataList.get(row); data[row] = SparseRegressionInstance.getArrayOfValues(regressionInstance, nameToIndexMap, useBias); responseArray[row][0] = regressionInstance.getLabel(); } DenseMatrix64F dataMatrix = new DenseMatrix64F(data); DenseMatrix64F dataMatrixTranspose = getTranspose(dataMatrix); DenseMatrix64F symmetricMatrix=getSymmetricMatrix(numVariables, dataMatrix, dataMatrixTranspose); DenseMatrix64F response = new DenseMatrix64F(responseArray); // transpose(response); LinearSolverChol linearSolverChol = new LinearSolverChol(new CholeskyDecompositionBlock(numVariables));// new CholeskyDecompositionCommon(true)); linearSolverChol.setA(symmetricMatrix); DenseMatrix64F dataMatrixTransposeTimesResponse = getDataMatrixTransposeTimesResponse(numVariables, dataMatrixTranspose, response); DenseMatrix64F coefficients = new DenseMatrix64F(numVariables); linearSolverChol.solve(dataMatrixTransposeTimesResponse, coefficients); return new LinearModel(coefficients.getData(), nameToIndexMap, useBias); } private DenseMatrix64F getDataMatrixTransposeTimesResponse(int numVariables, DenseMatrix64F dataMatrixTranspose, DenseMatrix64F response) { DenseMatrix64F multipliedResponse = new DenseMatrix64F(numVariables,1); mult(dataMatrixTranspose, response, multipliedResponse); return multipliedResponse; } private DenseMatrix64F getSymmetricMatrix(int numVariables, DenseMatrix64F dataMatrix, DenseMatrix64F dataMatrixTranspose) { DenseMatrix64F symmetricMatrix = new DenseMatrix64F(numVariables, numVariables); mult(dataMatrixTranspose, dataMatrix, symmetricMatrix); for (int i = 0; i<dataMatrix.getNumCols(); i++) { double diagonalElement = ridgeRegularizationConstant + symmetricMatrix.get(i, i); symmetricMatrix.set(i, i, diagonalElement); } return symmetricMatrix; } private DenseMatrix64F getTranspose(DenseMatrix64F dataMatrix) { DenseMatrix64F dataMatrixTranspose = dataMatrix.copy(); transpose(dataMatrixTranspose); return dataMatrixTranspose; } @Override public void updateBuilderConfig(final Map<String, Serializable> config) { if (config.containsKey(MIN_OBSERVATIONS_OF_ATTRIBUTE)) { minObservationsOfAttribute((Integer) config.get(MIN_OBSERVATIONS_OF_ATTRIBUTE)); } if (config.containsKey(RIDGE_REGULARIZATION_CONSTANT)) { ridgeRegularizationConstant((Double) config.get(RIDGE_REGULARIZATION_CONSTANT)); } if (config.containsKey(USE_BIAS)) { useBias((Boolean) config.get(USE_BIAS)); } } }