/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.transforms;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static org.apache.beam.sdk.TestUtils.checkCombineFn;
import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem;
import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasNamespace;
import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.includesDisplayDataFor;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasItem;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThat;
import com.google.common.base.MoreObjects;
import com.google.common.collect.ImmutableSet;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.Serializable;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.coders.AtomicCoder;
import org.apache.beam.sdk.coders.BigEndianIntegerCoder;
import org.apache.beam.sdk.coders.BigEndianLongCoder;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.CoderException;
import org.apache.beam.sdk.coders.CoderRegistry;
import org.apache.beam.sdk.coders.DoubleCoder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.coders.SerializableCoder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.coders.VarIntCoder;
import org.apache.beam.sdk.coders.VoidCoder;
import org.apache.beam.sdk.testing.NeedsRunner;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.testing.ValidatesRunner;
import org.apache.beam.sdk.transforms.Combine.CombineFn;
import org.apache.beam.sdk.transforms.CombineTest.TestCombineFn.Accumulator;
import org.apache.beam.sdk.transforms.CombineWithContext.CombineFnWithContext;
import org.apache.beam.sdk.transforms.CombineWithContext.Context;
import org.apache.beam.sdk.transforms.display.DisplayData;
import org.apache.beam.sdk.transforms.display.DisplayDataEvaluator;
import org.apache.beam.sdk.transforms.windowing.AfterPane;
import org.apache.beam.sdk.transforms.windowing.FixedWindows;
import org.apache.beam.sdk.transforms.windowing.GlobalWindows;
import org.apache.beam.sdk.transforms.windowing.Repeatedly;
import org.apache.beam.sdk.transforms.windowing.Sessions;
import org.apache.beam.sdk.transforms.windowing.SlidingWindows;
import org.apache.beam.sdk.transforms.windowing.Window;
import org.apache.beam.sdk.transforms.windowing.Window.ClosingBehavior;
import org.apache.beam.sdk.util.common.ElementByteSizeObserver;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.POutput;
import org.apache.beam.sdk.values.TimestampedValue;
import org.hamcrest.Matchers;
import org.joda.time.Duration;
import org.joda.time.Instant;
import org.junit.Rule;
import org.junit.Test;
import org.junit.experimental.categories.Category;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.Mock;
/**
* Tests for Combine transforms.
*/
@RunWith(JUnit4.class)
public class CombineTest implements Serializable {
// This test is Serializable, just so that it's easy to have
// anonymous inner classes inside the non-static test methods.
static final List<KV<String, Integer>> TABLE = Arrays.asList(
KV.of("a", 1),
KV.of("a", 1),
KV.of("a", 4),
KV.of("b", 1),
KV.of("b", 13)
);
static final List<KV<String, Integer>> EMPTY_TABLE = Collections.emptyList();
@Mock private DoFn<?, ?>.ProcessContext processContext;
@Rule
public final transient TestPipeline pipeline = TestPipeline.create();
PCollection<KV<String, Integer>> createInput(Pipeline p,
List<KV<String, Integer>> table) {
return p.apply(Create.of(table).withCoder(
KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of())));
}
private void runTestSimpleCombine(List<KV<String, Integer>> table,
int globalSum,
List<KV<String, String>> perKeyCombines) {
PCollection<KV<String, Integer>> input = createInput(pipeline, table);
PCollection<Integer> sum = input
.apply(Values.<Integer>create())
.apply(Combine.globally(new SumInts()));
// Java 8 will infer.
PCollection<KV<String, String>> sumPerKey = input
.apply(Combine.<String, Integer, String>perKey(new TestCombineFn()));
PAssert.that(sum).containsInAnyOrder(globalSum);
PAssert.that(sumPerKey).containsInAnyOrder(perKeyCombines);
pipeline.run();
}
private void runTestSimpleCombineWithContext(List<KV<String, Integer>> table,
int globalSum,
List<KV<String, String>> perKeyCombines,
String[] globallyCombines) {
PCollection<KV<String, Integer>> perKeyInput = createInput(pipeline, table);
PCollection<Integer> globallyInput = perKeyInput.apply(Values.<Integer>create());
PCollection<Integer> sum = globallyInput.apply("Sum", Combine.globally(new SumInts()));
PCollectionView<Integer> globallySumView = sum.apply(View.<Integer>asSingleton());
// Java 8 will infer.
PCollection<KV<String, String>> combinePerKey =
perKeyInput.apply(
Combine.<String, Integer, String>perKey(new TestCombineFnWithContext(globallySumView))
.withSideInputs(Arrays.asList(globallySumView)));
PCollection<String> combineGlobally = globallyInput
.apply(Combine.globally(new TestCombineFnWithContext(globallySumView))
.withoutDefaults()
.withSideInputs(Arrays.asList(globallySumView)));
PAssert.that(sum).containsInAnyOrder(globalSum);
PAssert.that(combinePerKey).containsInAnyOrder(perKeyCombines);
PAssert.that(combineGlobally).containsInAnyOrder(globallyCombines);
pipeline.run();
}
@Test
@Category(ValidatesRunner.class)
@SuppressWarnings({"rawtypes", "unchecked"})
public void testSimpleCombine() {
runTestSimpleCombine(TABLE, 20, Arrays.asList(KV.of("a", "114"), KV.of("b", "113")));
}
@Test
@Category(ValidatesRunner.class)
@SuppressWarnings({"rawtypes", "unchecked"})
public void testSimpleCombineWithContext() {
runTestSimpleCombineWithContext(TABLE, 20,
Arrays.asList(KV.of("a", "01124"), KV.of("b", "01123")),
new String[] {"01111234"});
}
@Test
@Category(ValidatesRunner.class)
public void testSimpleCombineWithContextEmpty() {
runTestSimpleCombineWithContext(
EMPTY_TABLE, 0, Collections.<KV<String, String>>emptyList(), new String[] {});
}
@Test
@Category(ValidatesRunner.class)
public void testSimpleCombineEmpty() {
runTestSimpleCombine(EMPTY_TABLE, 0, Collections.<KV<String, String>>emptyList());
}
@SuppressWarnings("unchecked")
private void runTestBasicCombine(List<KV<String, Integer>> table,
Set<Integer> globalUnique,
List<KV<String, Set<Integer>>> perKeyUnique) {
PCollection<KV<String, Integer>> input = createInput(pipeline, table);
PCollection<Set<Integer>> unique = input
.apply(Values.<Integer>create())
.apply(Combine.globally(new UniqueInts()));
// Java 8 will infer.
PCollection<KV<String, Set<Integer>>> uniquePerKey = input
.apply(Combine.<String, Integer, Set<Integer>>perKey(new UniqueInts()));
PAssert.that(unique).containsInAnyOrder(globalUnique);
PAssert.that(uniquePerKey).containsInAnyOrder(perKeyUnique);
pipeline.run();
}
@Test
@Category(ValidatesRunner.class)
public void testBasicCombine() {
runTestBasicCombine(TABLE, ImmutableSet.of(1, 13, 4), Arrays.asList(
KV.of("a", (Set<Integer>) ImmutableSet.of(1, 4)),
KV.of("b", (Set<Integer>) ImmutableSet.of(1, 13))));
}
@Test
@Category(ValidatesRunner.class)
public void testBasicCombineEmpty() {
runTestBasicCombine(
EMPTY_TABLE, ImmutableSet.<Integer>of(), Collections.<KV<String, Set<Integer>>>emptyList());
}
private void runTestAccumulatingCombine(List<KV<String, Integer>> table,
Double globalMean,
List<KV<String, Double>> perKeyMeans) {
PCollection<KV<String, Integer>> input = createInput(pipeline, table);
PCollection<Double> mean = input
.apply(Values.<Integer>create())
.apply(Combine.globally(new MeanInts()));
// Java 8 will infer.
PCollection<KV<String, Double>> meanPerKey = input.apply(
Combine.<String, Integer, Double>perKey(new MeanInts()));
PAssert.that(mean).containsInAnyOrder(globalMean);
PAssert.that(meanPerKey).containsInAnyOrder(perKeyMeans);
pipeline.run();
}
@Test
@Category(ValidatesRunner.class)
public void testFixedWindowsCombine() {
PCollection<KV<String, Integer>> input =
pipeline.apply(Create.timestamped(TABLE, Arrays.asList(0L, 1L, 6L, 7L, 8L))
.withCoder(KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of())))
.apply(Window.<KV<String, Integer>>into(FixedWindows.of(Duration.millis(2))));
PCollection<Integer> sum = input
.apply(Values.<Integer>create())
.apply(Combine.globally(new SumInts()).withoutDefaults());
PCollection<KV<String, String>> sumPerKey = input
.apply(Combine.<String, Integer, String>perKey(new TestCombineFn()));
PAssert.that(sum).containsInAnyOrder(2, 5, 13);
PAssert.that(sumPerKey).containsInAnyOrder(
KV.of("a", "11"),
KV.of("a", "4"),
KV.of("b", "1"),
KV.of("b", "13"));
pipeline.run();
}
@Test
@Category(ValidatesRunner.class)
public void testFixedWindowsCombineWithContext() {
PCollection<KV<String, Integer>> perKeyInput =
pipeline.apply(Create.timestamped(TABLE, Arrays.asList(0L, 1L, 6L, 7L, 8L))
.withCoder(KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of())))
.apply(Window.<KV<String, Integer>>into(FixedWindows.of(Duration.millis(2))));
PCollection<Integer> globallyInput = perKeyInput.apply(Values.<Integer>create());
PCollection<Integer> sum = globallyInput
.apply("Sum", Combine.globally(new SumInts()).withoutDefaults());
PCollectionView<Integer> globallySumView = sum.apply(View.<Integer>asSingleton());
PCollection<KV<String, String>> combinePerKeyWithContext =
perKeyInput.apply(
Combine.<String, Integer, String>perKey(new TestCombineFnWithContext(globallySumView))
.withSideInputs(Arrays.asList(globallySumView)));
PCollection<String> combineGloballyWithContext = globallyInput
.apply(Combine.globally(new TestCombineFnWithContext(globallySumView))
.withoutDefaults()
.withSideInputs(Arrays.asList(globallySumView)));
PAssert.that(sum).containsInAnyOrder(2, 5, 13);
PAssert.that(combinePerKeyWithContext).containsInAnyOrder(
KV.of("a", "112"),
KV.of("a", "45"),
KV.of("b", "15"),
KV.of("b", "1133"));
PAssert.that(combineGloballyWithContext).containsInAnyOrder("112", "145", "1133");
pipeline.run();
}
@Test
@Category(ValidatesRunner.class)
public void testSlidingWindowsCombineWithContext() {
PCollection<KV<String, Integer>> perKeyInput =
pipeline.apply(Create.timestamped(TABLE, Arrays.asList(2L, 3L, 8L, 9L, 10L))
.withCoder(KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of())))
.apply(Window.<KV<String, Integer>>into(SlidingWindows.of(Duration.millis(2))));
PCollection<Integer> globallyInput = perKeyInput.apply(Values.<Integer>create());
PCollection<Integer> sum = globallyInput
.apply("Sum", Combine.globally(new SumInts()).withoutDefaults());
PCollectionView<Integer> globallySumView = sum.apply(View.<Integer>asSingleton());
PCollection<KV<String, String>> combinePerKeyWithContext =
perKeyInput.apply(
Combine.<String, Integer, String>perKey(new TestCombineFnWithContext(globallySumView))
.withSideInputs(Arrays.asList(globallySumView)));
PCollection<String> combineGloballyWithContext = globallyInput
.apply(Combine.globally(new TestCombineFnWithContext(globallySumView))
.withoutDefaults()
.withSideInputs(Arrays.asList(globallySumView)));
PAssert.that(sum).containsInAnyOrder(1, 2, 1, 4, 5, 14, 13);
PAssert.that(combinePerKeyWithContext).containsInAnyOrder(
KV.of("a", "11"),
KV.of("a", "112"),
KV.of("a", "11"),
KV.of("a", "44"),
KV.of("a", "45"),
KV.of("b", "15"),
KV.of("b", "11134"),
KV.of("b", "1133"));
PAssert.that(combineGloballyWithContext).containsInAnyOrder(
"11", "112", "11", "44", "145", "11134", "1133");
pipeline.run();
}
private static class FormatPaneInfo extends DoFn<Integer, String> {
@ProcessElement
public void processElement(ProcessContext c) {
c.output(c.element() + ": " + c.pane().isLast());
}
}
@Test
@Category(ValidatesRunner.class)
public void testGlobalCombineWithDefaultsAndTriggers() {
PCollection<Integer> input = pipeline.apply(Create.of(1, 1));
PCollection<String> output = input
.apply(Window.<Integer>into(new GlobalWindows())
.triggering(Repeatedly.forever(AfterPane.elementCountAtLeast(1)))
.accumulatingFiredPanes()
.withAllowedLateness(new Duration(0)))
.apply(Sum.integersGlobally())
.apply(ParDo.of(new FormatPaneInfo()));
// The actual elements produced are nondeterministic. Could be one, could be two.
// But it should certainly have a final element with the correct final sum.
PAssert.that(output).satisfies(new SerializableFunction<Iterable<String>, Void>() {
@Override
public Void apply(Iterable<String> input) {
assertThat(input, hasItem("2: true"));
return null;
}
});
pipeline.run();
}
@Test
@Category(ValidatesRunner.class)
public void testSessionsCombine() {
PCollection<KV<String, Integer>> input =
pipeline.apply(Create.timestamped(TABLE, Arrays.asList(0L, 4L, 7L, 10L, 16L))
.withCoder(KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of())))
.apply(Window.<KV<String, Integer>>into(Sessions.withGapDuration(Duration.millis(5))));
PCollection<Integer> sum = input
.apply(Values.<Integer>create())
.apply(Combine.globally(new SumInts()).withoutDefaults());
PCollection<KV<String, String>> sumPerKey = input
.apply(Combine.<String, Integer, String>perKey(new TestCombineFn()));
PAssert.that(sum).containsInAnyOrder(7, 13);
PAssert.that(sumPerKey).containsInAnyOrder(
KV.of("a", "114"),
KV.of("b", "1"),
KV.of("b", "13"));
pipeline.run();
}
@Test
@Category(ValidatesRunner.class)
public void testSessionsCombineWithContext() {
PCollection<KV<String, Integer>> perKeyInput =
pipeline.apply(Create.timestamped(TABLE, Arrays.asList(0L, 4L, 7L, 10L, 16L))
.withCoder(KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of())));
PCollection<Integer> globallyInput = perKeyInput.apply(Values.<Integer>create());
PCollection<Integer> fixedWindowsSum = globallyInput
.apply("FixedWindows",
Window.<Integer>into(FixedWindows.of(Duration.millis(5))))
.apply("Sum", Combine.globally(new SumInts()).withoutDefaults());
PCollectionView<Integer> globallyFixedWindowsView =
fixedWindowsSum.apply(View.<Integer>asSingleton().withDefaultValue(0));
PCollection<KV<String, String>> sessionsCombinePerKey =
perKeyInput
.apply(
"PerKey Input Sessions",
Window.<KV<String, Integer>>into(Sessions.withGapDuration(Duration.millis(5))))
.apply(
Combine.<String, Integer, String>perKey(
new TestCombineFnWithContext(globallyFixedWindowsView))
.withSideInputs(Arrays.asList(globallyFixedWindowsView)));
PCollection<String> sessionsCombineGlobally = globallyInput
.apply("Globally Input Sessions",
Window.<Integer>into(Sessions.withGapDuration(Duration.millis(5))))
.apply(Combine.globally(new TestCombineFnWithContext(globallyFixedWindowsView))
.withoutDefaults()
.withSideInputs(Arrays.asList(globallyFixedWindowsView)));
PAssert.that(fixedWindowsSum).containsInAnyOrder(2, 4, 1, 13);
PAssert.that(sessionsCombinePerKey).containsInAnyOrder(
KV.of("a", "1114"),
KV.of("b", "11"),
KV.of("b", "013"));
PAssert.that(sessionsCombineGlobally).containsInAnyOrder("11114", "013");
pipeline.run();
}
@Test
@Category(ValidatesRunner.class)
public void testWindowedCombineEmpty() {
PCollection<Double> mean = pipeline
.apply(Create.empty(BigEndianIntegerCoder.of()))
.apply(Window.<Integer>into(FixedWindows.of(Duration.millis(1))))
.apply(Combine.globally(new MeanInts()).withoutDefaults());
PAssert.that(mean).empty();
pipeline.run();
}
@Test
@Category(ValidatesRunner.class)
public void testAccumulatingCombine() {
runTestAccumulatingCombine(TABLE, 4.0, Arrays.asList(KV.of("a", 2.0), KV.of("b", 7.0)));
}
@Test
@Category(ValidatesRunner.class)
public void testAccumulatingCombineEmpty() {
runTestAccumulatingCombine(EMPTY_TABLE, 0.0, Collections.<KV<String, Double>>emptyList());
}
// Checks that Min, Max, Mean, Sum (operations that pass-through to Combine) have good names.
@Test
public void testCombinerNames() {
Combine.PerKey<String, Integer, Integer> min = Min.integersPerKey();
Combine.PerKey<String, Integer, Integer> max = Max.integersPerKey();
Combine.PerKey<String, Integer, Double> mean = Mean.perKey();
Combine.PerKey<String, Integer, Integer> sum = Sum.integersPerKey();
assertThat(min.getName(), equalTo("Combine.perKey(MinInteger)"));
assertThat(max.getName(), equalTo("Combine.perKey(MaxInteger)"));
assertThat(mean.getName(), equalTo("Combine.perKey(Mean)"));
assertThat(sum.getName(), equalTo("Combine.perKey(SumInteger)"));
}
private static final SerializableFunction<String, Integer> hotKeyFanout =
new SerializableFunction<String, Integer>() {
@Override
public Integer apply(String input) {
return input.equals("a") ? 3 : 0;
}
};
private static final SerializableFunction<String, Integer> splitHotKeyFanout =
new SerializableFunction<String, Integer>() {
@Override
public Integer apply(String input) {
return Math.random() < 0.5 ? 3 : 0;
}
};
@Test
@Category(ValidatesRunner.class)
public void testHotKeyCombining() {
PCollection<KV<String, Integer>> input = copy(createInput(pipeline, TABLE), 10);
CombineFn<Integer, ?, Double> mean = new MeanInts();
PCollection<KV<String, Double>> coldMean = input.apply("ColdMean",
Combine.<String, Integer, Double>perKey(mean).withHotKeyFanout(0));
PCollection<KV<String, Double>> warmMean = input.apply("WarmMean",
Combine.<String, Integer, Double>perKey(mean).withHotKeyFanout(hotKeyFanout));
PCollection<KV<String, Double>> hotMean = input.apply("HotMean",
Combine.<String, Integer, Double>perKey(mean).withHotKeyFanout(5));
PCollection<KV<String, Double>> splitMean = input.apply("SplitMean",
Combine.<String, Integer, Double>perKey(mean).withHotKeyFanout(splitHotKeyFanout));
List<KV<String, Double>> expected = Arrays.asList(KV.of("a", 2.0), KV.of("b", 7.0));
PAssert.that(coldMean).containsInAnyOrder(expected);
PAssert.that(warmMean).containsInAnyOrder(expected);
PAssert.that(hotMean).containsInAnyOrder(expected);
PAssert.that(splitMean).containsInAnyOrder(expected);
pipeline.run();
}
private static class GetLast extends DoFn<Integer, Integer> {
@ProcessElement
public void processElement(ProcessContext c) {
if (c.pane().isLast()) {
c.output(c.element());
}
}
}
@Test
@Category(ValidatesRunner.class)
public void testHotKeyCombiningWithAccumulationMode() {
PCollection<Integer> input = pipeline.apply(Create.of(1, 2, 3, 4, 5));
PCollection<Integer> output = input
.apply(Window.<Integer>into(new GlobalWindows())
.triggering(Repeatedly.forever(AfterPane.elementCountAtLeast(1)))
.accumulatingFiredPanes()
.withAllowedLateness(new Duration(0), ClosingBehavior.FIRE_ALWAYS))
.apply(Sum.integersGlobally().withoutDefaults().withFanout(2))
.apply(ParDo.of(new GetLast()));
PAssert.that(output).satisfies(new SerializableFunction<Iterable<Integer>, Void>() {
@Override
public Void apply(Iterable<Integer> input) {
assertThat(input, hasItem(15));
return null;
}
});
pipeline.run();
}
@Test
@Category(NeedsRunner.class)
public void testBinaryCombineFn() {
PCollection<KV<String, Integer>> input = copy(createInput(pipeline, TABLE), 2);
PCollection<KV<String, Integer>> intProduct = input
.apply("IntProduct", Combine.<String, Integer, Integer>perKey(new TestProdInt()));
PCollection<KV<String, Integer>> objProduct = input
.apply("ObjProduct", Combine.<String, Integer, Integer>perKey(new TestProdObj()));
List<KV<String, Integer>> expected = Arrays.asList(KV.of("a", 16), KV.of("b", 169));
PAssert.that(intProduct).containsInAnyOrder(expected);
PAssert.that(objProduct).containsInAnyOrder(expected);
pipeline.run();
}
@Test
public void testBinaryCombineFnWithNulls() {
checkCombineFn(new NullCombiner(), Arrays.asList(3, 3, 5), 45);
checkCombineFn(new NullCombiner(), Arrays.asList(null, 3, 5), 30);
checkCombineFn(new NullCombiner(), Arrays.asList(3, 3, null), 18);
checkCombineFn(new NullCombiner(), Arrays.asList(null, 3, null), 12);
checkCombineFn(new NullCombiner(), Arrays.<Integer>asList(null, null, null), 8);
}
private static final class TestProdInt extends Combine.BinaryCombineIntegerFn {
@Override
public int apply(int left, int right) {
return left * right;
}
@Override
public int identity() {
return 1;
}
}
private static final class TestProdObj extends Combine.BinaryCombineFn<Integer> {
@Override
public Integer apply(Integer left, Integer right) {
return left * right;
}
}
/**
* Computes the product, considering null values to be 2.
*/
private static final class NullCombiner extends Combine.BinaryCombineFn<Integer> {
@Override
public Integer apply(Integer left, Integer right) {
return (left == null ? 2 : left) * (right == null ? 2 : right);
}
}
@Test
@Category(ValidatesRunner.class)
public void testCombineGloballyAsSingletonView() {
final PCollectionView<Integer> view = pipeline
.apply("CreateEmptySideInput", Create.empty(BigEndianIntegerCoder.of()))
.apply(Sum.integersGlobally().asSingletonView());
PCollection<Integer> output = pipeline
.apply("CreateVoidMainInput", Create.of((Void) null))
.apply("OutputSideInput", ParDo.of(new DoFn<Void, Integer>() {
@ProcessElement
public void processElement(ProcessContext c) {
c.output(c.sideInput(view));
}
}).withSideInputs(view));
PAssert.thatSingleton(output).isEqualTo(0);
pipeline.run();
}
@Test
@Category(ValidatesRunner.class)
public void testWindowedCombineGloballyAsSingletonView() {
FixedWindows windowFn = FixedWindows.of(Duration.standardMinutes(1));
final PCollectionView<Integer> view =
pipeline
.apply(
"CreateSideInput",
Create.timestamped(
TimestampedValue.of(1, new Instant(100)),
TimestampedValue.of(3, new Instant(100))))
.apply("WindowSideInput", Window.<Integer>into(windowFn))
.apply("CombineSideInput", Sum.integersGlobally().asSingletonView());
TimestampedValue<Void> nonEmptyElement = TimestampedValue.of(null, new Instant(100));
TimestampedValue<Void> emptyElement = TimestampedValue.atMinimumTimestamp(null);
PCollection<Integer> output =
pipeline
.apply(
"CreateMainInput",
Create.<Void>timestamped(nonEmptyElement, emptyElement).withCoder(VoidCoder.of()))
.apply("WindowMainInput", Window.<Void>into(windowFn))
.apply(
"OutputSideInput",
ParDo.of(
new DoFn<Void, Integer>() {
@ProcessElement
public void processElement(ProcessContext c) {
c.output(c.sideInput(view));
}
})
.withSideInputs(view));
PAssert.that(output).containsInAnyOrder(4, 0);
PAssert.that(output)
.inWindow(windowFn.assignWindow(nonEmptyElement.getTimestamp()))
.containsInAnyOrder(4);
PAssert.that(output)
.inWindow(windowFn.assignWindow(emptyElement.getTimestamp()))
.containsInAnyOrder(0);
pipeline.run();
}
@Test
public void testCombineGetName() {
assertEquals("Combine.globally(SumInts)", Combine.globally(new SumInts()).getName());
assertEquals(
"Combine.GloballyAsSingletonView",
Combine.globally(new SumInts()).asSingletonView().getName());
assertEquals("Combine.perKey(Test)", Combine.perKey(new TestCombineFn()).getName());
assertEquals(
"Combine.perKeyWithFanout(Test)",
Combine.perKey(new TestCombineFn()).withHotKeyFanout(10).getName());
}
@Test
public void testDisplayData() {
UniqueInts combineFn = new UniqueInts() {
@Override
public void populateDisplayData(DisplayData.Builder builder) {
builder.add(DisplayData.item("fnMetadata", "foobar"));
}
};
Combine.Globally<?, ?> combine = Combine.globally(combineFn)
.withFanout(1234);
DisplayData displayData = DisplayData.from(combine);
assertThat(displayData, hasDisplayItem("combineFn", combineFn.getClass()));
assertThat(displayData, hasDisplayItem("emitDefaultOnEmptyInput", true));
assertThat(displayData, hasDisplayItem("fanout", 1234));
assertThat(displayData, includesDisplayDataFor("combineFn", combineFn));
}
@Test
public void testDisplayDataForWrappedFn() {
UniqueInts combineFn = new UniqueInts() {
@Override
public void populateDisplayData(DisplayData.Builder builder) {
builder.add(DisplayData.item("foo", "bar"));
}
};
Combine.PerKey<?, ?, ?> combine = Combine.perKey(combineFn);
DisplayData displayData = DisplayData.from(combine);
assertThat(displayData, hasDisplayItem("combineFn", combineFn.getClass()));
assertThat(displayData, hasDisplayItem(hasNamespace(combineFn.getClass())));
}
@Test
@Category(ValidatesRunner.class)
public void testCombinePerKeyPrimitiveDisplayData() {
DisplayDataEvaluator evaluator = DisplayDataEvaluator.create();
CombineTest.UniqueInts combineFn = new CombineTest.UniqueInts();
PTransform<PCollection<KV<Integer, Integer>>, ? extends POutput> combine =
Combine.perKey(combineFn);
Set<DisplayData> displayData = evaluator.displayDataForPrimitiveTransforms(combine,
KvCoder.of(VarIntCoder.of(), VarIntCoder.of()));
assertThat("Combine.perKey should include the combineFn in its primitive transform",
displayData, hasItem(hasDisplayItem("combineFn", combineFn.getClass())));
}
@Test
@Category(ValidatesRunner.class)
public void testCombinePerKeyWithHotKeyFanoutPrimitiveDisplayData() {
int hotKeyFanout = 2;
DisplayDataEvaluator evaluator = DisplayDataEvaluator.create();
CombineTest.UniqueInts combineFn = new CombineTest.UniqueInts();
PTransform<PCollection<KV<Integer, Integer>>, PCollection<KV<Integer, Set<Integer>>>> combine =
Combine.<Integer, Integer, Set<Integer>>perKey(combineFn).withHotKeyFanout(hotKeyFanout);
Set<DisplayData> displayData = evaluator.displayDataForPrimitiveTransforms(combine,
KvCoder.of(VarIntCoder.of(), VarIntCoder.of()));
assertThat("Combine.perKey.withHotKeyFanout should include the combineFn in its primitive "
+ "transform", displayData, hasItem(hasDisplayItem("combineFn", combineFn.getClass())));
assertThat("Combine.perKey.withHotKeyFanout(int) should include the fanout in its primitive "
+ "transform", displayData, hasItem(hasDisplayItem("fanout", hotKeyFanout)));
}
////////////////////////////////////////////////////////////////////////////
// Test classes, for different kinds of combining fns.
/** Example SerializableFunction combiner. */
public static class SumInts
implements SerializableFunction<Iterable<Integer>, Integer> {
@Override
public Integer apply(Iterable<Integer> input) {
int sum = 0;
for (int item : input) {
sum += item;
}
return sum;
}
}
/** Example CombineFn. */
public static class UniqueInts extends
Combine.CombineFn<Integer, Set<Integer>, Set<Integer>> {
@Override
public Set<Integer> createAccumulator() {
return new HashSet<>();
}
@Override
public Set<Integer> addInput(Set<Integer> accumulator, Integer input) {
accumulator.add(input);
return accumulator;
}
@Override
public Set<Integer> mergeAccumulators(Iterable<Set<Integer>> accumulators) {
Set<Integer> all = new HashSet<>();
for (Set<Integer> part : accumulators) {
all.addAll(part);
}
return all;
}
@Override
public Set<Integer> extractOutput(Set<Integer> accumulator) {
return accumulator;
}
}
/** Example AccumulatingCombineFn. */
private static class MeanInts extends
Combine.AccumulatingCombineFn<Integer, MeanInts.CountSum, Double> {
private static final Coder<Long> LONG_CODER = BigEndianLongCoder.of();
private static final Coder<Double> DOUBLE_CODER = DoubleCoder.of();
class CountSum implements
Combine.AccumulatingCombineFn.Accumulator<Integer, CountSum, Double> {
long count = 0;
double sum = 0.0;
CountSum(long count, double sum) {
this.count = count;
this.sum = sum;
}
@Override
public void addInput(Integer element) {
count++;
sum += element.doubleValue();
}
@Override
public void mergeAccumulator(CountSum accumulator) {
count += accumulator.count;
sum += accumulator.sum;
}
@Override
public Double extractOutput() {
return count == 0 ? 0.0 : sum / count;
}
@Override
public int hashCode() {
return Objects.hash(count, sum);
}
@Override
public boolean equals(Object obj) {
if (obj == this) {
return true;
}
if (!(obj instanceof CountSum)) {
return false;
}
CountSum other = (CountSum) obj;
return this.count == other.count
&& (Math.abs(this.sum - other.sum) < 0.1);
}
@Override
public String toString() {
return MoreObjects.toStringHelper(this)
.add("count", count)
.add("sum", sum)
.toString();
}
}
@Override
public CountSum createAccumulator() {
return new CountSum(0, 0.0);
}
@Override
public Coder<CountSum> getAccumulatorCoder(
CoderRegistry registry, Coder<Integer> inputCoder) {
return new CountSumCoder();
}
/**
* A {@link Coder} for {@link CountSum}.
*/
private class CountSumCoder extends AtomicCoder<CountSum> {
@Override
public void encode(CountSum value, OutputStream outStream)
throws CoderException, IOException {
LONG_CODER.encode(value.count, outStream);
DOUBLE_CODER.encode(value.sum, outStream);
}
@Override
public CountSum decode(InputStream inStream)
throws CoderException, IOException {
long count = LONG_CODER.decode(inStream);
double sum = DOUBLE_CODER.decode(inStream);
return new CountSum(count, sum);
}
@Override
public void verifyDeterministic() throws NonDeterministicException { }
@Override
public boolean isRegisterByteSizeObserverCheap(
CountSum value) {
return true;
}
@Override
public void registerByteSizeObserver(
CountSum value, ElementByteSizeObserver observer)
throws Exception {
LONG_CODER.registerByteSizeObserver(value.count, observer);
DOUBLE_CODER.registerByteSizeObserver(value.sum, observer);
}
}
}
/**
* A {@link CombineFn} that results in a sorted list of all characters occurring in the key and
* the decimal representations of each value.
*/
public static class TestCombineFn extends CombineFn<Integer, TestCombineFn.Accumulator, String> {
// Not serializable.
static class Accumulator {
String value;
public Accumulator(String value) {
this.value = value;
}
public static Coder<Accumulator> getCoder() {
return new AtomicCoder<Accumulator>() {
@Override
public void encode(Accumulator accumulator, OutputStream outStream)
throws CoderException, IOException {
encode(accumulator, outStream, Coder.Context.NESTED);
}
@Override
public void encode(Accumulator accumulator, OutputStream outStream, Coder.Context context)
throws CoderException, IOException {
StringUtf8Coder.of().encode(accumulator.value, outStream, context);
}
@Override
public Accumulator decode(InputStream inStream) throws CoderException, IOException {
return decode(inStream, Coder.Context.NESTED);
}
@Override
public Accumulator decode(InputStream inStream, Coder.Context context)
throws CoderException, IOException {
return new Accumulator(StringUtf8Coder.of().decode(inStream, context));
}
};
}
}
@Override
public Coder<Accumulator> getAccumulatorCoder(
CoderRegistry registry, Coder<Integer> inputCoder) {
return Accumulator.getCoder();
}
@Override
public Accumulator createAccumulator() {
return new Accumulator("");
}
@Override
public Accumulator addInput(Accumulator accumulator, Integer value) {
try {
return new Accumulator(accumulator.value + String.valueOf(value));
} finally {
accumulator.value = "cleared in addInput";
}
}
@Override
public Accumulator mergeAccumulators(Iterable<Accumulator> accumulators) {
String all = "";
for (Accumulator accumulator : accumulators) {
all += accumulator.value;
accumulator.value = "cleared in mergeAccumulators";
}
return new Accumulator(all);
}
@Override
public String extractOutput(Accumulator accumulator) {
char[] chars = accumulator.value.toCharArray();
Arrays.sort(chars);
return new String(chars);
}
}
/**
* A {@link CombineFnWithContext} that produces a sorted list of all characters occurring in the
* key and the decimal representations of main and side inputs values.
*/
public class TestCombineFnWithContext extends CombineFnWithContext<Integer, Accumulator, String> {
private final PCollectionView<Integer> view;
public TestCombineFnWithContext(PCollectionView<Integer> view) {
this.view = view;
}
@Override
public Coder<TestCombineFn.Accumulator> getAccumulatorCoder(
CoderRegistry registry, Coder<Integer> inputCoder) {
return TestCombineFn.Accumulator.getCoder();
}
@Override
public TestCombineFn.Accumulator createAccumulator(Context c) {
return new TestCombineFn.Accumulator(c.sideInput(view).toString());
}
@Override
public TestCombineFn.Accumulator addInput(
TestCombineFn.Accumulator accumulator, Integer value, Context c) {
try {
assertThat(accumulator.value, Matchers.startsWith(c.sideInput(view).toString()));
return new TestCombineFn.Accumulator(accumulator.value + String.valueOf(value));
} finally {
accumulator.value = "cleared in addInput";
}
}
@Override
public TestCombineFn.Accumulator mergeAccumulators(
Iterable<TestCombineFn.Accumulator> accumulators, Context c) {
String prefix = c.sideInput(view).toString();
String all = prefix;
for (TestCombineFn.Accumulator accumulator : accumulators) {
assertThat(accumulator.value, Matchers.startsWith(prefix));
all += accumulator.value.substring(prefix.length());
accumulator.value = "cleared in mergeAccumulators";
}
return new TestCombineFn.Accumulator(all);
}
@Override
public String extractOutput(TestCombineFn.Accumulator accumulator, Context c) {
assertThat(accumulator.value, Matchers.startsWith(c.sideInput(view).toString()));
char[] chars = accumulator.value.toCharArray();
Arrays.sort(chars);
return new String(chars);
}
}
/** Another example AccumulatingCombineFn. */
public static class TestCounter extends
Combine.AccumulatingCombineFn<
Integer, TestCounter.Counter, Iterable<Long>> {
/** An accumulator that observes its merges and outputs. */
public class Counter implements
Combine.AccumulatingCombineFn.Accumulator<Integer, Counter, Iterable<Long>>,
Serializable {
public long sum = 0;
public long inputs = 0;
public long merges = 0;
public long outputs = 0;
public Counter(long sum, long inputs, long merges, long outputs) {
this.sum = sum;
this.inputs = inputs;
this.merges = merges;
this.outputs = outputs;
}
@Override
public void addInput(Integer element) {
checkState(merges == 0);
checkState(outputs == 0);
inputs++;
sum += element;
}
@Override
public void mergeAccumulator(Counter accumulator) {
checkState(outputs == 0);
checkArgument(accumulator.outputs == 0);
merges += accumulator.merges + 1;
inputs += accumulator.inputs;
sum += accumulator.sum;
}
@Override
public Iterable<Long> extractOutput() {
checkState(outputs == 0);
return Arrays.asList(sum, inputs, merges, outputs);
}
@Override
public int hashCode() {
return (int) (sum * 17 + inputs * 31 + merges * 43 + outputs * 181);
}
@Override
public boolean equals(Object otherObj) {
if (otherObj instanceof Counter) {
Counter other = (Counter) otherObj;
return (sum == other.sum
&& inputs == other.inputs
&& merges == other.merges
&& outputs == other.outputs);
}
return false;
}
@Override
public String toString() {
return sum + ":" + inputs + ":" + merges + ":" + outputs;
}
}
@Override
public Counter createAccumulator() {
return new Counter(0, 0, 0, 0);
}
@Override
public Coder<Counter> getAccumulatorCoder(
CoderRegistry registry, Coder<Integer> inputCoder) {
// This is a *very* inefficient encoding to send over the wire, but suffices
// for tests.
return SerializableCoder.of(Counter.class);
}
}
private static <T> PCollection<T> copy(PCollection<T> pc, final int n) {
return pc.apply(ParDo.of(new DoFn<T, T>() {
@ProcessElement
public void processElement(ProcessContext c) throws Exception {
for (int i = 0; i < n; i++) {
c.output(c.element());
}
}
}));
}
}