/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.core.construction;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.isEmptyOrNullString;
import static org.hamcrest.Matchers.not;
import static org.junit.Assert.assertThat;
import com.google.common.base.Equivalence;
import java.io.IOException;
import java.util.Collections;
import java.util.HashSet;
import java.util.Set;
import org.apache.beam.sdk.Pipeline.PipelineVisitor;
import org.apache.beam.sdk.coders.BigEndianLongCoder;
import org.apache.beam.sdk.coders.ByteArrayCoder;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.IterableCoder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.coders.SetCoder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.coders.StructuredCoder;
import org.apache.beam.sdk.coders.VarLongCoder;
import org.apache.beam.sdk.common.runner.v1.RunnerApi;
import org.apache.beam.sdk.common.runner.v1.RunnerApi.Components;
import org.apache.beam.sdk.io.GenerateSequence;
import org.apache.beam.sdk.runners.AppliedPTransform;
import org.apache.beam.sdk.runners.TransformHierarchy.Node;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.transforms.Count;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.GroupByKey;
import org.apache.beam.sdk.transforms.WithKeys;
import org.apache.beam.sdk.transforms.windowing.AfterPane;
import org.apache.beam.sdk.transforms.windowing.AfterWatermark;
import org.apache.beam.sdk.transforms.windowing.FixedWindows;
import org.apache.beam.sdk.transforms.windowing.Window;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PBegin;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PValue;
import org.apache.beam.sdk.values.WindowingStrategy;
import org.apache.beam.sdk.values.WindowingStrategy.AccumulationMode;
import org.hamcrest.Matchers;
import org.joda.time.Duration;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** Tests for {@link SdkComponents}. */
@RunWith(JUnit4.class)
public class SdkComponentsTest {
@Rule
public TestPipeline pipeline = TestPipeline.create().enableAbandonedNodeEnforcement(false);
@Rule
public ExpectedException thrown = ExpectedException.none();
private SdkComponents components = SdkComponents.create();
@Test
public void translatePipeline() {
BigEndianLongCoder customCoder = BigEndianLongCoder.of();
PCollection<Long> elems = pipeline.apply(GenerateSequence.from(0L).to(207L));
PCollection<Long> counted = elems.apply(Count.<Long>globally()).setCoder(customCoder);
PCollection<Long> windowed =
counted.apply(
Window.<Long>into(FixedWindows.of(Duration.standardMinutes(7)))
.triggering(
AfterWatermark.pastEndOfWindow()
.withEarlyFirings(AfterPane.elementCountAtLeast(19)))
.accumulatingFiredPanes()
.withAllowedLateness(Duration.standardMinutes(3L)));
final WindowingStrategy<?, ?> windowedStrategy = windowed.getWindowingStrategy();
PCollection<KV<String, Long>> keyed = windowed.apply(WithKeys.<String, Long>of("foo"));
PCollection<KV<String, Iterable<Long>>> grouped =
keyed.apply(GroupByKey.<String, Long>create());
final RunnerApi.Pipeline pipelineProto = SdkComponents.translatePipeline(pipeline);
pipeline.traverseTopologically(
new PipelineVisitor() {
Set<Node> transforms = new HashSet<>();
Set<PCollection<?>> pcollections = new HashSet<>();
Set<Equivalence.Wrapper<? extends Coder<?>>> coders = new HashSet<>();
Set<WindowingStrategy<?, ?>> windowingStrategies = new HashSet<>();
@Override
public CompositeBehavior enterCompositeTransform(Node node) {
return CompositeBehavior.ENTER_TRANSFORM;
}
@Override
public void leaveCompositeTransform(Node node) {
if (node.isRootNode()) {
assertThat(
"Unexpected number of PTransforms",
pipelineProto.getComponents().getTransformsCount(),
equalTo(transforms.size()));
assertThat(
"Unexpected number of PCollections",
pipelineProto.getComponents().getPcollectionsCount(),
equalTo(pcollections.size()));
assertThat(
"Unexpected number of Coders",
pipelineProto.getComponents().getCodersCount(),
equalTo(coders.size()));
assertThat(
"Unexpected number of Windowing Strategies",
pipelineProto.getComponents().getWindowingStrategiesCount(),
equalTo(windowingStrategies.size()));
} else {
transforms.add(node);
}
}
@Override
public void visitPrimitiveTransform(Node node) {
transforms.add(node);
}
@Override
public void visitValue(PValue value, Node producer) {
if (value instanceof PCollection) {
PCollection pc = (PCollection) value;
pcollections.add(pc);
addCoders(pc.getCoder());
windowingStrategies.add(pc.getWindowingStrategy());
addCoders(pc.getWindowingStrategy().getWindowFn().windowCoder());
}
}
private void addCoders(Coder<?> coder) {
coders.add(Equivalence.<Coder<?>>identity().wrap(coder));
if (coder instanceof StructuredCoder) {
for (Coder<?> component : ((StructuredCoder <?>) coder).getComponents()) {
addCoders(component);
}
}
}
});
}
@Test
public void registerCoder() throws IOException {
Coder<?> coder =
KvCoder.of(StringUtf8Coder.of(), IterableCoder.of(SetCoder.of(ByteArrayCoder.of())));
String id = components.registerCoder(coder);
assertThat(components.registerCoder(coder), equalTo(id));
assertThat(id, not(isEmptyOrNullString()));
VarLongCoder otherCoder = VarLongCoder.of();
assertThat(components.registerCoder(otherCoder), not(equalTo(id)));
components.toComponents().getCodersOrThrow(id);
components.toComponents().getCodersOrThrow(components.registerCoder(otherCoder));
}
@Test
public void registerCoderEqualsNotSame() throws IOException {
Coder<?> coder =
KvCoder.of(StringUtf8Coder.of(), IterableCoder.of(SetCoder.of(ByteArrayCoder.of())));
Coder<?> otherCoder =
KvCoder.of(StringUtf8Coder.of(), IterableCoder.of(SetCoder.of(ByteArrayCoder.of())));
assertThat(coder, Matchers.<Coder<?>>equalTo(otherCoder));
String id = components.registerCoder(coder);
String otherId = components.registerCoder(otherCoder);
assertThat(otherId, not(equalTo(id)));
components.toComponents().getCodersOrThrow(id);
components.toComponents().getCodersOrThrow(otherId);
}
@Test
public void registerTransformNoChildren() throws IOException {
Create.Values<Integer> create = Create.of(1, 2, 3);
PCollection<Integer> pt = pipeline.apply(create);
String userName = "my_transform/my_nesting";
AppliedPTransform<?, ?, ?> transform =
AppliedPTransform.<PBegin, PCollection<Integer>, Create.Values<Integer>>of(
userName, pipeline.begin().expand(), pt.expand(), create, pipeline);
String componentName =
components.registerPTransform(
transform, Collections.<AppliedPTransform<?, ?, ?>>emptyList());
assertThat(componentName, equalTo(userName));
assertThat(components.getExistingPTransformId(transform), equalTo(componentName));
}
@Test
public void registerTransformAfterChildren() throws IOException {
Create.Values<Long> create = Create.of(1L, 2L, 3L);
GenerateSequence createChild = GenerateSequence.from(0);
PCollection<Long> pt = pipeline.apply(create);
String userName = "my_transform";
String childUserName = "my_transform/my_nesting";
AppliedPTransform<?, ?, ?> transform =
AppliedPTransform.<PBegin, PCollection<Long>, Create.Values<Long>>of(
userName, pipeline.begin().expand(), pt.expand(), create, pipeline);
AppliedPTransform<?, ?, ?> childTransform =
AppliedPTransform.<PBegin, PCollection<Long>, GenerateSequence>of(
childUserName, pipeline.begin().expand(), pt.expand(), createChild, pipeline);
String childId = components.registerPTransform(childTransform,
Collections.<AppliedPTransform<?, ?, ?>>emptyList());
String parentId = components.registerPTransform(transform,
Collections.<AppliedPTransform<?, ?, ?>>singletonList(childTransform));
Components components = this.components.toComponents();
assertThat(components.getTransformsOrThrow(parentId).getSubtransforms(0), equalTo(childId));
assertThat(components.getTransformsOrThrow(childId).getSubtransformsCount(), equalTo(0));
}
@Test
public void registerTransformEmptyFullName() throws IOException {
Create.Values<Integer> create = Create.of(1, 2, 3);
PCollection<Integer> pt = pipeline.apply(create);
AppliedPTransform<?, ?, ?> transform =
AppliedPTransform.<PBegin, PCollection<Integer>, Create.Values<Integer>>of(
"", pipeline.begin().expand(), pt.expand(), create, pipeline);
thrown.expect(IllegalArgumentException.class);
thrown.expectMessage(transform.toString());
components.getExistingPTransformId(transform);
}
@Test
public void registerTransformNullComponents() throws IOException {
Create.Values<Integer> create = Create.of(1, 2, 3);
PCollection<Integer> pt = pipeline.apply(create);
String userName = "my_transform/my_nesting";
AppliedPTransform<?, ?, ?> transform =
AppliedPTransform.<PBegin, PCollection<Integer>, Create.Values<Integer>>of(
userName, pipeline.begin().expand(), pt.expand(), create, pipeline);
thrown.expect(NullPointerException.class);
thrown.expectMessage("child nodes may not be null");
components.registerPTransform(transform, null);
}
/**
* Tests that trying to register a transform which has unregistered children throws.
*/
@Test
public void registerTransformWithUnregisteredChildren() throws IOException {
Create.Values<Long> create = Create.of(1L, 2L, 3L);
GenerateSequence createChild = GenerateSequence.from(0);
PCollection<Long> pt = pipeline.apply(create);
String userName = "my_transform";
String childUserName = "my_transform/my_nesting";
AppliedPTransform<?, ?, ?> transform =
AppliedPTransform.<PBegin, PCollection<Long>, Create.Values<Long>>of(
userName, pipeline.begin().expand(), pt.expand(), create, pipeline);
AppliedPTransform<?, ?, ?> childTransform =
AppliedPTransform.<PBegin, PCollection<Long>, GenerateSequence>of(
childUserName, pipeline.begin().expand(), pt.expand(), createChild, pipeline);
thrown.expect(IllegalArgumentException.class);
thrown.expectMessage(childTransform.toString());
components.registerPTransform(
transform, Collections.<AppliedPTransform<?, ?, ?>>singletonList(childTransform));
}
@Test
public void registerPCollection() throws IOException {
PCollection<Long> pCollection = pipeline.apply(GenerateSequence.from(0)).setName("foo");
String id = components.registerPCollection(pCollection);
assertThat(id, equalTo("foo"));
components.toComponents().getPcollectionsOrThrow(id);
}
@Test
public void registerPCollectionExistingNameCollision() throws IOException {
PCollection<Long> pCollection =
pipeline.apply("FirstCount", GenerateSequence.from(0)).setName("foo");
String firstId = components.registerPCollection(pCollection);
PCollection<Long> duplicate =
pipeline.apply("SecondCount", GenerateSequence.from(0)).setName("foo");
String secondId = components.registerPCollection(duplicate);
assertThat(firstId, equalTo("foo"));
assertThat(secondId, containsString("foo"));
assertThat(secondId, not(equalTo("foo")));
components.toComponents().getPcollectionsOrThrow(firstId);
components.toComponents().getPcollectionsOrThrow(secondId);
}
@Test
public void registerWindowingStrategy() throws IOException {
WindowingStrategy<?, ?> strategy =
WindowingStrategy.globalDefault().withMode(AccumulationMode.ACCUMULATING_FIRED_PANES);
String name = components.registerWindowingStrategy(strategy);
assertThat(name, not(isEmptyOrNullString()));
components.toComponents().getWindowingStrategiesOrThrow(name);
}
@Test
public void registerWindowingStrategyIdEqualStrategies() throws IOException {
WindowingStrategy<?, ?> strategy =
WindowingStrategy.globalDefault().withMode(AccumulationMode.ACCUMULATING_FIRED_PANES);
String name = components.registerWindowingStrategy(strategy);
String duplicateName =
components.registerWindowingStrategy(
WindowingStrategy.globalDefault().withMode(AccumulationMode.ACCUMULATING_FIRED_PANES));
assertThat(name, equalTo(duplicateName));
}
}