/* * Copyright (C) 2012 The Guava Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.google.common.math; import static com.google.common.math.StatsTesting.ALLOWED_ERROR; import static com.google.common.math.StatsTesting.ALL_MANY_VALUES; import static com.google.common.math.StatsTesting.ALL_PAIRED_STATS; import static com.google.common.math.StatsTesting.CONSTANT_VALUES_PAIRED_STATS; import static com.google.common.math.StatsTesting.DUPLICATE_MANY_VALUES_PAIRED_STATS; import static com.google.common.math.StatsTesting.EMPTY_PAIRED_STATS; import static com.google.common.math.StatsTesting.EMPTY_STATS_ITERABLE; import static com.google.common.math.StatsTesting.HORIZONTAL_VALUES_PAIRED_STATS; import static com.google.common.math.StatsTesting.MANY_VALUES; import static com.google.common.math.StatsTesting.MANY_VALUES_COUNT; import static com.google.common.math.StatsTesting.MANY_VALUES_PAIRED_STATS; import static com.google.common.math.StatsTesting.MANY_VALUES_STATS_ITERABLE; import static com.google.common.math.StatsTesting.MANY_VALUES_STATS_VARARGS; import static com.google.common.math.StatsTesting.MANY_VALUES_SUM_OF_PRODUCTS_OF_DELTAS; import static com.google.common.math.StatsTesting.ONE_VALUE_PAIRED_STATS; import static com.google.common.math.StatsTesting.ONE_VALUE_STATS; import static com.google.common.math.StatsTesting.OTHER_MANY_VALUES; import static com.google.common.math.StatsTesting.OTHER_MANY_VALUES_STATS; import static com.google.common.math.StatsTesting.OTHER_ONE_VALUE_STATS; import static com.google.common.math.StatsTesting.OTHER_TWO_VALUES_STATS; import static com.google.common.math.StatsTesting.TWO_VALUES_PAIRED_STATS; import static com.google.common.math.StatsTesting.TWO_VALUES_STATS; import static com.google.common.math.StatsTesting.TWO_VALUES_SUM_OF_PRODUCTS_OF_DELTAS; import static com.google.common.math.StatsTesting.VERTICAL_VALUES_PAIRED_STATS; import static com.google.common.math.StatsTesting.assertDiagonalLinearTransformation; import static com.google.common.math.StatsTesting.assertHorizontalLinearTransformation; import static com.google.common.math.StatsTesting.assertLinearTransformationNaN; import static com.google.common.math.StatsTesting.assertStatsApproxEqual; import static com.google.common.math.StatsTesting.assertVerticalLinearTransformation; import static com.google.common.math.StatsTesting.createPairedStatsOf; import static com.google.common.truth.Truth.assertThat; import com.google.common.collect.ImmutableList; import com.google.common.math.StatsTesting.ManyValues; import com.google.common.testing.EqualsTester; import com.google.common.testing.SerializableTester; import java.nio.ByteBuffer; import java.nio.ByteOrder; import junit.framework.TestCase; /** * Tests for {@link PairedStats}. This tests instances created by * {@link PairedStatsAccumulator#snapshot}. * * @author Pete Gillin */ public class PairedStatsTest extends TestCase { public void testCount() { assertThat(EMPTY_PAIRED_STATS.count()).isEqualTo(0); assertThat(ONE_VALUE_PAIRED_STATS.count()).isEqualTo(1); assertThat(TWO_VALUES_PAIRED_STATS.count()).isEqualTo(2); assertThat(MANY_VALUES_PAIRED_STATS.count()).isEqualTo(MANY_VALUES_COUNT); } public void testXStats() { assertStatsApproxEqual(EMPTY_STATS_ITERABLE, EMPTY_PAIRED_STATS.xStats()); assertStatsApproxEqual(ONE_VALUE_STATS, ONE_VALUE_PAIRED_STATS.xStats()); assertStatsApproxEqual(TWO_VALUES_STATS, TWO_VALUES_PAIRED_STATS.xStats()); assertStatsApproxEqual(MANY_VALUES_STATS_ITERABLE, MANY_VALUES_PAIRED_STATS.xStats()); } public void testYStats() { assertStatsApproxEqual(EMPTY_STATS_ITERABLE, EMPTY_PAIRED_STATS.yStats()); assertStatsApproxEqual(OTHER_ONE_VALUE_STATS, ONE_VALUE_PAIRED_STATS.yStats()); assertStatsApproxEqual(OTHER_TWO_VALUES_STATS, TWO_VALUES_PAIRED_STATS.yStats()); assertStatsApproxEqual(OTHER_MANY_VALUES_STATS, MANY_VALUES_PAIRED_STATS.yStats()); } public void testPopulationCovariance() { try { EMPTY_PAIRED_STATS.populationCovariance(); fail("Expected IllegalStateException"); } catch (IllegalStateException expected) { } assertThat(ONE_VALUE_PAIRED_STATS.populationCovariance()).isWithin(0.0).of(0.0); assertThat(createSingleStats(Double.POSITIVE_INFINITY, 1.23).populationCovariance()).isNaN(); assertThat(createSingleStats(Double.NEGATIVE_INFINITY, 1.23).populationCovariance()).isNaN(); assertThat(createSingleStats(Double.NaN, 1.23).populationCovariance()).isNaN(); assertThat(TWO_VALUES_PAIRED_STATS.populationCovariance()) .isWithin(ALLOWED_ERROR) .of(TWO_VALUES_SUM_OF_PRODUCTS_OF_DELTAS / 2); // For datasets of many double values, we test many combinations of finite and non-finite // x-values: for (ManyValues values : ALL_MANY_VALUES) { PairedStats stats = createPairedStatsOf(values.asIterable(), OTHER_MANY_VALUES); double populationCovariance = stats.populationCovariance(); if (values.hasAnyNonFinite()) { assertThat(populationCovariance).named("population covariance of " + values).isNaN(); } else { assertThat(populationCovariance) .named("population covariance of " + values) .isWithin(ALLOWED_ERROR) .of(MANY_VALUES_SUM_OF_PRODUCTS_OF_DELTAS / MANY_VALUES_COUNT); } } assertThat(HORIZONTAL_VALUES_PAIRED_STATS.populationCovariance()) .isWithin(ALLOWED_ERROR) .of(0.0); assertThat(VERTICAL_VALUES_PAIRED_STATS.populationCovariance()).isWithin(ALLOWED_ERROR).of(0.0); assertThat(CONSTANT_VALUES_PAIRED_STATS.populationCovariance()).isWithin(ALLOWED_ERROR).of(0.0); } public void testSampleCovariance() { try { EMPTY_PAIRED_STATS.sampleCovariance(); fail("Expected IllegalStateException"); } catch (IllegalStateException expected) { } try { ONE_VALUE_PAIRED_STATS.sampleCovariance(); fail("Expected IllegalStateException"); } catch (IllegalStateException expected) { } assertThat(TWO_VALUES_PAIRED_STATS.sampleCovariance()) .isWithin(ALLOWED_ERROR) .of(TWO_VALUES_SUM_OF_PRODUCTS_OF_DELTAS); assertThat(MANY_VALUES_PAIRED_STATS.sampleCovariance()) .isWithin(ALLOWED_ERROR) .of(MANY_VALUES_SUM_OF_PRODUCTS_OF_DELTAS / (MANY_VALUES_COUNT - 1)); assertThat(HORIZONTAL_VALUES_PAIRED_STATS.sampleCovariance()).isWithin(ALLOWED_ERROR).of(0.0); assertThat(VERTICAL_VALUES_PAIRED_STATS.sampleCovariance()).isWithin(ALLOWED_ERROR).of(0.0); assertThat(CONSTANT_VALUES_PAIRED_STATS.sampleCovariance()).isWithin(ALLOWED_ERROR).of(0.0); } public void testPearsonsCorrelationCoefficient() { try { EMPTY_PAIRED_STATS.pearsonsCorrelationCoefficient(); fail("Expected IllegalStateException"); } catch (IllegalStateException expected) { } try { ONE_VALUE_PAIRED_STATS.pearsonsCorrelationCoefficient(); fail("Expected IllegalStateException"); } catch (IllegalStateException expected) { } try { createSingleStats(Double.POSITIVE_INFINITY, 1.23).pearsonsCorrelationCoefficient(); fail("Expected IllegalStateException"); } catch (IllegalStateException expected) { } assertThat(TWO_VALUES_PAIRED_STATS.pearsonsCorrelationCoefficient()) .isWithin(ALLOWED_ERROR) .of( TWO_VALUES_PAIRED_STATS.populationCovariance() / (TWO_VALUES_PAIRED_STATS.xStats().populationStandardDeviation() * TWO_VALUES_PAIRED_STATS.yStats().populationStandardDeviation())); // For datasets of many double values, we test many combinations of finite and non-finite // y-values: for (ManyValues values : ALL_MANY_VALUES) { PairedStats stats = createPairedStatsOf(MANY_VALUES, values.asIterable()); double pearsonsCorrelationCoefficient = stats.pearsonsCorrelationCoefficient(); if (values.hasAnyNonFinite()) { assertThat(pearsonsCorrelationCoefficient) .named("Pearson's correlation coefficient of " + values) .isNaN(); } else { assertThat(pearsonsCorrelationCoefficient) .named("Pearson's correlation coefficient of " + values) .isWithin(ALLOWED_ERROR) .of( stats.populationCovariance() / (stats.xStats().populationStandardDeviation() * stats.yStats().populationStandardDeviation())); } } try { HORIZONTAL_VALUES_PAIRED_STATS.pearsonsCorrelationCoefficient(); fail("Expected IllegalStateException"); } catch (IllegalStateException expected) { } try { VERTICAL_VALUES_PAIRED_STATS.pearsonsCorrelationCoefficient(); fail("Expected IllegalStateException"); } catch (IllegalStateException expected) { } try { CONSTANT_VALUES_PAIRED_STATS.pearsonsCorrelationCoefficient(); fail("Expected IllegalStateException"); } catch (IllegalStateException expected) { } } public void testLeastSquaresFit() { try { EMPTY_PAIRED_STATS.leastSquaresFit(); fail("Expected IllegalStateException"); } catch (IllegalStateException expected) { } try { ONE_VALUE_PAIRED_STATS.leastSquaresFit(); fail("Expected IllegalStateException"); } catch (IllegalStateException expected) { } try { createSingleStats(Double.POSITIVE_INFINITY, 1.23).leastSquaresFit(); fail("Expected IllegalStateException"); } catch (IllegalStateException expected) { } assertDiagonalLinearTransformation( TWO_VALUES_PAIRED_STATS.leastSquaresFit(), TWO_VALUES_PAIRED_STATS.xStats().mean(), TWO_VALUES_PAIRED_STATS.yStats().mean(), TWO_VALUES_PAIRED_STATS.xStats().populationVariance(), TWO_VALUES_PAIRED_STATS.populationCovariance()); // For datasets of many double values, we test many combinations of finite and non-finite // x-values: for (ManyValues values : ALL_MANY_VALUES) { PairedStats stats = createPairedStatsOf(values.asIterable(), OTHER_MANY_VALUES); LinearTransformation fit = stats.leastSquaresFit(); if (values.hasAnyNonFinite()) { assertLinearTransformationNaN(fit); } else { assertDiagonalLinearTransformation( fit, stats.xStats().mean(), stats.yStats().mean(), stats.xStats().populationVariance(), stats.populationCovariance()); } } assertHorizontalLinearTransformation( HORIZONTAL_VALUES_PAIRED_STATS.leastSquaresFit(), HORIZONTAL_VALUES_PAIRED_STATS.yStats().mean()); assertVerticalLinearTransformation( VERTICAL_VALUES_PAIRED_STATS.leastSquaresFit(), VERTICAL_VALUES_PAIRED_STATS.xStats().mean()); try { CONSTANT_VALUES_PAIRED_STATS.leastSquaresFit(); fail("Expected IllegalStateException"); } catch (IllegalStateException expected) { } } public void testEqualsAndHashCode() { new EqualsTester() .addEqualityGroup( MANY_VALUES_PAIRED_STATS, DUPLICATE_MANY_VALUES_PAIRED_STATS, SerializableTester.reserialize(MANY_VALUES_PAIRED_STATS)) .addEqualityGroup( new PairedStats(MANY_VALUES_STATS_ITERABLE, OTHER_MANY_VALUES_STATS, 1.23), new PairedStats(MANY_VALUES_STATS_VARARGS, OTHER_MANY_VALUES_STATS, 1.23)) .addEqualityGroup( new PairedStats(OTHER_MANY_VALUES_STATS, MANY_VALUES_STATS_ITERABLE, 1.23)) .addEqualityGroup( new PairedStats(MANY_VALUES_STATS_ITERABLE, MANY_VALUES_STATS_ITERABLE, 1.23)) .addEqualityGroup( new PairedStats(TWO_VALUES_STATS, MANY_VALUES_STATS_ITERABLE, 1.23)) .addEqualityGroup( new PairedStats(MANY_VALUES_STATS_ITERABLE, ONE_VALUE_STATS, 1.23)) .addEqualityGroup( new PairedStats(MANY_VALUES_STATS_ITERABLE, MANY_VALUES_STATS_ITERABLE, 1.234)) .testEquals(); } public void testSerializable() { SerializableTester.reserializeAndAssert(MANY_VALUES_PAIRED_STATS); } public void testToString() { assertThat(EMPTY_PAIRED_STATS.toString()) .isEqualTo("PairedStats{xStats=Stats{count=0}, yStats=Stats{count=0}}"); assertThat(MANY_VALUES_PAIRED_STATS.toString()) .isEqualTo( "PairedStats{xStats=" + MANY_VALUES_PAIRED_STATS.xStats() + ", yStats=" + MANY_VALUES_PAIRED_STATS.yStats() + ", populationCovariance=" + MANY_VALUES_PAIRED_STATS.populationCovariance() + "}"); } private PairedStats createSingleStats(double x, double y) { return createPairedStatsOf(ImmutableList.of(x), ImmutableList.of(y)); } public void testToByteArrayAndFromByteArrayRoundTrip() { for (PairedStats pairedStats : ALL_PAIRED_STATS) { byte[] pairedStatsByteArray = pairedStats.toByteArray(); // Round trip to byte array and back assertThat(PairedStats.fromByteArray(pairedStatsByteArray)).isEqualTo(pairedStats); } } public void testFromByteArray_withNullInputThrowsNullPointerException() { try { PairedStats.fromByteArray(null); fail("Expected NullPointerException"); } catch (NullPointerException expected) { } } public void testFromByteArray_withEmptyArrayInputThrowsIllegalArgumentException() { try { PairedStats.fromByteArray(new byte[0]); fail("Expected IllegalArgumentException"); } catch (IllegalArgumentException expected) { } } public void testFromByteArray_withTooLongArrayInputThrowsIllegalArgumentException() { byte[] buffer = MANY_VALUES_PAIRED_STATS.toByteArray(); byte[] tooLongByteArray = ByteBuffer.allocate(buffer.length + 2) .order(ByteOrder.LITTLE_ENDIAN) .put(buffer) .putChar('.') .array(); try { PairedStats.fromByteArray(tooLongByteArray); fail("Expected IllegalArgumentException"); } catch (IllegalArgumentException expected) { } } public void testFromByteArrayWithTooShortArrayInputThrowsIllegalArgumentException() { byte[] buffer = MANY_VALUES_PAIRED_STATS.toByteArray(); byte[] tooShortByteArray = ByteBuffer.allocate(buffer.length - 1) .order(ByteOrder.LITTLE_ENDIAN) .put(buffer, 0, buffer.length - 1) .array(); try { PairedStats.fromByteArray(tooShortByteArray); fail("Expected IllegalArgumentException"); } catch (IllegalArgumentException expected) { } } }