/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.beam.sdk.transforms; import static org.apache.beam.sdk.TestUtils.checkCombineFn; import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.collection.IsIterableContainingInOrder.contains; import java.io.Serializable; import java.util.ArrayList; import java.util.Arrays; import java.util.Comparator; import java.util.List; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.coders.BigEndianIntegerCoder; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.testing.NeedsRunner; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.ApproximateQuantiles.ApproximateQuantilesCombineFn; import org.apache.beam.sdk.transforms.display.DisplayData; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.hamcrest.CoreMatchers; import org.hamcrest.Description; import org.hamcrest.Matcher; import org.hamcrest.TypeSafeDiagnosingMatcher; import org.junit.Rule; import org.junit.Test; import org.junit.experimental.categories.Category; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; /** * Tests for {@link ApproximateQuantiles}. */ @RunWith(JUnit4.class) public class ApproximateQuantilesTest { static final List<KV<String, Integer>> TABLE = Arrays.asList( KV.of("a", 1), KV.of("a", 2), KV.of("a", 3), KV.of("b", 1), KV.of("b", 10), KV.of("b", 10), KV.of("b", 100) ); @Rule public TestPipeline p = TestPipeline.create(); public PCollection<KV<String, Integer>> createInputTable(Pipeline p) { return p.apply(Create.of(TABLE).withCoder( KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of()))); } @Test @Category(NeedsRunner.class) public void testQuantilesGlobally() { PCollection<Integer> input = intRangeCollection(p, 101); PCollection<List<Integer>> quantiles = input.apply(ApproximateQuantiles.<Integer>globally(5)); PAssert.that(quantiles) .containsInAnyOrder(Arrays.asList(0, 25, 50, 75, 100)); p.run(); } @Test @Category(NeedsRunner.class) public void testQuantilesGobally_comparable() { PCollection<Integer> input = intRangeCollection(p, 101); PCollection<List<Integer>> quantiles = input.apply( ApproximateQuantiles.globally(5, new DescendingIntComparator())); PAssert.that(quantiles) .containsInAnyOrder(Arrays.asList(100, 75, 50, 25, 0)); p.run(); } @Test @Category(NeedsRunner.class) public void testQuantilesPerKey() { PCollection<KV<String, Integer>> input = createInputTable(p); PCollection<KV<String, List<Integer>>> quantiles = input.apply( ApproximateQuantiles.<String, Integer>perKey(2)); PAssert.that(quantiles) .containsInAnyOrder( KV.of("a", Arrays.asList(1, 3)), KV.of("b", Arrays.asList(1, 100))); p.run(); } @Test @Category(NeedsRunner.class) public void testQuantilesPerKey_reversed() { PCollection<KV<String, Integer>> input = createInputTable(p); PCollection<KV<String, List<Integer>>> quantiles = input.apply( ApproximateQuantiles.<String, Integer, DescendingIntComparator>perKey( 2, new DescendingIntComparator())); PAssert.that(quantiles) .containsInAnyOrder( KV.of("a", Arrays.asList(3, 1)), KV.of("b", Arrays.asList(100, 1))); p.run(); } @Test public void testSingleton() { checkCombineFn( ApproximateQuantilesCombineFn.<Integer>create(5), Arrays.asList(389), Arrays.asList(389, 389, 389, 389, 389)); } @Test public void testSimpleQuantiles() { checkCombineFn( ApproximateQuantilesCombineFn.<Integer>create(5), intRange(101), Arrays.asList(0, 25, 50, 75, 100)); } @Test public void testUnevenQuantiles() { checkCombineFn( ApproximateQuantilesCombineFn.<Integer>create(37), intRange(5000), quantileMatcher(5000, 37, 20 /* tolerance */)); } @Test public void testLargerQuantiles() { checkCombineFn( ApproximateQuantilesCombineFn.<Integer>create(50), intRange(10001), quantileMatcher(10001, 50, 20 /* tolerance */)); } @Test public void testTightEpsilon() { checkCombineFn( ApproximateQuantilesCombineFn.<Integer>create(10).withEpsilon(0.01), intRange(10001), quantileMatcher(10001, 10, 5 /* tolerance */)); } @Test public void testDuplicates() { int size = 101; List<Integer> all = new ArrayList<>(); for (int i = 0; i < 10; i++) { all.addAll(intRange(size)); } checkCombineFn( ApproximateQuantilesCombineFn.<Integer>create(5), all, Arrays.asList(0, 25, 50, 75, 100)); } @Test public void testLotsOfDuplicates() { List<Integer> all = new ArrayList<>(); all.add(1); for (int i = 1; i < 300; i++) { all.add(2); } for (int i = 300; i < 1000; i++) { all.add(3); } checkCombineFn( ApproximateQuantilesCombineFn.<Integer>create(5), all, Arrays.asList(1, 2, 3, 3, 3)); } @Test public void testLogDistribution() { List<Integer> all = new ArrayList<>(); for (int i = 1; i < 1000; i++) { all.add((int) Math.log(i)); } checkCombineFn( ApproximateQuantilesCombineFn.<Integer>create(5), all, Arrays.asList(0, 5, 6, 6, 6)); } @Test public void testZipfianDistribution() { List<Integer> all = new ArrayList<>(); for (int i = 1; i < 1000; i++) { all.add(1000 / i); } checkCombineFn( ApproximateQuantilesCombineFn.<Integer>create(5), all, Arrays.asList(1, 1, 2, 4, 1000)); } @Test public void testAlternateComparator() { List<String> inputs = Arrays.asList( "aa", "aaa", "aaaa", "b", "ccccc", "dddd", "zz"); checkCombineFn( ApproximateQuantilesCombineFn.<String>create(3), inputs, Arrays.asList("aa", "b", "zz")); checkCombineFn( ApproximateQuantilesCombineFn.create(3, new OrderByLength()), inputs, Arrays.asList("b", "aaa", "ccccc")); } @Test public void testDisplayData() { Top.Natural<Integer> comparer = new Top.Natural<Integer>(); PTransform<?, ?> approxQuanitiles = ApproximateQuantiles.globally(20, comparer); DisplayData displayData = DisplayData.from(approxQuanitiles); assertThat(displayData, hasDisplayItem("numQuantiles", 20)); assertThat(displayData, hasDisplayItem("comparer", comparer.getClass())); } private Matcher<Iterable<? extends Integer>> quantileMatcher( int size, int numQuantiles, int absoluteError) { List<Matcher<? super Integer>> quantiles = new ArrayList<>(); quantiles.add(CoreMatchers.is(0)); for (int k = 1; k < numQuantiles - 1; k++) { int expected = (int) (((double) (size - 1)) * k / (numQuantiles - 1)); quantiles.add(new Between<>( expected - absoluteError, expected + absoluteError)); } quantiles.add(CoreMatchers.is(size - 1)); return contains(quantiles); } private static class Between<T extends Comparable<T>> extends TypeSafeDiagnosingMatcher<T> { private final T min; private final T max; private Between(T min, T max) { this.min = min; this.max = max; } @Override public void describeTo(Description description) { description.appendText("is between " + min + " and " + max); } @Override protected boolean matchesSafely(T item, Description mismatchDescription) { return min.compareTo(item) <= 0 && item.compareTo(max) <= 0; } } private static class DescendingIntComparator implements SerializableComparator<Integer> { @Override public int compare(Integer o1, Integer o2) { return o2.compareTo(o1); } } private static class OrderByLength implements Comparator<String>, Serializable { @Override public int compare(String a, String b) { if (a.length() != b.length()) { return a.length() - b.length(); } else { return a.compareTo(b); } } } private PCollection<Integer> intRangeCollection(Pipeline p, int size) { return p.apply("CreateIntsUpTo(" + size + ")", Create.of(intRange(size))); } private List<Integer> intRange(int size) { List<Integer> all = new ArrayList<>(size); for (int i = 0; i < size; i++) { all.add(i); } return all; } }