/* * File: MultivariateLinearRegressionTest.java * Authors: Kevin R. Dixon * Company: Sandia National Laboratories * Project: Cognitive Foundry * * Copyright Jun 22, 2011, Sandia Corporation. * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive * license for use of this work by or on behalf of the U.S. Government. * Export of this program may require a license from the United States * Government. See CopyrightHistory.txt for complete details. * */ package gov.sandia.cognition.learning.algorithm.regression; import gov.sandia.cognition.learning.data.DefaultWeightedInputOutputPair; import gov.sandia.cognition.learning.data.DefaultInputOutputPair; import gov.sandia.cognition.math.matrix.Matrix; import gov.sandia.cognition.math.matrix.MatrixFactory; import gov.sandia.cognition.math.matrix.VectorFactory; import java.util.LinkedList; import gov.sandia.cognition.learning.data.InputOutputPair; import gov.sandia.cognition.learning.function.vector.MultivariateDiscriminant; import gov.sandia.cognition.learning.function.vector.MultivariateDiscriminantWithBias; import gov.sandia.cognition.math.matrix.Vector; import java.util.Collection; import java.util.Random; import org.junit.Test; import static org.junit.Assert.*; /** * Tests for class MultivariateLinearRegressionTest. * @author krdixon */ public class MultivariateLinearRegressionTest { /** * Random number generator to use for a fixed random seed. */ public final Random RANDOM = new Random( 1 ); /** * Default tolerance of the regression tests, {@value}. */ public final double TOLERANCE = 1e-5; /** * Default number of samples to test against, {@value}. */ public final int NUM_SAMPLES = 1000; /** * Default Constructor */ public MultivariateLinearRegressionTest() { } /** * Tests the constructors of class MultivariateLinearRegressionTest. */ @Test public void testConstructors() { System.out.println( "Constructors" ); MultivariateLinearRegression instance = new MultivariateLinearRegression(); assertTrue( instance.getUsePseudoInverse() ); } /** * Tests the clone method of class MultivariateLinearRegressionTest. */ @Test public void testClone() { System.out.println( "Clone" ); MultivariateLinearRegression instance = new MultivariateLinearRegression(); instance.setUsePseudoInverse(false); MultivariateLinearRegression clone = instance.clone(); assertNotSame( instance, clone ); assertEquals( instance.getUsePseudoInverse(), clone.getUsePseudoInverse() ); } /** * Test of learn method, of class MultivariateLinearRegression. */ @Test public void testLearn() { System.out.println("learn"); int M = RANDOM.nextInt( 5 ) + 1; int N = RANDOM.nextInt( 5 ) + 1; double r = 1.0; Matrix A = MatrixFactory.getDefault().createUniformRandom( M, N, -r, r, RANDOM ); Vector bias = VectorFactory.getDefault().createUniformRandom( M, -r, r, RANDOM); MultivariateDiscriminantWithBias f = new MultivariateDiscriminantWithBias( A, bias ); int num = RANDOM.nextInt( 100 ) + (M*N); Collection<InputOutputPair<Vector,Vector>> dataset = new LinkedList<InputOutputPair<Vector,Vector>>(); for( int i = 0; i < num; i++ ) { Vector input = VectorFactory.getDefault().createUniformRandom( N, -r, r, RANDOM ); Vector output = f.evaluate( input ); dataset.add( new DefaultInputOutputPair<Vector,Vector>( input, output ) ); } MultivariateLinearRegression learner = new MultivariateLinearRegression(); learner.setUsePseudoInverse(true); MultivariateDiscriminantWithBias fhat = learner.learn( dataset ); System.out.println( "fhat: " + fhat.convertToVector() ); System.out.println( "f: " + f.convertToVector() ); assertTrue( A.equals( fhat.getDiscriminant(), 1e-5 ) ); learner.setUsePseudoInverse(false); fhat = learner.learn( dataset ); Vector p1 = fhat.convertToVector(); assertTrue( A.equals( fhat.getDiscriminant(), 1e-5 ) ); System.out.println( "p1: " + p1.norm2() ); learner.setRegularization(0.1); fhat = learner.learn( dataset ); Vector p2 = fhat.convertToVector(); System.out.println( "p2: " + p2.norm2() ); assertTrue( p1.norm2() > p2.norm2() ); learner.setRegularization(1.0); fhat = learner.learn( dataset ); Vector p3 = fhat.convertToVector(); System.out.println( "p3: " + p3.norm2() ); assertTrue( p2.norm2() > p3.norm2() ); } /** * weighted learn */ @Test public void testWeightedLearn() { System.out.println( "weightedLearn" ); int M = RANDOM.nextInt( 5 ) + 1; int N = RANDOM.nextInt( 5 ) + 1; double r = 1.0; Matrix A = MatrixFactory.getDefault().createUniformRandom( M, N, -r, r, RANDOM ); MultivariateDiscriminant f = new MultivariateDiscriminant( A ); int num = RANDOM.nextInt(100) + (M*N); Collection<InputOutputPair<Vector,Vector>> dataset = new LinkedList<InputOutputPair<Vector,Vector>>(); for( int i = 0; i < num; i++ ) { double weight = RANDOM.nextDouble(); Vector input = VectorFactory.getDefault().createUniformRandom( N, -r, r, RANDOM ); Vector output = f.evaluate( input ); dataset.add( DefaultWeightedInputOutputPair.create( input, output, weight ) ); } MultivariateLinearRegression learner = new MultivariateLinearRegression(); MultivariateDiscriminant fhat = learner.learn( dataset ); assertTrue( A.equals( fhat.getDiscriminant(), 1e-5 ) ); } }