/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.ignite.ml.regressions; import org.apache.ignite.ml.math.Matrix; import org.apache.ignite.ml.math.Vector; import org.apache.ignite.ml.math.exceptions.MathIllegalArgumentException; import org.apache.ignite.ml.math.exceptions.NullArgumentException; import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix; import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; import org.junit.Assert; import org.junit.Before; import org.junit.Test; /** * This class is based on the corresponding class from Apache Common Math lib. * Abstract base class for implementations of {@link MultipleLinearRegression}. */ public abstract class AbstractMultipleLinearRegressionTest { /** */ protected AbstractMultipleLinearRegression regression; /** */ @Before public void setUp() { regression = createRegression(); } /** */ protected abstract AbstractMultipleLinearRegression createRegression(); /** */ protected abstract int getNumberOfRegressors(); /** */ protected abstract int getSampleSize(); /** */ @Test public void canEstimateRegressionParameters() { double[] beta = regression.estimateRegressionParameters(); Assert.assertEquals(getNumberOfRegressors(), beta.length); } /** */ @Test public void canEstimateResiduals() { double[] e = regression.estimateResiduals(); Assert.assertEquals(getSampleSize(), e.length); } /** */ @Test public void canEstimateRegressionParametersVariance() { Matrix var = regression.estimateRegressionParametersVariance(); Assert.assertEquals(getNumberOfRegressors(), var.rowSize()); } /** */ @Test public void canEstimateRegressandVariance() { if (getSampleSize() > getNumberOfRegressors()) { double variance = regression.estimateRegressandVariance(); Assert.assertTrue(variance > 0.0); } } /** * Verifies that newSampleData methods consistently insert unitary columns * in design matrix. Confirms the fix for MATH-411. */ @Test public void testNewSample() { double[] design = new double[] { 1, 19, 22, 33, 2, 20, 30, 40, 3, 25, 35, 45, 4, 27, 37, 47 }; double[] y = new double[] {1, 2, 3, 4}; double[][] x = new double[][] { {19, 22, 33}, {20, 30, 40}, {25, 35, 45}, {27, 37, 47} }; AbstractMultipleLinearRegression regression = createRegression(); regression.newSampleData(design, 4, 3, new DenseLocalOnHeapMatrix()); Matrix flatX = regression.getX().copy(); Vector flatY = regression.getY().copy(); regression.newXSampleData(new DenseLocalOnHeapMatrix(x)); regression.newYSampleData(new DenseLocalOnHeapVector(y)); Assert.assertEquals(flatX, regression.getX()); Assert.assertEquals(flatY, regression.getY()); // No intercept regression.setNoIntercept(true); regression.newSampleData(design, 4, 3, new DenseLocalOnHeapMatrix()); flatX = regression.getX().copy(); flatY = regression.getY().copy(); regression.newXSampleData(new DenseLocalOnHeapMatrix(x)); regression.newYSampleData(new DenseLocalOnHeapVector(y)); Assert.assertEquals(flatX, regression.getX()); Assert.assertEquals(flatY, regression.getY()); } /** */ @Test(expected = NullArgumentException.class) public void testNewSampleNullData() { double[] data = null; createRegression().newSampleData(data, 2, 3, new DenseLocalOnHeapMatrix()); } /** */ @Test(expected = MathIllegalArgumentException.class) public void testNewSampleInvalidData() { double[] data = new double[] {1, 2, 3, 4}; createRegression().newSampleData(data, 2, 3, new DenseLocalOnHeapMatrix()); } /** */ @Test(expected = MathIllegalArgumentException.class) public void testNewSampleInsufficientData() { double[] data = new double[] {1, 2, 3, 4}; createRegression().newSampleData(data, 1, 3, new DenseLocalOnHeapMatrix()); } /** */ @Test(expected = NullArgumentException.class) public void testXSampleDataNull() { createRegression().newXSampleData(null); } /** */ @Test(expected = NullArgumentException.class) public void testYSampleDataNull() { createRegression().newYSampleData(null); } }