/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.beam.sdk.transforms; import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem; import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.includesDisplayDataFor; import static org.junit.Assert.assertThat; import com.google.common.collect.ImmutableList; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.io.Serializable; import java.util.Arrays; import java.util.Collections; import java.util.List; import org.apache.beam.sdk.coders.AtomicCoder; import org.apache.beam.sdk.coders.BigEndianIntegerCoder; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.CoderException; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.coders.NullableCoder; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.testing.ValidatesRunner; import org.apache.beam.sdk.transforms.Combine.BinaryCombineFn; import org.apache.beam.sdk.transforms.CombineFns.CoCombineResult; import org.apache.beam.sdk.transforms.CombineWithContext.CombineFnWithContext; import org.apache.beam.sdk.transforms.display.DisplayData; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.TupleTag; import org.hamcrest.Matchers; import org.junit.Rule; import org.junit.Test; import org.junit.experimental.categories.Category; import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; /** * Unit tests for {@link CombineFns}. */ @RunWith(JUnit4.class) public class CombineFnsTest { @Rule public final TestPipeline p = TestPipeline.create(); @Rule public ExpectedException expectedException = ExpectedException.none(); @Test public void testDuplicatedTags() { expectedException.expect(IllegalArgumentException.class); expectedException.expectMessage("it is already present in the composition"); TupleTag<Integer> tag = new TupleTag<Integer>(); CombineFns.compose() .with(new GetIntegerFunction(), Max.ofIntegers(), tag) .with(new GetIntegerFunction(), Min.ofIntegers(), tag); } @Test public void testDuplicatedTagsKeyed() { expectedException.expect(IllegalArgumentException.class); expectedException.expectMessage("it is already present in the composition"); TupleTag<Integer> tag = new TupleTag<Integer>(); CombineFns.compose() .with(new GetIntegerFunction(), Max.ofIntegers(), tag) .with(new GetIntegerFunction(), Min.ofIntegers(), tag); } @Test public void testDuplicatedTagsWithContext() { expectedException.expect(IllegalArgumentException.class); expectedException.expectMessage("it is already present in the composition"); TupleTag<UserString> tag = new TupleTag<UserString>(); CombineFns.compose() .with( new GetUserStringFunction(), new ConcatStringWithContext(null /* view */), tag) .with( new GetUserStringFunction(), new ConcatStringWithContext(null /* view */), tag); } @Test @Category(ValidatesRunner.class) public void testComposedCombine() { p.getCoderRegistry().registerCoderForClass(UserString.class, UserStringCoder.of()); PCollection<KV<String, KV<Integer, UserString>>> perKeyInput = p.apply( Create.timestamped( Arrays.asList( KV.of("a", KV.of(1, UserString.of("1"))), KV.of("a", KV.of(1, UserString.of("1"))), KV.of("a", KV.of(4, UserString.of("4"))), KV.of("b", KV.of(1, UserString.of("1"))), KV.of("b", KV.of(13, UserString.of("13")))), Arrays.asList(0L, 4L, 7L, 10L, 16L)) .withCoder(KvCoder.of( StringUtf8Coder.of(), KvCoder.of(BigEndianIntegerCoder.of(), UserStringCoder.of())))); TupleTag<Integer> maxIntTag = new TupleTag<Integer>(); TupleTag<UserString> concatStringTag = new TupleTag<UserString>(); PCollection<KV<String, KV<Integer, String>>> combineGlobally = perKeyInput .apply(Values.<KV<Integer, UserString>>create()) .apply(Combine.globally(CombineFns.compose() .with( new GetIntegerFunction(), Max.ofIntegers(), maxIntTag) .with( new GetUserStringFunction(), new ConcatString(), concatStringTag))) .apply(WithKeys.<String, CoCombineResult>of("global")) .apply( "ExtractGloballyResult", ParDo.of(new ExtractResultDoFn(maxIntTag, concatStringTag))); PCollection<KV<String, KV<Integer, String>>> combinePerKey = perKeyInput .apply( Combine.<String, KV<Integer, UserString>, CoCombineResult>perKey( CombineFns.compose() .with(new GetIntegerFunction(), Max.ofIntegers(), maxIntTag) .with(new GetUserStringFunction(), new ConcatString(), concatStringTag))) .apply( "ExtractPerKeyResult", ParDo.of(new ExtractResultDoFn(maxIntTag, concatStringTag))); PAssert.that(combineGlobally).containsInAnyOrder( KV.of("global", KV.of(13, "111134"))); PAssert.that(combinePerKey).containsInAnyOrder( KV.of("a", KV.of(4, "114")), KV.of("b", KV.of(13, "113"))); p.run(); } @Test @Category(ValidatesRunner.class) public void testComposedCombineWithContext() { p.getCoderRegistry().registerCoderForClass(UserString.class, UserStringCoder.of()); PCollectionView<String> view = p .apply(Create.of("I")) .apply(View.<String>asSingleton()); PCollection<KV<String, KV<Integer, UserString>>> perKeyInput = p.apply( Create.timestamped( Arrays.asList( KV.of("a", KV.of(1, UserString.of("1"))), KV.of("a", KV.of(1, UserString.of("1"))), KV.of("a", KV.of(4, UserString.of("4"))), KV.of("b", KV.of(1, UserString.of("1"))), KV.of("b", KV.of(13, UserString.of("13")))), Arrays.asList(0L, 4L, 7L, 10L, 16L)) .withCoder(KvCoder.of( StringUtf8Coder.of(), KvCoder.of(BigEndianIntegerCoder.of(), UserStringCoder.of())))); TupleTag<Integer> maxIntTag = new TupleTag<Integer>(); TupleTag<UserString> concatStringTag = new TupleTag<UserString>(); PCollection<KV<String, KV<Integer, String>>> combineGlobally = perKeyInput .apply(Values.<KV<Integer, UserString>>create()) .apply(Combine.globally(CombineFns.compose() .with( new GetIntegerFunction(), Max.ofIntegers(), maxIntTag) .with( new GetUserStringFunction(), new ConcatStringWithContext(view), concatStringTag)) .withoutDefaults() .withSideInputs(ImmutableList.of(view))) .apply(WithKeys.<String, CoCombineResult>of("global")) .apply( "ExtractGloballyResult", ParDo.of(new ExtractResultDoFn(maxIntTag, concatStringTag))); PCollection<KV<String, KV<Integer, String>>> combinePerKey = perKeyInput .apply( Combine.<String, KV<Integer, UserString>, CoCombineResult>perKey( CombineFns.compose() .with(new GetIntegerFunction(), Max.ofIntegers(), maxIntTag) .with( new GetUserStringFunction(), new ConcatStringWithContext(view), concatStringTag)) .withSideInputs(ImmutableList.of(view))) .apply( "ExtractPerKeyResult", ParDo.of(new ExtractResultDoFn(maxIntTag, concatStringTag))); PAssert.that(combineGlobally).containsInAnyOrder( KV.of("global", KV.of(13, "111134I"))); PAssert.that(combinePerKey).containsInAnyOrder( KV.of("a", KV.of(4, "114I")), KV.of("b", KV.of(13, "113I"))); p.run(); } @Test @Category(ValidatesRunner.class) public void testComposedCombineNullValues() { p.getCoderRegistry().registerCoderForClass( UserString.class, NullableCoder.of(UserStringCoder.of())); p.getCoderRegistry().registerCoderForClass( String.class, NullableCoder.of(StringUtf8Coder.of())); PCollection<KV<String, KV<Integer, UserString>>> perKeyInput = p.apply( Create.timestamped( Arrays.asList( KV.of("a", KV.of(1, UserString.of("1"))), KV.of("a", KV.of(1, UserString.of("1"))), KV.of("a", KV.of(4, UserString.of("4"))), KV.of("b", KV.of(1, UserString.of("1"))), KV.of("b", KV.of(13, UserString.of("13")))), Arrays.asList(0L, 4L, 7L, 10L, 16L)) .withCoder(KvCoder.of( NullableCoder.of(StringUtf8Coder.of()), KvCoder.of( BigEndianIntegerCoder.of(), NullableCoder.of(UserStringCoder.of()))))); TupleTag<Integer> maxIntTag = new TupleTag<Integer>(); TupleTag<UserString> concatStringTag = new TupleTag<UserString>(); PCollection<KV<String, KV<Integer, String>>> combinePerKey = perKeyInput .apply( Combine.<String, KV<Integer, UserString>, CoCombineResult>perKey( CombineFns.compose() .with(new GetIntegerFunction(), Max.ofIntegers(), maxIntTag) .with( new GetUserStringFunction(), new OutputNullString(), concatStringTag))) .apply( "ExtractPerKeyResult", ParDo.of(new ExtractResultDoFn(maxIntTag, concatStringTag))); PAssert.that(combinePerKey).containsInAnyOrder( KV.of("a", KV.of(4, (String) null)), KV.of("b", KV.of(13, (String) null))); p.run(); } @Test public void testComposedCombineDisplayData() { SimpleFunction<String, String> extractFn = new SimpleFunction<String, String>() { @Override public String apply(String input) { return input; } }; DisplayDataCombineFn combineFn1 = new DisplayDataCombineFn("value1"); DisplayDataCombineFn combineFn2 = new DisplayDataCombineFn("value2"); CombineFns.ComposedCombineFn<String> composedCombine = CombineFns.compose() .with(extractFn, combineFn1, new TupleTag<String>()) .with(extractFn, combineFn2, new TupleTag<String>()); DisplayData displayData = DisplayData.from(composedCombine); assertThat(displayData, hasDisplayItem("combineFn1", combineFn1.getClass())); assertThat(displayData, hasDisplayItem("combineFn2", combineFn2.getClass())); assertThat(displayData, includesDisplayDataFor("combineFn1", combineFn1)); assertThat(displayData, includesDisplayDataFor("combineFn2", combineFn2)); } private static class DisplayDataCombineFn extends Combine.CombineFn<String, String, String> { private final String value; private static int i; private final int id; DisplayDataCombineFn(String value) { id = ++i; this.value = value; } @Override public String createAccumulator() { return null; } @Override public String addInput(String accumulator, String input) { return null; } @Override public String mergeAccumulators(Iterable<String> accumulators) { return null; } @Override public String extractOutput(String accumulator) { return null; } @Override public void populateDisplayData(DisplayData.Builder builder) { builder .add(DisplayData.item("uniqueKey" + id, value)) .add(DisplayData.item("sharedKey", value)); } } private static class UserString implements Serializable { private String strValue; static UserString of(String strValue) { UserString ret = new UserString(); ret.strValue = strValue; return ret; } } private static class UserStringCoder extends AtomicCoder<UserString> { public static UserStringCoder of() { return INSTANCE; } private static final UserStringCoder INSTANCE = new UserStringCoder(); @Override public void encode(UserString value, OutputStream outStream) throws CoderException, IOException { encode(value, outStream, Context.NESTED); } @Override public void encode(UserString value, OutputStream outStream, Context context) throws CoderException, IOException { StringUtf8Coder.of().encode(value.strValue, outStream, context); } @Override public UserString decode(InputStream inStream) throws CoderException, IOException { return decode(inStream, Context.NESTED); } @Override public UserString decode(InputStream inStream, Context context) throws CoderException, IOException { return UserString.of(StringUtf8Coder.of().decode(inStream, context)); } @Override public List<? extends Coder<?>> getCoderArguments() { return Collections.emptyList(); } @Override public void verifyDeterministic() throws NonDeterministicException {} } private static class GetIntegerFunction extends SimpleFunction<KV<Integer, UserString>, Integer> { @Override public Integer apply(KV<Integer, UserString> input) { return input.getKey(); } } private static class GetUserStringFunction extends SimpleFunction<KV<Integer, UserString>, UserString> { @Override public UserString apply(KV<Integer, UserString> input) { return input.getValue(); } } private static class ConcatString extends BinaryCombineFn<UserString> { @Override public UserString apply(UserString left, UserString right) { String retStr = left.strValue + right.strValue; char[] chars = retStr.toCharArray(); Arrays.sort(chars); return UserString.of(new String(chars)); } } private static class OutputNullString extends BinaryCombineFn<UserString> { @Override public UserString apply(UserString left, UserString right) { return null; } } private static class ConcatStringWithContext extends CombineFnWithContext<UserString, UserString, UserString> { private final PCollectionView<String> view; private ConcatStringWithContext(PCollectionView<String> view) { this.view = view; } @Override public UserString createAccumulator(CombineWithContext.Context c) { return UserString.of(c.sideInput(view)); } @Override public UserString addInput( UserString accumulator, UserString input, CombineWithContext.Context c) { assertThat(accumulator.strValue, Matchers.startsWith(c.sideInput(view))); accumulator.strValue += input.strValue; return accumulator; } @Override public UserString mergeAccumulators( Iterable<UserString> accumulators, CombineWithContext.Context c) { String keyPrefix = c.sideInput(view); String all = keyPrefix; for (UserString accumulator : accumulators) { assertThat(accumulator.strValue, Matchers.startsWith(keyPrefix)); all += accumulator.strValue.substring(keyPrefix.length()); accumulator.strValue = "cleared in mergeAccumulators"; } return UserString.of(all); } @Override public UserString extractOutput(UserString accumulator, CombineWithContext.Context c) { assertThat(accumulator.strValue, Matchers.startsWith(c.sideInput(view))); char[] chars = accumulator.strValue.toCharArray(); Arrays.sort(chars); return UserString.of(new String(chars)); } } private static class ExtractResultDoFn extends DoFn<KV<String, CoCombineResult>, KV<String, KV<Integer, String>>> { private final TupleTag<Integer> maxIntTag; private final TupleTag<UserString> concatStringTag; ExtractResultDoFn(TupleTag<Integer> maxIntTag, TupleTag<UserString> concatStringTag) { this.maxIntTag = maxIntTag; this.concatStringTag = concatStringTag; } @ProcessElement public void processElement(ProcessContext c) throws Exception { UserString userString = c.element().getValue().get(concatStringTag); KV<Integer, String> value = KV.of( c.element().getValue().get(maxIntTag), userString == null ? null : userString.strValue); c.output(KV.of(c.element().getKey(), value)); } } }