/*******************************************************************************
* Copyright (c) 2010 Haifeng Li
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
package smile.projection;
import org.junit.After;
import org.junit.AfterClass;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;
import static org.junit.Assert.*;
import smile.math.Math;
/**
*
* @author Haifeng Li
*/
@SuppressWarnings("unused")
public class GHATest {
double[][] USArrests = {
// Murder Assault UrbanPop Rape
{13.2, 236, 58, 21.2}, // Alabama
{10.0, 263, 48, 44.5}, // Alaska
{8.1, 294, 80, 31.0}, // Arizona
{8.8, 190, 50, 19.5}, // Arkansas
{9.0, 276, 91, 40.6}, // California
{7.9, 204, 78, 38.7}, // Colorado
{3.3, 110, 77, 11.1}, // Connecticut
{5.9, 238, 72, 15.8}, // Delaware
{15.4, 335, 80, 31.9}, // Florida
{17.4, 211, 60, 25.8}, // Georgia
{5.3, 46, 83, 20.2}, // Hawaii
{2.6, 120, 54, 14.2}, // Idaho
{10.4, 249, 83, 24.0}, // Illinois
{7.2, 113, 65, 21.0}, // Indiana
{2.2, 56, 57, 11.3}, // Iowa
{6.0, 115, 66, 18.0}, // Kansas
{9.7, 109, 52, 16.3}, // Kentucky
{15.4, 249, 66, 22.2}, // Louisiana
{2.1, 83, 51, 7.8}, // Maine
{11.3, 300, 67, 27.8}, // Maryland
{4.4, 149, 85, 16.3}, // Massachusetts
{12.1, 255, 74, 35.1}, // Michigan
{2.7, 72, 66, 14.9}, // Michigan
{16.1, 259, 44, 17.1}, // Mississippi
{9.0, 178, 70, 28.2}, // Missouri
{6.0, 109, 53, 16.4}, // Montana
{4.3, 102, 62, 16.5}, // Nebraska
{12.2, 252, 81, 46.0}, // Nevada
{2.1, 57, 56, 9.5}, // New Hampshire
{7.4, 159, 89, 18.8}, // New Jersey
{11.4, 285, 70, 32.1}, // New Mexico
{11.1, 254, 86, 26.1}, // New York
{13.0, 337, 45, 16.1}, // North Carolina
{0.8, 45, 44, 7.3}, // North Dakota
{7.3, 120, 75, 21.4}, // Ohio
{6.6, 151, 68, 20.0}, // Oklahoma
{4.9, 159, 67, 29.3}, // Oregon
{6.3, 106, 72, 14.9}, // Pennsylvania
{3.4, 174, 87, 8.3}, // Rhode Island
{14.4, 279, 48, 22.5}, // South Carolina
{3.8, 86, 45, 12.8}, // South Dakota
{13.2, 188, 59, 26.9}, // Tennessee
{12.7, 201, 80, 25.5}, // Texas
{3.2, 120, 80, 22.9}, // Utah
{2.2, 48, 32, 11.2}, // Vermont
{8.5, 156, 63, 20.7}, // Virginia
{4.0, 145, 73, 26.2}, // Washington
{5.7, 81, 39, 9.3}, // West Virginia
{2.6, 53, 66, 10.8}, // Wisconsin
{6.8, 161, 60, 15.6} // Wyoming
};
double[][] loadings = {
{-0.0417043206282872, -0.0448216562696701, -0.0798906594208108, -0.994921731246978},
{-0.995221281426497, -0.058760027857223, 0.0675697350838043, 0.0389382976351601},
{-0.0463357461197108, 0.97685747990989, 0.200546287353866, -0.0581691430589319},
{-0.075155500585547, 0.200718066450337, -0.974080592182491, 0.0723250196376097}
};
double[] eigenvalues = {7011.114851, 201.992366, 42.112651, 6.164246};
public GHATest() {
}
@BeforeClass
public static void setUpClass() throws Exception {
}
@AfterClass
public static void tearDownClass() throws Exception {
}
@Before
public void setUp() {
}
@After
public void tearDown() {
}
/**
* Test of learn method, of class GHA.
*/
@Test
public void testLearn() {
System.out.println("learn");
int k = 3;
double[] mu = Math.colMean(USArrests);
double[][] cov = Math.cov(USArrests);
for (int i = 0; i < USArrests.length; i++) {
Math.minus(USArrests[i], mu);
}
double r = 0.00001;
GHA gha = new GHA(4, k, r);
for (int iter = 1, t = 0; iter <= 1000; iter++) {
double error = 0.0;
for (int i = 0; i < USArrests.length; i++, t++) {
error += gha.learn(USArrests[i]);
}
error /= USArrests.length;
if (iter % 100 == 0) {
System.out.format("Iter %3d, Error = %.5g%n", iter, error);
}
}
double[][] p = gha.getProjection();
double[][] t = Math.atamm(p);
for (int i = 0; i < t.length; i++) {
for (int j = 0; j < t[i].length; j++) {
System.out.format("% .4f ", t[i][j]);
}
System.out.println();
}
double[][] s = Math.abtmm(Math.abmm(p, cov), p);
double[] ev = new double[k];
System.out.println("Relative error of eigenvalues:");
for (int i = 0; i < k; i++) {
ev[i] = Math.abs(eigenvalues[i] - s[i][i]) / eigenvalues[i];
System.out.format("%.4f ", ev[i]);
}
System.out.println();
for (int i = 0; i < k; i++) {
assertTrue(ev[i] < 0.1);
}
double[][] lt = Math.transpose(loadings);
double[] evdot = new double[k];
System.out.println("Dot products of learned eigenvectors to true eigenvectors:");
for (int i = 0; i < k; i++) {
evdot[i] = Math.dot(lt[i], p[i]);
System.out.format("%.4f ", evdot[i]);
}
System.out.println();
for (int i = 0; i < k; i++) {
assertTrue(Math.abs(1.0-Math.abs(evdot[i])) < 0.1);
}
}
}