/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.commons.math4.stat.descriptive; import java.util.Locale; import org.apache.commons.math4.TestUtils; import org.apache.commons.math4.exception.DimensionMismatchException; import org.apache.commons.math4.exception.MathIllegalStateException; import org.apache.commons.math4.stat.descriptive.MultivariateSummaryStatistics; import org.apache.commons.math4.stat.descriptive.StorelessUnivariateStatistic; import org.apache.commons.math4.stat.descriptive.moment.Mean; import org.apache.commons.math4.util.FastMath; import org.junit.Test; import org.junit.Assert; /** * Test cases for the {@link MultivariateSummaryStatistics} class. * */ public class MultivariateSummaryStatisticsTest { protected MultivariateSummaryStatistics createMultivariateSummaryStatistics(int k, boolean isCovarianceBiasCorrected) { return new MultivariateSummaryStatistics(k, isCovarianceBiasCorrected); } @Test public void testSetterInjection() { MultivariateSummaryStatistics u = createMultivariateSummaryStatistics(2, true); u.setMeanImpl(new StorelessUnivariateStatistic[] { new sumMean(), new sumMean() }); u.addValue(new double[] { 1, 2 }); u.addValue(new double[] { 3, 4 }); Assert.assertEquals(4, u.getMean()[0], 1E-14); Assert.assertEquals(6, u.getMean()[1], 1E-14); u.clear(); u.addValue(new double[] { 1, 2 }); u.addValue(new double[] { 3, 4 }); Assert.assertEquals(4, u.getMean()[0], 1E-14); Assert.assertEquals(6, u.getMean()[1], 1E-14); u.clear(); u.setMeanImpl(new StorelessUnivariateStatistic[] { new Mean(), new Mean() }); // OK after clear u.addValue(new double[] { 1, 2 }); u.addValue(new double[] { 3, 4 }); Assert.assertEquals(2, u.getMean()[0], 1E-14); Assert.assertEquals(3, u.getMean()[1], 1E-14); Assert.assertEquals(2, u.getDimension()); } @Test public void testSetterIllegalState() { MultivariateSummaryStatistics u = createMultivariateSummaryStatistics(2, true); u.addValue(new double[] { 1, 2 }); u.addValue(new double[] { 3, 4 }); try { u.setMeanImpl(new StorelessUnivariateStatistic[] { new sumMean(), new sumMean() }); Assert.fail("Expecting MathIllegalStateException"); } catch (MathIllegalStateException ex) { // expected } } @Test public void testToString() { MultivariateSummaryStatistics stats = createMultivariateSummaryStatistics(2, true); stats.addValue(new double[] {1, 3}); stats.addValue(new double[] {2, 2}); stats.addValue(new double[] {3, 1}); Locale d = Locale.getDefault(); Locale.setDefault(Locale.US); final String suffix = System.getProperty("line.separator"); Assert.assertEquals("MultivariateSummaryStatistics:" + suffix+ "n: 3" +suffix+ "min: 1.0, 1.0" +suffix+ "max: 3.0, 3.0" +suffix+ "mean: 2.0, 2.0" +suffix+ "geometric mean: 1.817..., 1.817..." +suffix+ "sum of squares: 14.0, 14.0" +suffix+ "sum of logarithms: 1.791..., 1.791..." +suffix+ "standard deviation: 1.0, 1.0" +suffix+ "covariance: Array2DRowRealMatrix{{1.0,-1.0},{-1.0,1.0}}" +suffix, stats.toString().replaceAll("([0-9]+\\.[0-9][0-9][0-9])[0-9]+", "$1...")); Locale.setDefault(d); } @Test public void testShuffledStatistics() { // the purpose of this test is only to check the get/set methods // we are aware shuffling statistics like this is really not // something sensible to do in production ... MultivariateSummaryStatistics reference = createMultivariateSummaryStatistics(2, true); MultivariateSummaryStatistics shuffled = createMultivariateSummaryStatistics(2, true); StorelessUnivariateStatistic[] tmp = shuffled.getGeoMeanImpl(); shuffled.setGeoMeanImpl(shuffled.getMeanImpl()); shuffled.setMeanImpl(shuffled.getMaxImpl()); shuffled.setMaxImpl(shuffled.getMinImpl()); shuffled.setMinImpl(shuffled.getSumImpl()); shuffled.setSumImpl(shuffled.getSumsqImpl()); shuffled.setSumsqImpl(shuffled.getSumLogImpl()); shuffled.setSumLogImpl(tmp); for (int i = 100; i > 0; --i) { reference.addValue(new double[] {i, i}); shuffled.addValue(new double[] {i, i}); } TestUtils.assertEquals(reference.getMean(), shuffled.getGeometricMean(), 1.0e-10); TestUtils.assertEquals(reference.getMax(), shuffled.getMean(), 1.0e-10); TestUtils.assertEquals(reference.getMin(), shuffled.getMax(), 1.0e-10); TestUtils.assertEquals(reference.getSum(), shuffled.getMin(), 1.0e-10); TestUtils.assertEquals(reference.getSumSq(), shuffled.getSum(), 1.0e-10); TestUtils.assertEquals(reference.getSumLog(), shuffled.getSumSq(), 1.0e-10); TestUtils.assertEquals(reference.getGeometricMean(), shuffled.getSumLog(), 1.0e-10); } /** * Bogus mean implementation to test setter injection. * Returns the sum instead of the mean. */ static class sumMean implements StorelessUnivariateStatistic { private double sum = 0; private long n = 0; @Override public double evaluate(double[] values, int begin, int length) { return 0; } @Override public double evaluate(double[] values) { return 0; } @Override public void clear() { sum = 0; n = 0; } @Override public long getN() { return n; } @Override public double getResult() { return sum; } @Override public void increment(double d) { sum += d; n++; } @Override public void incrementAll(double[] values, int start, int length) { } @Override public void incrementAll(double[] values) { } @Override public StorelessUnivariateStatistic copy() { return new sumMean(); } } @Test public void testDimension() { try { createMultivariateSummaryStatistics(2, true).addValue(new double[3]); Assert.fail("Expecting DimensionMismatchException"); } catch (DimensionMismatchException dme) { // expected behavior } } /** test stats */ @Test public void testStats() { MultivariateSummaryStatistics u = createMultivariateSummaryStatistics(2, true); Assert.assertEquals(0, u.getN()); u.addValue(new double[] { 1, 2 }); u.addValue(new double[] { 2, 3 }); u.addValue(new double[] { 2, 3 }); u.addValue(new double[] { 3, 4 }); Assert.assertEquals( 4, u.getN()); Assert.assertEquals( 8, u.getSum()[0], 1.0e-10); Assert.assertEquals(12, u.getSum()[1], 1.0e-10); Assert.assertEquals(18, u.getSumSq()[0], 1.0e-10); Assert.assertEquals(38, u.getSumSq()[1], 1.0e-10); Assert.assertEquals( 1, u.getMin()[0], 1.0e-10); Assert.assertEquals( 2, u.getMin()[1], 1.0e-10); Assert.assertEquals( 3, u.getMax()[0], 1.0e-10); Assert.assertEquals( 4, u.getMax()[1], 1.0e-10); Assert.assertEquals(2.4849066497880003102, u.getSumLog()[0], 1.0e-10); Assert.assertEquals( 4.276666119016055311, u.getSumLog()[1], 1.0e-10); Assert.assertEquals( 1.8612097182041991979, u.getGeometricMean()[0], 1.0e-10); Assert.assertEquals( 2.9129506302439405217, u.getGeometricMean()[1], 1.0e-10); Assert.assertEquals( 2, u.getMean()[0], 1.0e-10); Assert.assertEquals( 3, u.getMean()[1], 1.0e-10); Assert.assertEquals(FastMath.sqrt(2.0 / 3.0), u.getStandardDeviation()[0], 1.0e-10); Assert.assertEquals(FastMath.sqrt(2.0 / 3.0), u.getStandardDeviation()[1], 1.0e-10); Assert.assertEquals(2.0 / 3.0, u.getCovariance().getEntry(0, 0), 1.0e-10); Assert.assertEquals(2.0 / 3.0, u.getCovariance().getEntry(0, 1), 1.0e-10); Assert.assertEquals(2.0 / 3.0, u.getCovariance().getEntry(1, 0), 1.0e-10); Assert.assertEquals(2.0 / 3.0, u.getCovariance().getEntry(1, 1), 1.0e-10); u.clear(); Assert.assertEquals(0, u.getN()); } @Test public void testN0andN1Conditions() { MultivariateSummaryStatistics u = createMultivariateSummaryStatistics(1, true); Assert.assertTrue(Double.isNaN(u.getMean()[0])); Assert.assertTrue(Double.isNaN(u.getStandardDeviation()[0])); /* n=1 */ u.addValue(new double[] { 1 }); Assert.assertEquals(1.0, u.getMean()[0], 1.0e-10); Assert.assertEquals(1.0, u.getGeometricMean()[0], 1.0e-10); Assert.assertEquals(0.0, u.getStandardDeviation()[0], 1.0e-10); /* n=2 */ u.addValue(new double[] { 2 }); Assert.assertTrue(u.getStandardDeviation()[0] > 0); } @Test public void testNaNContracts() { MultivariateSummaryStatistics u = createMultivariateSummaryStatistics(1, true); Assert.assertTrue(Double.isNaN(u.getMean()[0])); Assert.assertTrue(Double.isNaN(u.getMin()[0])); Assert.assertTrue(Double.isNaN(u.getStandardDeviation()[0])); Assert.assertTrue(Double.isNaN(u.getGeometricMean()[0])); u.addValue(new double[] { 1.0 }); Assert.assertFalse(Double.isNaN(u.getMean()[0])); Assert.assertFalse(Double.isNaN(u.getMin()[0])); Assert.assertFalse(Double.isNaN(u.getStandardDeviation()[0])); Assert.assertFalse(Double.isNaN(u.getGeometricMean()[0])); } @Test public void testSerialization() { MultivariateSummaryStatistics u = createMultivariateSummaryStatistics(2, true); // Empty test TestUtils.checkSerializedEquality(u); MultivariateSummaryStatistics s = (MultivariateSummaryStatistics) TestUtils.serializeAndRecover(u); Assert.assertEquals(u, s); // Add some data u.addValue(new double[] { 2d, 1d }); u.addValue(new double[] { 1d, 1d }); u.addValue(new double[] { 3d, 1d }); u.addValue(new double[] { 4d, 1d }); u.addValue(new double[] { 5d, 1d }); // Test again TestUtils.checkSerializedEquality(u); s = (MultivariateSummaryStatistics) TestUtils.serializeAndRecover(u); Assert.assertEquals(u, s); } @Test public void testEqualsAndHashCode() { MultivariateSummaryStatistics u = createMultivariateSummaryStatistics(2, true); MultivariateSummaryStatistics t = null; int emptyHash = u.hashCode(); Assert.assertTrue(u.equals(u)); Assert.assertFalse(u.equals(t)); Assert.assertFalse(u.equals(Double.valueOf(0))); t = createMultivariateSummaryStatistics(2, true); Assert.assertTrue(t.equals(u)); Assert.assertTrue(u.equals(t)); Assert.assertEquals(emptyHash, t.hashCode()); // Add some data to u u.addValue(new double[] { 2d, 1d }); u.addValue(new double[] { 1d, 1d }); u.addValue(new double[] { 3d, 1d }); u.addValue(new double[] { 4d, 1d }); u.addValue(new double[] { 5d, 1d }); Assert.assertFalse(t.equals(u)); Assert.assertFalse(u.equals(t)); Assert.assertTrue(u.hashCode() != t.hashCode()); //Add data in same order to t t.addValue(new double[] { 2d, 1d }); t.addValue(new double[] { 1d, 1d }); t.addValue(new double[] { 3d, 1d }); t.addValue(new double[] { 4d, 1d }); t.addValue(new double[] { 5d, 1d }); Assert.assertTrue(t.equals(u)); Assert.assertTrue(u.equals(t)); Assert.assertEquals(u.hashCode(), t.hashCode()); // Clear and make sure summaries are indistinguishable from empty summary u.clear(); t.clear(); Assert.assertTrue(t.equals(u)); Assert.assertTrue(u.equals(t)); Assert.assertEquals(emptyHash, t.hashCode()); Assert.assertEquals(emptyHash, u.hashCode()); } }