/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.commons.math4.stat.descriptive; import org.apache.commons.math4.TestUtils; import org.apache.commons.math4.exception.MathIllegalStateException; import org.apache.commons.math4.stat.descriptive.StatisticalSummary; import org.apache.commons.math4.stat.descriptive.SummaryStatistics; import org.apache.commons.math4.stat.descriptive.moment.GeometricMean; import org.apache.commons.math4.stat.descriptive.moment.Mean; import org.apache.commons.math4.stat.descriptive.moment.Variance; import org.apache.commons.math4.stat.descriptive.summary.Sum; import org.apache.commons.math4.util.FastMath; import org.junit.Assert; import org.junit.Test; /** * Test cases for the {@link SummaryStatistics} class. */ public class SummaryStatisticsTest { private final double one = 1; private final float twoF = 2; private final long twoL = 2; private final int three = 3; private final double mean = 2; private final double sumSq = 18; private final double sum = 8; private final double var = 0.666666666666666666667; private final double popVar = 0.5; private final double std = FastMath.sqrt(var); private final double n = 4; private final double min = 1; private final double max = 3; private final double tolerance = 10E-15; protected SummaryStatistics createSummaryStatistics() { return new SummaryStatistics(); } /** test stats */ @Test public void testStats() { SummaryStatistics u = createSummaryStatistics(); Assert.assertEquals("total count",0,u.getN(),tolerance); u.addValue(one); u.addValue(twoF); u.addValue(twoL); u.addValue(three); Assert.assertEquals("N",n,u.getN(),tolerance); Assert.assertEquals("sum",sum,u.getSum(),tolerance); Assert.assertEquals("sumsq",sumSq,u.getSumsq(),tolerance); Assert.assertEquals("var",var,u.getVariance(),tolerance); Assert.assertEquals("population var",popVar,u.getPopulationVariance(),tolerance); Assert.assertEquals("std",std,u.getStandardDeviation(),tolerance); Assert.assertEquals("mean",mean,u.getMean(),tolerance); Assert.assertEquals("min",min,u.getMin(),tolerance); Assert.assertEquals("max",max,u.getMax(),tolerance); u.clear(); Assert.assertEquals("total count",0,u.getN(),tolerance); } @Test public void testN0andN1Conditions() { SummaryStatistics u = createSummaryStatistics(); Assert.assertTrue("Mean of n = 0 set should be NaN", Double.isNaN( u.getMean() ) ); Assert.assertTrue("Standard Deviation of n = 0 set should be NaN", Double.isNaN( u.getStandardDeviation() ) ); Assert.assertTrue("Variance of n = 0 set should be NaN", Double.isNaN(u.getVariance() ) ); /* n=1 */ u.addValue(one); Assert.assertTrue("mean should be one (n = 1)", u.getMean() == one); Assert.assertTrue("geometric should be one (n = 1) instead it is " + u.getGeometricMean(), u.getGeometricMean() == one); Assert.assertTrue("Std should be zero (n = 1)", u.getStandardDeviation() == 0.0); Assert.assertTrue("variance should be zero (n = 1)", u.getVariance() == 0.0); /* n=2 */ u.addValue(twoF); Assert.assertTrue("Std should not be zero (n = 2)", u.getStandardDeviation() != 0.0); Assert.assertTrue("variance should not be zero (n = 2)", u.getVariance() != 0.0); } @Test public void testProductAndGeometricMean() { SummaryStatistics u = createSummaryStatistics(); u.addValue( 1.0 ); u.addValue( 2.0 ); u.addValue( 3.0 ); u.addValue( 4.0 ); Assert.assertEquals( "Geometric mean not expected", 2.213364, u.getGeometricMean(), 0.00001 ); } @Test public void testNaNContracts() { SummaryStatistics u = createSummaryStatistics(); Assert.assertTrue("mean not NaN",Double.isNaN(u.getMean())); Assert.assertTrue("min not NaN",Double.isNaN(u.getMin())); Assert.assertTrue("std dev not NaN",Double.isNaN(u.getStandardDeviation())); Assert.assertTrue("var not NaN",Double.isNaN(u.getVariance())); Assert.assertTrue("geom mean not NaN",Double.isNaN(u.getGeometricMean())); u.addValue(1.0); Assert.assertEquals( "mean not expected", 1.0, u.getMean(), Double.MIN_VALUE); Assert.assertEquals( "variance not expected", 0.0, u.getVariance(), Double.MIN_VALUE); Assert.assertEquals( "geometric mean not expected", 1.0, u.getGeometricMean(), Double.MIN_VALUE); u.addValue(-1.0); Assert.assertTrue("geom mean not NaN",Double.isNaN(u.getGeometricMean())); u.addValue(0.0); Assert.assertTrue("geom mean not NaN",Double.isNaN(u.getGeometricMean())); //FiXME: test all other NaN contract specs } @Test public void testGetSummary() { SummaryStatistics u = createSummaryStatistics(); StatisticalSummary summary = u.getSummary(); verifySummary(u, summary); u.addValue(1d); summary = u.getSummary(); verifySummary(u, summary); u.addValue(2d); summary = u.getSummary(); verifySummary(u, summary); u.addValue(2d); summary = u.getSummary(); verifySummary(u, summary); } @Test public void testSerialization() { SummaryStatistics u = createSummaryStatistics(); // Empty test TestUtils.checkSerializedEquality(u); SummaryStatistics s = (SummaryStatistics) TestUtils.serializeAndRecover(u); StatisticalSummary summary = s.getSummary(); verifySummary(u, summary); // Add some data u.addValue(2d); u.addValue(1d); u.addValue(3d); u.addValue(4d); u.addValue(5d); // Test again TestUtils.checkSerializedEquality(u); s = (SummaryStatistics) TestUtils.serializeAndRecover(u); summary = s.getSummary(); verifySummary(u, summary); } @Test public void testEqualsAndHashCode() { SummaryStatistics u = createSummaryStatistics(); SummaryStatistics t = null; int emptyHash = u.hashCode(); Assert.assertTrue("reflexive", u.equals(u)); Assert.assertFalse("non-null compared to null", u.equals(t)); Assert.assertFalse("wrong type", u.equals(Double.valueOf(0))); t = createSummaryStatistics(); Assert.assertTrue("empty instances should be equal", t.equals(u)); Assert.assertTrue("empty instances should be equal", u.equals(t)); Assert.assertEquals("empty hash code", emptyHash, t.hashCode()); // Add some data to u u.addValue(2d); u.addValue(1d); u.addValue(3d); u.addValue(4d); Assert.assertFalse("different n's should make instances not equal", t.equals(u)); Assert.assertFalse("different n's should make instances not equal", u.equals(t)); Assert.assertTrue("different n's should make hashcodes different", u.hashCode() != t.hashCode()); //Add data in same order to t t.addValue(2d); t.addValue(1d); t.addValue(3d); t.addValue(4d); Assert.assertTrue("summaries based on same data should be equal", t.equals(u)); Assert.assertTrue("summaries based on same data should be equal", u.equals(t)); Assert.assertEquals("summaries based on same data should have same hashcodes", u.hashCode(), t.hashCode()); // Clear and make sure summaries are indistinguishable from empty summary u.clear(); t.clear(); Assert.assertTrue("empty instances should be equal", t.equals(u)); Assert.assertTrue("empty instances should be equal", u.equals(t)); Assert.assertEquals("empty hash code", emptyHash, t.hashCode()); Assert.assertEquals("empty hash code", emptyHash, u.hashCode()); } @Test public void testCopy() { SummaryStatistics u = createSummaryStatistics(); u.addValue(2d); u.addValue(1d); u.addValue(3d); u.addValue(4d); SummaryStatistics v = new SummaryStatistics(u); Assert.assertEquals(u, v); Assert.assertEquals(v, u); // Make sure both behave the same with additional values added u.addValue(7d); u.addValue(9d); u.addValue(11d); u.addValue(23d); v.addValue(7d); v.addValue(9d); v.addValue(11d); v.addValue(23d); Assert.assertEquals(u, v); Assert.assertEquals(v, u); // Check implementation pointers are preserved u.clear(); u.setSumImpl(new Sum()); SummaryStatistics.copy(u,v); Assert.assertEquals(u.getSumImpl(), v.getSumImpl()); } private void verifySummary(SummaryStatistics u, StatisticalSummary s) { Assert.assertEquals("N",s.getN(),u.getN()); TestUtils.assertEquals("sum",s.getSum(),u.getSum(),tolerance); TestUtils.assertEquals("var",s.getVariance(),u.getVariance(),tolerance); TestUtils.assertEquals("std",s.getStandardDeviation(),u.getStandardDeviation(),tolerance); TestUtils.assertEquals("mean",s.getMean(),u.getMean(),tolerance); TestUtils.assertEquals("min",s.getMin(),u.getMin(),tolerance); TestUtils.assertEquals("max",s.getMax(),u.getMax(),tolerance); } @Test public void testSetterInjection() { SummaryStatistics u = createSummaryStatistics(); u.setMeanImpl(new Sum()); u.setSumLogImpl(new Sum()); u.addValue(1); u.addValue(3); Assert.assertEquals(4, u.getMean(), 1E-14); Assert.assertEquals(4, u.getSumOfLogs(), 1E-14); Assert.assertEquals(FastMath.exp(2), u.getGeometricMean(), 1E-14); u.clear(); u.addValue(1); u.addValue(2); Assert.assertEquals(3, u.getMean(), 1E-14); u.clear(); u.setMeanImpl(new Mean()); // OK after clear } @Test public void testSetterIllegalState() { SummaryStatistics u = createSummaryStatistics(); u.addValue(1); u.addValue(3); try { u.setMeanImpl(new Sum()); Assert.fail("Expecting MathIllegalStateException"); } catch (MathIllegalStateException ex) { // expected } } @Test public void testQuadraticMean() { final double[] values = { 1.2, 3.4, 5.6, 7.89 }; final SummaryStatistics stats = createSummaryStatistics(); final int len = values.length; double expected = 0; for (int i = 0; i < len; i++) { final double v = values[i]; expected += v * v / len; stats.addValue(v); } expected = Math.sqrt(expected); Assert.assertEquals(expected, stats.getQuadraticMean(), Math.ulp(expected)); } /** * JIRA: MATH-691 */ @Test public void testOverrideVarianceWithMathClass() { double[] scores = {1, 2, 3, 4}; SummaryStatistics stats = new SummaryStatistics(); stats.setVarianceImpl(new Variance(false)); //use "population variance" for(double i : scores) { stats.addValue(i); } Assert.assertEquals((new Variance(false)).evaluate(scores),stats.getVariance(), 0); } @Test public void testOverrideMeanWithMathClass() { double[] scores = {1, 2, 3, 4}; SummaryStatistics stats = new SummaryStatistics(); stats.setMeanImpl(new Mean()); for(double i : scores) { stats.addValue(i); } Assert.assertEquals((new Mean()).evaluate(scores),stats.getMean(), 0); } @Test public void testOverrideGeoMeanWithMathClass() { double[] scores = {1, 2, 3, 4}; SummaryStatistics stats = new SummaryStatistics(); stats.setGeoMeanImpl(new GeometricMean()); for(double i : scores) { stats.addValue(i); } Assert.assertEquals((new GeometricMean()).evaluate(scores),stats.getGeometricMean(), 0); } @Test public void testToString() { SummaryStatistics u = createSummaryStatistics(); for (int i = 0; i < 5; i++) { u.addValue(i); } final String[] labels = {"min", "max", "sum", "geometric mean", "variance", "population variance", "second moment", "sum of squares", "standard deviation", "sum of logs"}; final double[] values = {u.getMin(), u.getMax(), u.getSum(), u.getGeometricMean(), u.getVariance(), u.getPopulationVariance(), u.getSecondMoment(), u.getSumsq(), u.getStandardDeviation(), u.getSumOfLogs()}; final String toString = u.toString(); Assert.assertTrue(toString.indexOf("n: " + u.getN()) > 0); // getN() returns a long for (int i = 0; i < values.length; i++) { Assert.assertTrue(toString.indexOf(labels[i] + ": " + String.valueOf(values[i])) > 0); } } }