package org.nd4j.linalg.dimensionalityreduction; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; /** * Created by rcorbish */ @RunWith(Parameterized.class) public class TestPCA extends BaseNd4jTest { public TestPCA(Nd4jBackend backend) { super(backend); } @Test public void testFactorDims() { int m = 13; int n = 4; double f[] = new double[] {7, 1, 11, 11, 7, 11, 3, 1, 2, 21, 1, 11, 10, 26, 29, 56, 31, 52, 55, 71, 31, 54, 47, 40, 66, 68, 6, 15, 8, 8, 6, 9, 17, 22, 18, 4, 23, 9, 8, 60, 52, 20, 47, 33, 22, 6, 44, 22, 26, 34, 12, 12}; INDArray A = Nd4j.create(f, new int[] {m, n}, 'f'); INDArray A1 = A.dup('f'); INDArray Factor = org.nd4j.linalg.dimensionalityreduction.PCA.pca_factor(A1, 3, true); A1 = A.subiRowVector(A.mean(0)); INDArray Reduced = A1.mmul(Factor); INDArray Reconstructed = Reduced.mmul(Factor.transpose()); INDArray Diff = Reconstructed.sub(A1); for (int i = 0; i < m * n; i++) { assertEquals("Reconstructed matrix is very different from the original.", 0.0, Diff.getDouble(i), 1.0); } } @Test public void testFactorVariance() { int m = 13; int n = 4; double f[] = new double[] {7, 1, 11, 11, 7, 11, 3, 1, 2, 21, 1, 11, 10, 26, 29, 56, 31, 52, 55, 71, 31, 54, 47, 40, 66, 68, 6, 15, 8, 8, 6, 9, 17, 22, 18, 4, 23, 9, 8, 60, 52, 20, 47, 33, 22, 6, 44, 22, 26, 34, 12, 12}; INDArray A = Nd4j.create(f, new int[] {m, n}, 'f'); INDArray A1 = A.dup('f'); INDArray Factor1 = org.nd4j.linalg.dimensionalityreduction.PCA.pca_factor(A1, 0.95, true); A1 = A.subiRowVector(A.mean(0)); INDArray Reduced1 = A1.mmul(Factor1); INDArray Reconstructed1 = Reduced1.mmul(Factor1.transpose()); INDArray Diff1 = Reconstructed1.sub(A1); for (int i = 0; i < m * n; i++) { assertEquals("Reconstructed matrix is very different from the original.", 0.0, Diff1.getDouble(i), 0.1); } INDArray A2 = A.dup('f'); INDArray Factor2 = org.nd4j.linalg.dimensionalityreduction.PCA.pca_factor(A2, 0.50, true); assertTrue("Variance differences should change factor sizes.", Factor1.columns() > Factor2.columns()); } @Override public char ordering() { return 'f'; } }