package quickml.supervised.regressionModel; import com.beust.jcommander.internal.Lists; import junit.framework.Assert; import org.junit.Before; import org.junit.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import quickml.data.AttributesMap; import quickml.data.instances.Instance; import quickml.data.instances.RegressionInstance; import quickml.data.instances.RidgeInstance; import quickml.supervised.regressionModel.LinearRegression.RidgeLinearModel; import quickml.supervised.regressionModel.LinearRegression.RidgeLinearModelBuilder; import quickml.supervised.regressionModel.LinearRegression2.LinearModel; import quickml.supervised.regressionModel.LinearRegression2.SimpleRidgeRegressionBuilder; import java.io.Serializable; import java.util.ArrayList; import java.util.List; /** * Created by alexanderhawk on 8/15/14. */ public class RidgeRegressionBuilderTest { private double regularizationConstant = 0.0; RidgeLinearModelBuilder ridgeLinearModelBuilder; SimpleRidgeRegressionBuilder<RegressionInstance> simpleRidgeRegressionBuilder = new SimpleRidgeRegressionBuilder<>().useBias(true).ridgeRegularizationConstant(regularizationConstant); final Logger logger = LoggerFactory.getLogger(RidgeRegressionBuilderTest.class); String [] header = {"temperature"}; private List<RidgeInstance> trainingData; private List<RegressionInstance> regTrainingData; @Before public void setUp() { ridgeLinearModelBuilder = new RidgeLinearModelBuilder().header(header).includeBiasTerm(true).regularizationConstant(regularizationConstant); trainingData = new ArrayList<>(); trainingData.add(new RidgeInstance(new double[]{20.0}, 88.6)); trainingData.add(new RidgeInstance(new double[]{16.0}, 71.6)); trainingData.add(new RidgeInstance(new double[]{19.8}, 93.3)); trainingData.add(new RidgeInstance(new double[]{18.4}, 84.3)); trainingData.add(new RidgeInstance(new double[]{17.1}, 80.6)); trainingData.add(new RidgeInstance(new double[]{15.5}, 75.2)); trainingData.add(new RidgeInstance(new double[]{14.7}, 69.7)); trainingData.add(new RidgeInstance(new double[]{15.7}, 71.6)); trainingData.add(new RidgeInstance(new double[]{15.4}, 69.4)); trainingData.add(new RidgeInstance(new double[]{16.3}, 83.3)); trainingData.add(new RidgeInstance(new double[]{15.0}, 79.6)); trainingData.add(new RidgeInstance(new double[]{17.2}, 82.6)); trainingData.add(new RidgeInstance(new double[]{16.0}, 80.6)); trainingData.add(new RidgeInstance(new double[]{17.0}, 83.5)); trainingData.add(new RidgeInstance(new double[]{14.4}, 76.3)); regTrainingData = ridgeInstancesToRegressionInstances(trainingData); } private List<RegressionInstance> ridgeInstancesToRegressionInstances(List<RidgeInstance> ridgeInstances) { List<RegressionInstance> regressionInstances = Lists.newArrayList(); String attr = "attr"; for (RidgeInstance ridgeInstance: ridgeInstances) { AttributesMap attributesMap = AttributesMap.newHashMap(); attributesMap.put(attr, ridgeInstance.getAttributes()[0]); regressionInstances.add(new RegressionInstance(attributesMap, (Double)ridgeInstance.getLabel())); } return regressionInstances; } @Test public void simpleRidgeRegressionBuilderTest (){ LinearModel linearModel = simpleRidgeRegressionBuilder.buildPredictiveModel(regTrainingData); double pythonRMSE = Math.sqrt(212.32/trainingData.size()); double pythonEpsilon = pythonRMSE/25.0; double mse = 0; for (RegressionInstance instance : regTrainingData) { AttributesMap attributesMap = instance.getAttributes(); logger.info("prediction " + linearModel.predict(attributesMap) + ". label: " + instance.getLabel()); mse+= Math.pow(linearModel.predict(attributesMap) - (Double)instance.getLabel(), 2); logger.info("un-normalized mse " + mse); } mse/=trainingData.size(); double RMSE = Math.sqrt(mse); logger.info("rmse_per_test_instance " + RMSE + " Python rmse: "+pythonRMSE); Assert.assertTrue("mse "+ RMSE + "python mse" + pythonRMSE, RMSE < pythonRMSE + pythonEpsilon); } @Test public void ridgeRegressionBuilderTest (){ RidgeLinearModel ridgeLinearModel = ridgeLinearModelBuilder.buildPredictiveModel(trainingData); double pythonRMSE = Math.sqrt(212.32/trainingData.size()); double pythonEpsilon = pythonRMSE/25.0; double mse = 0; for (Instance<double[], Serializable> instance : trainingData) { double [] x = instance.getAttributes(); logger.info("prediction " + ridgeLinearModel.predict(x) + ". label: " + instance.getLabel()); mse+= Math.pow(ridgeLinearModel.predict(x) - (Double)instance.getLabel(), 2); logger.info("un-normalized mse " + mse); } mse/=trainingData.size(); double RMSE = Math.sqrt(mse); logger.info("mse_per_test_instance " + mse); Assert.assertTrue("mse "+ RMSE + "python mse" + pythonRMSE, RMSE < pythonRMSE + pythonEpsilon); } @Test //TODO[mk] updateBuilderConfig this test public void ridgePMOTest() { // CrossValidator<double[] , Serializable, Double> crossValidator = new StationaryCrossValidatorBuilder().setFolds(4).setLossFunction(new SingleVariableRealValuedFunctionMSECVLossFunction()).createCrossValidator(); // RidgeLinearModelBuilderFactory ridgeLinearModelBuilderFactory = new RidgeLinearModelBuilderFactory().header(header).includeBiasTerm(true).regularizationConstants(new FixedOrderRecommender(0.001, 0.01, 0.1)); // PredictiveModelOptimizer<double[], Serializable, Double, RidgeLinearModel, RidgeLinearModelBuilder> predictiveModelOptimizer = new PredictiveModelOptimizer<>(ridgeLinearModelBuilderFactory, trainingData, crossValidator); // Map<String, Object> optimalParams = predictiveModelOptimizer.determineOptimalConfiguration(); // for (String key : optimalParams.keySet()) // logger.info(key+ " : " + optimalParams.get(key)); } }