/* * File: KernelPrincipalComponentsAnalysisTest.java * Authors: Justin Basilico * Company: Sandia National Laboratories * Project: Cognitive Foundry Learning Core * * Copyright December 21, 2010, Sandia Corporation. * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive * license for use of this work by or on behalf of the U.S. Government. Export * of this program may require a license from the United States Government. */ package gov.sandia.cognition.learning.algorithm.pca; import gov.sandia.cognition.learning.function.kernel.Kernel; import gov.sandia.cognition.learning.function.kernel.LinearKernel; import gov.sandia.cognition.learning.function.kernel.PolynomialKernel; import gov.sandia.cognition.math.MultivariateStatisticsUtil; import gov.sandia.cognition.math.matrix.Vector; import gov.sandia.cognition.math.matrix.VectorFactory; import gov.sandia.cognition.math.matrix.decomposition.SingularValueDecomposition; import gov.sandia.cognition.math.matrix.mtj.DenseMatrix; import gov.sandia.cognition.math.matrix.mtj.DenseMatrixFactoryMTJ; import gov.sandia.cognition.math.matrix.mtj.Vector2; import gov.sandia.cognition.math.matrix.mtj.decomposition.SingularValueDecompositionMTJ; import java.util.ArrayList; import java.util.Random; import junit.framework.TestCase; /** * Unit tests for class KernelPrincipalComponentsAnalysis. * * @author Justin Basilico * @since 3.1 */ public class KernelPrincipalComponentsAnalysisTest extends TestCase { /** * Creates a new test. * * @param testName The test name. */ public KernelPrincipalComponentsAnalysisTest( String testName) { super(testName); } /** The random number generator for the tests. */ Random random = new Random(4); public int INPUT_DIM = random.nextInt(3) + 3; public int OUTPUT_DIM = random.nextInt(2) + 1; /** * Test of constructors of class KernelPrincipalComponentsAnalysis. */ public void testConstants() { assertTrue(KernelPrincipalComponentsAnalysis.DEFAULT_COMPONENT_COUNT > 0); assertTrue(KernelPrincipalComponentsAnalysis.DEFAULT_CENTER_DATA); } /** * Test of constructors of class KernelPrincipalComponentsAnalysis. */ public void testConstructors() { Kernel<? super Vector> kernel = null; int componentCount = KernelPrincipalComponentsAnalysis.DEFAULT_COMPONENT_COUNT; boolean centerData = KernelPrincipalComponentsAnalysis.DEFAULT_CENTER_DATA; KernelPrincipalComponentsAnalysis<Vector> instance = new KernelPrincipalComponentsAnalysis<Vector>(); assertSame(kernel, instance.getKernel()); assertEquals(componentCount, instance.getComponentCount()); kernel = new PolynomialKernel(4, 3.0); componentCount = random.nextInt(1000); instance = new KernelPrincipalComponentsAnalysis<Vector>(kernel, componentCount); assertSame(kernel, instance.getKernel()); assertEquals(componentCount, instance.getComponentCount()); } /** * Test of learn method, of class KernelPrincipalComponentsAnalysis. */ public void testLearn() { System.out.println("PCA.learn"); int num = random.nextInt(100) + 10; ArrayList<Vector> data = new ArrayList<Vector>(num); final double r1 = random.nextDouble(); final double r2 = r1 / random.nextDouble(); for (int i = 0; i < num; i++) { data.add(VectorFactory.getDefault().createUniformRandom(INPUT_DIM, r1, r2, random)); } System.out.println("Data:"); for (Vector item : data) { System.out.println(item.getElement(0) +" "+item.getElement(1)); } Vector mean = MultivariateStatisticsUtil.computeMean(data); System.out.println("Mean: " + mean); DenseMatrix X = DenseMatrixFactoryMTJ.INSTANCE.createMatrix(INPUT_DIM, num); for (int n = 0; n < num; n++) { X.setColumn(n, data.get(n).minus(mean)); } long startsvd = System.currentTimeMillis(); SingularValueDecomposition svd = SingularValueDecompositionMTJ.create(X); long stopsvd = System.currentTimeMillis(); long startpca = System.currentTimeMillis(); PrincipalComponentsAnalysis pca = new ThinSingularValueDecomposition(OUTPUT_DIM); PrincipalComponentsAnalysisFunction fpca = pca.learn(data); long stoppca = System.currentTimeMillis(); long startkpca = System.currentTimeMillis(); LinearKernel kernel = new LinearKernel(); KernelPrincipalComponentsAnalysis<Vector> instance = new KernelPrincipalComponentsAnalysis<Vector>(kernel, OUTPUT_DIM); KernelPrincipalComponentsAnalysis.Function<Vector> f = instance.learn(data); long stopkpca = System.currentTimeMillis(); System.out.println("Uhat:\n" + f.getComponents().transpose()); System.out.println("U:\n" + svd.getU()); System.out.println("Time taken: SVD = " + (stopsvd - startsvd) + ", PCA = " + (stoppca - startpca) + ", KPCA = " + (stopkpca - startkpca)); assertEquals(OUTPUT_DIM, instance.getComponentCount()); assertEquals(instance.getComponentCount(), f.getOutputDimensionality()); // The mean should project to zero. Vector zeros = VectorFactory.getDefault().createVector(OUTPUT_DIM); if (!zeros.equals(f.evaluate(mean), 1e-5)) { assertEquals(zeros, f.evaluate(mean)); } for (int i = 0; i < OUTPUT_DIM; i++) { Vector alphaI = f.getComponents().getRow(i); for (int j = 0; j < i; j++) { Vector alphaJ = f.getComponents().getRow(j); assertEquals( "Dot product between " + i + " and " + j + " is too large!", 0.0, alphaI.dotProduct( alphaJ ), 1e-2 ); } assertTrue(alphaI.norm2() > 0.0); } } /** * Test of the learn method on a pre-specified set of data. */ public void testLearnLinearExample() { ArrayList<Vector> data = new ArrayList<Vector>(); data.add(new Vector2(2.5, 2.4)); data.add(new Vector2(0.5, 0.7)); data.add(new Vector2(2.2, 2.9)); data.add(new Vector2(1.9, 2.2)); data.add(new Vector2(3.1, 3.0)); data.add(new Vector2(2.3, 2.7)); data.add(new Vector2(2, 1.6)); data.add(new Vector2(1, 1.1)); data.add(new Vector2(1.5, 1.6)); data.add(new Vector2(1.1, 0.9)); System.out.println("Data:"); for (Vector item : data) { System.out.println(item.getElement(0) +" "+item.getElement(1)); } System.out.println(); KernelPrincipalComponentsAnalysis<Vector> kpca = new KernelPrincipalComponentsAnalysis<Vector>(new LinearKernel(), 2); KernelPrincipalComponentsAnalysis.Function<Vector> fkpca = kpca.learn(data); // The mean should project to zero. final Vector zeros = new Vector2(); final Vector mean = MultivariateStatisticsUtil.computeMean(data); if (!zeros.equals(fkpca.evaluate(mean), 1e-5)) { assertEquals(zeros, fkpca.evaluate(mean)); } for (int i = 0; i < 2; i++) { Vector alphaI = fkpca.getComponents().getRow(i); for (int j = 0; j < i; j++) { Vector alphaJ = fkpca.getComponents().getRow(j); assertEquals( "Dot product between " + i + " and " + j + " is too large!", 0.0, alphaI.dotProduct( alphaJ ), 1e-2 ); } assertTrue(alphaI.norm2() > 0.0); } System.out.println("KPCA:"); for (Vector v : data) { System.out.println(fkpca.evaluate(v)); } ThinSingularValueDecomposition pca = new ThinSingularValueDecomposition(2); PrincipalComponentsAnalysisFunction fpca = pca.learn(data); System.out.println("PCA:"); for (Vector v : data) { System.out.println(fpca.evaluate(v)); } } /** * Test of the learn method using a toy example. */ public void testLearnToyExample() { final Vector2[] clusters = new Vector2[] { new Vector2(-0.5, -0.2), new Vector2(0.0, 0.6), new Vector2(0.5, 0.0) }; ArrayList<Vector> data = new ArrayList<Vector>(); int clusterSize = 30; double noise = 0.1; for (Vector2 cluster : clusters) { for (int i = 0; i < clusterSize; i++) { data.add(VectorFactory.getDenseDefault().copyValues( cluster.getX() + random.nextGaussian() * noise, cluster.getY() + random.nextGaussian() * noise)); } } System.out.println("Data:"); for (Vector item : data) { System.out.println(item); } System.out.println(); KernelPrincipalComponentsAnalysis<Vector> kpca = new KernelPrincipalComponentsAnalysis<Vector>(new LinearKernel(), 2); KernelPrincipalComponentsAnalysis.Function<Vector> fkpca = kpca.learn(data); // The mean should project to zero. final Vector zeros = new Vector2(); final Vector mean = MultivariateStatisticsUtil.computeMean(data); if (!zeros.equals(fkpca.evaluate(mean), 1e-5)) { assertEquals(zeros, fkpca.evaluate(mean)); } for (int i = 0; i < 2; i++) { Vector alphaI = fkpca.getComponents().getRow(i); for (int j = 0; j < i; j++) { Vector alphaJ = fkpca.getComponents().getRow(j); assertEquals( "Dot product between " + i + " and " + j + " is too large!", 0.0, alphaI.dotProduct( alphaJ ), 1e-2 ); } assertTrue(alphaI.norm2() > 0.0); } System.out.println("KPCA:"); for (Vector v : data) { System.out.println(fkpca.evaluate(v)); } ThinSingularValueDecomposition pca = new ThinSingularValueDecomposition(2); PrincipalComponentsAnalysisFunction fpca = pca.learn(data); System.out.println("PCA:"); for (Vector v : data) { System.out.println(fpca.evaluate(v)); } } /** * Test of getComponentCount method, of class KernelPrincipalComponentsAnalysis. */ public void testGetComponentCount() { this.testSetComponentCount(); } /** * Test of setComponentCount method, of class KernelPrincipalComponentsAnalysis. */ public void testSetComponentCount() { int componentCount = KernelPrincipalComponentsAnalysis.DEFAULT_COMPONENT_COUNT; KernelPrincipalComponentsAnalysis<String> instance = new KernelPrincipalComponentsAnalysis<String>(); assertEquals(componentCount, instance.getComponentCount()); componentCount = 1; instance.setComponentCount(componentCount); assertEquals(componentCount, instance.getComponentCount()); componentCount = 5; instance.setComponentCount(componentCount); assertEquals(componentCount, instance.getComponentCount()); boolean exceptionThrown = false; try { instance.setComponentCount(0); } catch (IllegalArgumentException e) { exceptionThrown = true; } finally { assertTrue(exceptionThrown); } assertEquals(componentCount, instance.getComponentCount()); exceptionThrown = false; try { instance.setComponentCount(-1); } catch (IllegalArgumentException e) { exceptionThrown = true; } finally { assertTrue(exceptionThrown); } assertEquals(componentCount, instance.getComponentCount()); } /** * Test of isCenterData method, of class KernelPrincipalComponentsAnalysis. */ public void testIsCenterData() { this.testSetCenterData(); } /** * Test of setCenterData method, of class KernelPrincipalComponentsAnalysis. */ public void testSetCenterData() { boolean centerData = KernelPrincipalComponentsAnalysis.DEFAULT_CENTER_DATA; KernelPrincipalComponentsAnalysis<?> instance = new KernelPrincipalComponentsAnalysis<Vector>(); assertEquals(centerData, instance.isCenterData()); centerData = false; instance.setCenterData(centerData); assertEquals(centerData, instance.isCenterData()); centerData = true; instance.setCenterData(centerData); assertEquals(centerData, instance.isCenterData()); centerData = false; instance.setCenterData(centerData); assertEquals(centerData, instance.isCenterData()); } }