package edu.stanford.nlp.math; import junit.framework.TestCase; public class ArrayMathTest extends TestCase { private double[] d1 = new double[3]; private double[] d2 = new double[3]; private double[] d3 = new double[3]; private double[] d4 = new double[3]; private double[] d5 = new double[4]; @Override public void setUp() { d1[0] = 1.0; d1[1] = 343.33; d1[2] = -13.1; d2[0] = 1.0; d2[1] = 343.33; d2[2] = -13.1; d3[0] = Double.NaN; d3[1] = Double.POSITIVE_INFINITY ; d3[2] = 2; d4[0] = 0.1; d4[1] = 0.2; d4[2] = 0.3; d5[0] = 0.1; d5[1] = 0.2; d5[2] = 0.3; d5[3] = 0.8; } public void testInnerProduct() { double inner = ArrayMath.innerProduct(d4, d4); assertEquals("Wrong inner product", 0.14, inner, 1e-6); inner = ArrayMath.innerProduct(d5, d5); assertEquals("Wrong inner product", 0.78, inner, 1e-6); } public void testNumRows() { int nRows = ArrayMath.numRows(d1); assertEquals(nRows, 3); } public void testExpLog() { double[] d1prime = ArrayMath.log(ArrayMath.exp(d1)); double[] diff = ArrayMath.pairwiseSubtract(d1, d1prime); double norm2 = ArrayMath.norm(diff); assertTrue(norm2 < 1e-4); } public void testExpLogInplace() { ArrayMath.expInPlace(d1); ArrayMath.logInPlace(d1); ArrayMath.pairwiseSubtractInPlace(d1, d2); double norm2 = ArrayMath.norm(d1); assertTrue(norm2 < 1e-4); } public void testAddInPlace() { ArrayMath.addInPlace(d1, 3); for (int i = 0; i < ArrayMath.numRows(d1); i++) { assertTrue(d1[i]==d2[i]+3); } } public void testMultiplyInPlace() { ArrayMath.multiplyInPlace(d1, 3); for (int i = 0; i < ArrayMath.numRows(d1); i++) { assertTrue(d1[i]==d2[i]*3); } } public void testPowInPlace() { ArrayMath.powInPlace(d1, 3); for (int i = 0; i < ArrayMath.numRows(d1); i++) { assertTrue(d1[i]==Math.pow(d2[i],3)); } } public void testAdd() { double[] d1prime = ArrayMath.add(d1, 3); for (int i = 0; i < ArrayMath.numRows(d1prime); i++) { assertTrue(d1prime[i]==d1[i]+3); } } public void testMultiply() { double[] d1prime = ArrayMath.multiply(d1, 3); for (int i = 0; i < ArrayMath.numRows(d1prime); i++) { assertTrue(d1prime[i]==d1[i]*3); } } public void testPow() { double[] d1prime = ArrayMath.pow(d1, 3); for (int i = 0; i < ArrayMath.numRows(d1prime); i++) { assertTrue(d1prime[i]==Math.pow(d1[i],3)); } } public void testPairwiseAdd() { double[] sum = ArrayMath.pairwiseAdd(d1,d2); for (int i = 0; i < ArrayMath.numRows(d1); i++) { assertTrue(sum[i] == d1[i]+d2[i]); } } public void testPairwiseSubtract() { double[] diff = ArrayMath.pairwiseSubtract(d1,d2); for (int i = 0; i < ArrayMath.numRows(d1); i++) { assertTrue(diff[i] == d1[i]-d2[i]); } } public void testPairwiseMultiply() { double[] product = ArrayMath.pairwiseMultiply(d1,d2); for (int i = 0; i < ArrayMath.numRows(d1); i++) { assertTrue(product[i] == d1[i]*d2[i]); } } public void testHasNaN() { assertFalse(ArrayMath.hasNaN(d1)); assertFalse(ArrayMath.hasNaN(d2)); assertTrue(ArrayMath.hasNaN(d3)); } public void testHasInfinite() { assertFalse(ArrayMath.hasInfinite(d1)); assertFalse(ArrayMath.hasInfinite(d2)); assertTrue(ArrayMath.hasInfinite(d3)); } public void testCountNaN() { assertTrue(ArrayMath.countNaN(d1)==0); assertTrue(ArrayMath.countNaN(d2)==0); assertTrue(ArrayMath.countNaN(d3)==1); } public void testFliterNaN() { double[] f_d3 = ArrayMath.filterNaN(d3); assertTrue(ArrayMath.numRows(f_d3)==2); assertTrue(ArrayMath.countNaN(f_d3)==0); } public void testCountInfinite() { assertTrue(ArrayMath.countInfinite(d1)==0); assertTrue(ArrayMath.countInfinite(d2)==0); assertTrue(ArrayMath.countInfinite(d3)==1); } public void testFliterInfinite() { double[] f_d3 = ArrayMath.filterInfinite(d3); assertTrue(ArrayMath.numRows(f_d3)==2); assertTrue(ArrayMath.countInfinite(f_d3)==0); } public void testFliterNaNAndInfinite() { double[] f_d3 = ArrayMath.filterNaNAndInfinite(d3); assertTrue(ArrayMath.numRows(f_d3)==1); assertTrue(ArrayMath.countInfinite(f_d3)==0); assertTrue(ArrayMath.countNaN(f_d3)==0); } public void testSum() { double sum = ArrayMath.sum(d1); double mySum = 0.0; for (double d : d1) { mySum += d; } assertTrue(sum==mySum); } public void testNorm_inf() { double ninf = ArrayMath.norm_inf(d1); double max = ArrayMath.max(d1); assertTrue(ninf==max); ninf = ArrayMath.norm_inf(d2); max = ArrayMath.max(d2); assertTrue(ninf==max); ninf = ArrayMath.norm_inf(d3); max = ArrayMath.max(d3); assertTrue(ninf==max); } public void testArgmax() { assertTrue(ArrayMath.max(d1)==d1[ArrayMath.argmax(d1)]); assertTrue(ArrayMath.max(d2)==d2[ArrayMath.argmax(d2)]); assertTrue(ArrayMath.max(d3)==d3[ArrayMath.argmax(d3)]); } public void testArgmin() { assertTrue(ArrayMath.min(d1)==d1[ArrayMath.argmin(d1)]); assertTrue(ArrayMath.min(d2)==d2[ArrayMath.argmin(d2)]); assertTrue(ArrayMath.min(d3)==d3[ArrayMath.argmin(d3)]); } public void testLogSum() { double lsum = ArrayMath.logSum(d4); double myLsum = 0; for (double d : d4) { myLsum += Math.exp(d); } myLsum = Math.log(myLsum); assertTrue(myLsum == lsum); } public void testNormalize() { double tol = 1e-4; ArrayMath.normalize(d1); ArrayMath.normalize(d2); //ArrayMath.normalize(d3); ArrayMath.normalize(d4); assertTrue(ArrayMath.sum(d1)-1 < tol); assertTrue(ArrayMath.sum(d2)-1 < tol); //assertTrue(ArrayMath.sum(d3)-1 < tol); assertTrue(ArrayMath.sum(d4)-1 < tol); } public void testKLDivergence() { double kld = ArrayMath.klDivergence(d1, d2); assertTrue(kld==0); } public void testSumAndMean() { assertTrue(ArrayMath.sum(d1) == ArrayMath.mean(d1)*d1.length); assertTrue(ArrayMath.sum(d2) == ArrayMath.mean(d2)*d2.length); //assertTrue(ArrayMath.sum(d3) == ArrayMath.mean(d3)*d3.length); assertTrue(ArrayMath.sum(d4) == ArrayMath.mean(d4)*d4.length); } public static void helpTestSafeSumAndMean(double[] d) { double[] dprime = ArrayMath.filterNaNAndInfinite(d); assertTrue(ArrayMath.safeMean(d)*ArrayMath.numRows(dprime)==ArrayMath.sum(dprime)); } public void testSafeSumAndMean() { helpTestSafeSumAndMean(d1); helpTestSafeSumAndMean(d2); helpTestSafeSumAndMean(d3); helpTestSafeSumAndMean(d4); } public void testJensenShannon() { double[] a = { 0.1, 0.1, 0.7, 0.1, 0.0, 0.0 }; double[] b = { 0.0, 0.1, 0.1, 0.7, 0.1, 0.0 }; assertEquals(0.46514844544032313, ArrayMath.jensenShannonDivergence(a, b), 1e-5); double[] c = { 1.0, 0.0, 0.0 }; double[] d = { 0.0, 0.5, 0.5 }; assertEquals(1.0, ArrayMath.jensenShannonDivergence(c, d), 1e-5); } }