package quickml.supervised.regressionModel.LinearRegression; import com.google.common.collect.Iterables; import org.apache.commons.math3.linear.Array2DRowRealMatrix; import org.apache.commons.math3.linear.DiagonalMatrix; import org.apache.commons.math3.linear.RealMatrix; import org.apache.commons.math3.linear.SingularValueDecomposition; import org.javatuples.Pair; import quickml.data.instances.Instance; import quickml.supervised.PredictiveModelBuilder; import quickml.data.instances.RidgeInstance; import java.io.Serializable; import java.util.Map; /** * Created by alexanderhawk on 8/14/14. */ public class RidgeLinearModelBuilder implements PredictiveModelBuilder< RidgeLinearModel, RidgeInstance> { public static final String REGULARIZATION_CONSTANT = "regularizationConstant"; public static final String INCLUDE_BIAS_TERM = "includeBiasTerm"; private double regularizationConstant = 0; private Iterable<? extends Instance<double[], Serializable>> trainingData; private boolean includeBiasTerm = false; private int collumnsInDataMatrix = 0; private String[] header; @Override public void updateBuilderConfig(Map<String, Serializable> cfg) { if (cfg.containsKey(REGULARIZATION_CONSTANT)) regularizationConstant((Double) cfg.get(REGULARIZATION_CONSTANT)); if (cfg.containsKey(INCLUDE_BIAS_TERM)) includeBiasTerm((Boolean) cfg.get(INCLUDE_BIAS_TERM)); } public RidgeLinearModelBuilder() { } public RidgeLinearModelBuilder regularizationConstant(double regularizationConstant) { this.regularizationConstant = regularizationConstant; return this; } public RidgeLinearModelBuilder includeBiasTerm(boolean includeBiasTerm) { this.includeBiasTerm = includeBiasTerm; return this; } public RidgeLinearModelBuilder header(String[] header) { this.header = header; return this; } @Override public RidgeLinearModel buildPredictiveModel(Iterable<RidgeInstance> trainingData) { //compute modelCoefficients = (X^t * X + regularizationConstant*IdentityMatrix)^-1 * X^t * labels, where X is the data matrix this.trainingData = trainingData; collumnsInDataMatrix = (includeBiasTerm) ? header.length + 1 : header.length; Pair<RealMatrix, double[]> dataMatrixLabelsPair = createDataMatrixLabelsPair(trainingData); RealMatrix dataMatrix = dataMatrixLabelsPair.getValue0(); double[] labels = dataMatrixLabelsPair.getValue1(); RealMatrix dataMatrixTranspose = dataMatrix.transpose(); RealMatrix identityMatrixTimesRegularizationConstant = getIdentiytMatrixTimesRegularizationConstant(); //log this out RealMatrix dataMatrixTransposeTimesDataMatrix = dataMatrixTranspose.multiply(dataMatrix); RealMatrix matrixToInvert = dataMatrixTransposeTimesDataMatrix.add(identityMatrixTimesRegularizationConstant); RealMatrix invertedMatrix = new SingularValueDecomposition(matrixToInvert).getSolver().getInverse(); //mult on right by X^t, then by Y double[] modelCoefficients = (invertedMatrix.multiply(dataMatrixTranspose)).operate(labels); return new RidgeLinearModel(modelCoefficients, header, includeBiasTerm); } private void printMatrix(RealMatrix matrix) { for (int i = 0; i < matrix.getRowDimension(); i++) { for (int j = 0; j < matrix.getColumnDimension(); j++) { System.out.print(matrix.getEntry(i, j) + " "); } System.out.print("\n"); } } private RealMatrix getIdentiytMatrixTimesRegularizationConstant() { RealMatrix identityMatrixTimesRegularizationConstant = new DiagonalMatrix(collumnsInDataMatrix); for (int i = 0; i < collumnsInDataMatrix; i++) { identityMatrixTimesRegularizationConstant.setEntry(i, i, regularizationConstant); } return identityMatrixTimesRegularizationConstant; } private Pair createDataMatrixLabelsPair(Iterable<? extends Instance<double[], Serializable>> trainingData) { RealMatrix dataMatrix = new Array2DRowRealMatrix(Iterables.size(trainingData), collumnsInDataMatrix); double[] labels = new double[Iterables.size(trainingData)]; int row = 0; for (Instance<double[], Serializable> instance : trainingData) { labels[row] = (Double) instance.getLabel(); double[] attributes = instance.getAttributes(); int oneIfUsingBiasTerm = 0; if (includeBiasTerm) { dataMatrix.setEntry(row, 0, 1.0); oneIfUsingBiasTerm = 1; } for (int i = 0; i < attributes.length; i++) { dataMatrix.setEntry(row, i + oneIfUsingBiasTerm, attributes[i]); } row++; } return new Pair<>(dataMatrix, labels); } }