/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.beam.sdk.transforms; import static org.apache.beam.sdk.TestUtils.KvMatcher.isKv; import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem; import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.CoreMatchers.hasItem; import static org.hamcrest.Matchers.empty; import static org.hamcrest.collection.IsIterableContainingInAnyOrder.containsInAnyOrder; import static org.hamcrest.core.Is.is; import static org.junit.Assert.assertThat; import com.google.common.base.Function; import com.google.common.collect.Iterables; import java.io.DataInputStream; import java.io.DataOutputStream; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.concurrent.ThreadLocalRandom; import org.apache.beam.sdk.coders.AtomicCoder; import org.apache.beam.sdk.coders.BigEndianIntegerCoder; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.CoderProviders; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.coders.MapCoder; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.testing.NeedsRunner; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.testing.ValidatesRunner; import org.apache.beam.sdk.transforms.display.DisplayData; import org.apache.beam.sdk.transforms.windowing.FixedWindows; import org.apache.beam.sdk.transforms.windowing.InvalidWindows; import org.apache.beam.sdk.transforms.windowing.Sessions; import org.apache.beam.sdk.transforms.windowing.TimestampCombiner; import org.apache.beam.sdk.transforms.windowing.Window; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.TimestampedValue; import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.sdk.values.WindowingStrategy; import org.joda.time.Duration; import org.joda.time.Instant; import org.junit.Assert; import org.junit.Rule; import org.junit.Test; import org.junit.experimental.categories.Category; import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; /** * Tests for GroupByKey. */ @RunWith(JUnit4.class) @SuppressWarnings({"rawtypes", "unchecked"}) public class GroupByKeyTest { @Rule public final TestPipeline p = TestPipeline.create(); @Rule public ExpectedException thrown = ExpectedException.none(); @Test @Category(ValidatesRunner.class) public void testGroupByKey() { List<KV<String, Integer>> ungroupedPairs = Arrays.asList( KV.of("k1", 3), KV.of("k5", Integer.MAX_VALUE), KV.of("k5", Integer.MIN_VALUE), KV.of("k2", 66), KV.of("k1", 4), KV.of("k2", -33), KV.of("k3", 0)); PCollection<KV<String, Integer>> input = p.apply(Create.of(ungroupedPairs) .withCoder(KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of()))); PCollection<KV<String, Iterable<Integer>>> output = input.apply(GroupByKey.<String, Integer>create()); PAssert.that(output) .satisfies(new AssertThatHasExpectedContentsForTestGroupByKey()); p.run(); } static class AssertThatHasExpectedContentsForTestGroupByKey implements SerializableFunction<Iterable<KV<String, Iterable<Integer>>>, Void> { @Override public Void apply(Iterable<KV<String, Iterable<Integer>>> actual) { assertThat(actual, containsInAnyOrder( isKv(is("k1"), containsInAnyOrder(3, 4)), isKv(is("k5"), containsInAnyOrder(Integer.MAX_VALUE, Integer.MIN_VALUE)), isKv(is("k2"), containsInAnyOrder(66, -33)), isKv(is("k3"), containsInAnyOrder(0)))); return null; } } @Test @Category(ValidatesRunner.class) public void testGroupByKeyAndWindows() { List<KV<String, Integer>> ungroupedPairs = Arrays.asList( KV.of("k1", 3), // window [0, 5) KV.of("k5", Integer.MAX_VALUE), // window [0, 5) KV.of("k5", Integer.MIN_VALUE), // window [0, 5) KV.of("k2", 66), // window [0, 5) KV.of("k1", 4), // window [5, 10) KV.of("k2", -33), // window [5, 10) KV.of("k3", 0)); // window [5, 10) PCollection<KV<String, Integer>> input = p.apply(Create.timestamped(ungroupedPairs, Arrays.asList(1L, 2L, 3L, 4L, 5L, 6L, 7L)) .withCoder(KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of()))); PCollection<KV<String, Iterable<Integer>>> output = input.apply(Window.<KV<String, Integer>>into(FixedWindows.of(new Duration(5)))) .apply(GroupByKey.<String, Integer>create()); PAssert.that(output) .satisfies(new AssertThatHasExpectedContentsForTestGroupByKeyAndWindows()); p.run(); } static class AssertThatHasExpectedContentsForTestGroupByKeyAndWindows implements SerializableFunction<Iterable<KV<String, Iterable<Integer>>>, Void> { @Override public Void apply(Iterable<KV<String, Iterable<Integer>>> actual) { assertThat(actual, containsInAnyOrder( isKv(is("k1"), containsInAnyOrder(3)), isKv(is("k1"), containsInAnyOrder(4)), isKv(is("k5"), containsInAnyOrder(Integer.MAX_VALUE, Integer.MIN_VALUE)), isKv(is("k2"), containsInAnyOrder(66)), isKv(is("k2"), containsInAnyOrder(-33)), isKv(is("k3"), containsInAnyOrder(0)))); return null; } } @Test @Category(ValidatesRunner.class) public void testGroupByKeyEmpty() { List<KV<String, Integer>> ungroupedPairs = Arrays.asList(); PCollection<KV<String, Integer>> input = p.apply(Create.of(ungroupedPairs) .withCoder(KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of()))); PCollection<KV<String, Iterable<Integer>>> output = input.apply(GroupByKey.<String, Integer>create()); PAssert.that(output).empty(); p.run(); } @Test public void testGroupByKeyNonDeterministic() throws Exception { List<KV<Map<String, String>, Integer>> ungroupedPairs = Arrays.asList(); PCollection<KV<Map<String, String>, Integer>> input = p.apply(Create.of(ungroupedPairs) .withCoder( KvCoder.of(MapCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of()), BigEndianIntegerCoder.of()))); thrown.expect(IllegalStateException.class); thrown.expectMessage("must be deterministic"); input.apply(GroupByKey.<Map<String, String>, Integer>create()); } @Test @Category(NeedsRunner.class) public void testIdentityWindowFnPropagation() { List<KV<String, Integer>> ungroupedPairs = Arrays.asList(); PCollection<KV<String, Integer>> input = p.apply(Create.of(ungroupedPairs) .withCoder(KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of()))) .apply(Window.<KV<String, Integer>>into(FixedWindows.of(Duration.standardMinutes(1)))); PCollection<KV<String, Iterable<Integer>>> output = input.apply(GroupByKey.<String, Integer>create()); p.run(); Assert.assertTrue(output.getWindowingStrategy().getWindowFn().isCompatible( FixedWindows.of(Duration.standardMinutes(1)))); } @Test @Category(NeedsRunner.class) public void testWindowFnInvalidation() { List<KV<String, Integer>> ungroupedPairs = Arrays.asList(); PCollection<KV<String, Integer>> input = p.apply(Create.of(ungroupedPairs) .withCoder(KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of()))) .apply(Window.<KV<String, Integer>>into( Sessions.withGapDuration(Duration.standardMinutes(1)))); PCollection<KV<String, Iterable<Integer>>> output = input.apply(GroupByKey.<String, Integer>create()); p.run(); Assert.assertTrue( output.getWindowingStrategy().getWindowFn().isCompatible( new InvalidWindows( "Invalid", Sessions.withGapDuration( Duration.standardMinutes(1))))); } @Test public void testInvalidWindowsDirect() { List<KV<String, Integer>> ungroupedPairs = Arrays.asList(); PCollection<KV<String, Integer>> input = p.apply(Create.of(ungroupedPairs) .withCoder(KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of()))) .apply(Window.<KV<String, Integer>>into( Sessions.withGapDuration(Duration.standardMinutes(1)))); thrown.expect(IllegalStateException.class); thrown.expectMessage("GroupByKey must have a valid Window merge function"); input .apply("GroupByKey", GroupByKey.<String, Integer>create()) .apply("GroupByKeyAgain", GroupByKey.<String, Iterable<Integer>>create()); } @Test @Category(NeedsRunner.class) public void testRemerge() { List<KV<String, Integer>> ungroupedPairs = Arrays.asList(); PCollection<KV<String, Integer>> input = p.apply(Create.of(ungroupedPairs) .withCoder(KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of()))) .apply(Window.<KV<String, Integer>>into( Sessions.withGapDuration(Duration.standardMinutes(1)))); PCollection<KV<String, Iterable<Iterable<Integer>>>> middle = input .apply("GroupByKey", GroupByKey.<String, Integer>create()) .apply("Remerge", Window.<KV<String, Iterable<Integer>>>remerge()) .apply("GroupByKeyAgain", GroupByKey.<String, Iterable<Integer>>create()) .apply("RemergeAgain", Window.<KV<String, Iterable<Iterable<Integer>>>>remerge()); p.run(); Assert.assertTrue( middle.getWindowingStrategy().getWindowFn().isCompatible( Sessions.withGapDuration(Duration.standardMinutes(1)))); } @Test public void testGroupByKeyDirectUnbounded() { PCollection<KV<String, Integer>> input = p.apply( new PTransform<PBegin, PCollection<KV<String, Integer>>>() { @Override public PCollection<KV<String, Integer>> expand(PBegin input) { return PCollection.<KV<String, Integer>>createPrimitiveOutputInternal( input.getPipeline(), WindowingStrategy.globalDefault(), PCollection.IsBounded.UNBOUNDED) .setTypeDescriptor(new TypeDescriptor<KV<String, Integer>>() {}); } }); thrown.expect(IllegalStateException.class); thrown.expectMessage( "GroupByKey cannot be applied to non-bounded PCollection in the GlobalWindow without " + "a trigger. Use a Window.into or Window.triggering transform prior to GroupByKey."); input.apply("GroupByKey", GroupByKey.<String, Integer>create()); } /** * Tests that when two elements are combined via a GroupByKey their output timestamp agrees * with the windowing function customized to actually be the same as the default, the earlier of * the two values. */ @Test @Category(ValidatesRunner.class) public void testTimestampCombinerEarliest() { p.apply( Create.timestamped( TimestampedValue.of(KV.of(0, "hello"), new Instant(0)), TimestampedValue.of(KV.of(0, "goodbye"), new Instant(10)))) .apply(Window.<KV<Integer, String>>into(FixedWindows.of(Duration.standardMinutes(10))) .withTimestampCombiner(TimestampCombiner.EARLIEST)) .apply(GroupByKey.<Integer, String>create()) .apply(ParDo.of(new AssertTimestamp(new Instant(0)))); p.run(); } /** * Tests that when two elements are combined via a GroupByKey their output timestamp agrees * with the windowing function customized to use the latest value. */ @Test @Category(ValidatesRunner.class) public void testTimestampCombinerLatest() { p.apply( Create.timestamped( TimestampedValue.of(KV.of(0, "hello"), new Instant(0)), TimestampedValue.of(KV.of(0, "goodbye"), new Instant(10)))) .apply(Window.<KV<Integer, String>>into(FixedWindows.of(Duration.standardMinutes(10))) .withTimestampCombiner(TimestampCombiner.LATEST)) .apply(GroupByKey.<Integer, String>create()) .apply(ParDo.of(new AssertTimestamp(new Instant(10)))); p.run(); } private static class AssertTimestamp<K, V> extends DoFn<KV<K, V>, Void> { private final Instant timestamp; public AssertTimestamp(Instant timestamp) { this.timestamp = timestamp; } @ProcessElement public void processElement(ProcessContext c) throws Exception { assertThat(c.timestamp(), equalTo(timestamp)); } } @Test public void testGroupByKeyGetName() { Assert.assertEquals("GroupByKey", GroupByKey.<String, Integer>create().getName()); } @Test public void testDisplayData() { GroupByKey<String, String> groupByKey = GroupByKey.create(); GroupByKey<String, String> groupByFewKeys = GroupByKey.createWithFewKeys(); DisplayData gbkDisplayData = DisplayData.from(groupByKey); DisplayData fewKeysDisplayData = DisplayData.from(groupByFewKeys); assertThat(gbkDisplayData.items(), empty()); assertThat(fewKeysDisplayData, hasDisplayItem("fewKeys", true)); } /** * Verify that runners correctly hash/group on the encoded value * and not the value itself. */ @Test @Category(ValidatesRunner.class) public void testGroupByKeyWithBadEqualsHashCode() throws Exception { final int numValues = 10; final int numKeys = 5; p.getCoderRegistry().registerCoderProvider( CoderProviders.fromStaticMethods(BadEqualityKey.class, DeterministicKeyCoder.class)); // construct input data List<KV<BadEqualityKey, Long>> input = new ArrayList<>(); for (int i = 0; i < numValues; i++) { for (int key = 0; key < numKeys; key++) { input.add(KV.of(new BadEqualityKey(key), 1L)); } } // We first ensure that the values are randomly partitioned in the beginning. // Some runners might otherwise keep all values on the machine where // they are initially created. PCollection<KV<BadEqualityKey, Long>> dataset1 = p .apply(Create.of(input)) .apply(ParDo.of(new AssignRandomKey())) .apply(Reshuffle.<Long, KV<BadEqualityKey, Long>>of()) .apply(Values.<KV<BadEqualityKey, Long>>create()); // Make the GroupByKey and Count implicit, in real-world code // this would be a Count.perKey() PCollection<KV<BadEqualityKey, Long>> result = dataset1 .apply(GroupByKey.<BadEqualityKey, Long>create()) .apply(Combine.<BadEqualityKey, Long>groupedValues(new CountFn())); PAssert.that(result).satisfies(new AssertThatCountPerKeyCorrect(numValues)); PAssert.that(result.apply(Keys.<BadEqualityKey>create())) .satisfies(new AssertThatAllKeysExist(numKeys)); p.run(); } /** * This is a bogus key class that returns random hash values from {@link #hashCode()} and always * returns {@code false} for {@link #equals(Object)}. The results of the test are correct if * the runner correctly hashes and sorts on the encoded bytes. */ static class BadEqualityKey { long key; public BadEqualityKey() {} public BadEqualityKey(long key) { this.key = key; } @Override public boolean equals(Object o) { return false; } @Override public int hashCode() { return ThreadLocalRandom.current().nextInt(); } } /** * Deterministic {@link Coder} for {@link BadEqualityKey}. */ static class DeterministicKeyCoder extends AtomicCoder<BadEqualityKey> { public static DeterministicKeyCoder of() { return INSTANCE; } ///////////////////////////////////////////////////////////////////////////// private static final DeterministicKeyCoder INSTANCE = new DeterministicKeyCoder(); private DeterministicKeyCoder() {} @Override public void encode(BadEqualityKey value, OutputStream outStream) throws IOException { new DataOutputStream(outStream).writeLong(value.key); } @Override public BadEqualityKey decode(InputStream inStream) throws IOException { return new BadEqualityKey(new DataInputStream(inStream).readLong()); } @Override public void verifyDeterministic() {} } /** * Creates a KV that wraps the original KV together with a random key. */ static class AssignRandomKey extends DoFn<KV<BadEqualityKey, Long>, KV<Long, KV<BadEqualityKey, Long>>> { @ProcessElement public void processElement(ProcessContext c) throws Exception { c.output(KV.of(ThreadLocalRandom.current().nextLong(), c.element())); } } static class CountFn implements SerializableFunction<Iterable<Long>, Long> { @Override public Long apply(Iterable<Long> input) { long result = 0L; for (Long in: input) { result += in; } return result; } } static class AssertThatCountPerKeyCorrect implements SerializableFunction<Iterable<KV<BadEqualityKey, Long>>, Void> { private final int numValues; AssertThatCountPerKeyCorrect(int numValues) { this.numValues = numValues; } @Override public Void apply(Iterable<KV<BadEqualityKey, Long>> input) { for (KV<BadEqualityKey, Long> val: input) { Assert.assertEquals(numValues, (long) val.getValue()); } return null; } } static class AssertThatAllKeysExist implements SerializableFunction<Iterable<BadEqualityKey>, Void> { private final int numKeys; AssertThatAllKeysExist(int numKeys) { this.numKeys = numKeys; } private static <T> Iterable<Object> asStructural( final Iterable<T> iterable, final Coder<T> coder) { return Iterables.transform( iterable, new Function<T, Object>() { @Override public Object apply(T input) { try { return coder.structuralValue(input); } catch (Exception e) { Assert.fail("Could not structural values."); throw new RuntimeException(); // to satisfy the compiler... } } }); } @Override public Void apply(Iterable<BadEqualityKey> input) { final DeterministicKeyCoder keyCoder = DeterministicKeyCoder.of(); List<BadEqualityKey> expectedList = new ArrayList<>(); for (int key = 0; key < numKeys; key++) { expectedList.add(new BadEqualityKey(key)); } Iterable<Object> structuralInput = asStructural(input, keyCoder); Iterable<Object> structuralExpected = asStructural(expectedList, keyCoder); for (Object expected: structuralExpected) { assertThat(structuralInput, hasItem(expected)); } return null; } } }