/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.beam.runners.direct; import com.google.common.base.Supplier; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import java.util.Collection; import java.util.Collections; import java.util.EnumSet; import java.util.List; import java.util.Map; import java.util.Set; import org.apache.beam.runners.core.SplittableParDoViaKeyedWorkItems; import org.apache.beam.runners.core.construction.PTransformMatchers; import org.apache.beam.runners.core.construction.SplittableParDo; import org.apache.beam.runners.direct.DirectRunner.DirectPipelineResult; import org.apache.beam.runners.direct.TestStreamEvaluatorFactory.DirectTestStreamFactory; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.Pipeline.PipelineExecutionException; import org.apache.beam.sdk.PipelineResult; import org.apache.beam.sdk.PipelineRunner; import org.apache.beam.sdk.io.Read; import org.apache.beam.sdk.metrics.MetricResults; import org.apache.beam.sdk.metrics.MetricsEnvironment; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.runners.PTransformOverride; import org.apache.beam.sdk.testing.TestStream; import org.apache.beam.sdk.transforms.GroupByKey; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.ParDo.MultiOutput; import org.apache.beam.sdk.transforms.View.CreatePCollectionView; import org.apache.beam.sdk.util.UserCodeException; import org.apache.beam.sdk.values.PCollection; import org.joda.time.Duration; /** * A {@link PipelineRunner} that executes a {@link Pipeline} within the process that constructed the * {@link Pipeline}. * * <p>The {@link DirectRunner} is suitable for running a {@link Pipeline} on small scale, example, * and test data, and should be used for ensuring that processing logic is correct. It also * is appropriate for executing unit tests and performs additional work to ensure that behavior * contained within a {@link Pipeline} does not break assumptions within the Beam model, to improve * the ability to execute a {@link Pipeline} at scale on a distributed backend. */ public class DirectRunner extends PipelineRunner<DirectPipelineResult> { enum Enforcement { ENCODABILITY { @Override public boolean appliesTo(PCollection<?> collection, DirectGraph graph) { return true; } }, IMMUTABILITY { @Override public boolean appliesTo(PCollection<?> collection, DirectGraph graph) { return CONTAINS_UDF.contains(graph.getProducer(collection).getTransform().getClass()); } }; /** * The set of {@link PTransform PTransforms} that execute a UDF. Useful for some enforcements. */ private static final Set<Class<? extends PTransform>> CONTAINS_UDF = ImmutableSet.of( Read.Bounded.class, Read.Unbounded.class, ParDo.SingleOutput.class, MultiOutput.class); public abstract boolean appliesTo(PCollection<?> collection, DirectGraph graph); //////////////////////////////////////////////////////////////////////////////////////////////// // Utilities for creating enforcements static Set<Enforcement> enabled(DirectOptions options) { EnumSet<Enforcement> enabled = EnumSet.noneOf(Enforcement.class); if (options.isEnforceEncodability()) { enabled.add(ENCODABILITY); } if (options.isEnforceImmutability()) { enabled.add(IMMUTABILITY); } return Collections.unmodifiableSet(enabled); } static BundleFactory bundleFactoryFor( Set<Enforcement> enforcements, DirectGraph graph) { BundleFactory bundleFactory = enforcements.contains(Enforcement.ENCODABILITY) ? CloningBundleFactory.create() : ImmutableListBundleFactory.create(); if (enforcements.contains(Enforcement.IMMUTABILITY)) { bundleFactory = ImmutabilityCheckingBundleFactory.create(bundleFactory, graph); } return bundleFactory; } @SuppressWarnings("rawtypes") private static Map<Class<? extends PTransform>, Collection<ModelEnforcementFactory>> defaultModelEnforcements(Set<Enforcement> enabledEnforcements) { ImmutableMap.Builder<Class<? extends PTransform>, Collection<ModelEnforcementFactory>> enforcements = ImmutableMap.builder(); ImmutableList.Builder<ModelEnforcementFactory> enabledParDoEnforcements = ImmutableList.builder(); if (enabledEnforcements.contains(Enforcement.IMMUTABILITY)) { enabledParDoEnforcements.add(ImmutabilityEnforcementFactory.create()); } Collection<ModelEnforcementFactory> parDoEnforcements = enabledParDoEnforcements.build(); enforcements.put(ParDo.SingleOutput.class, parDoEnforcements); enforcements.put(MultiOutput.class, parDoEnforcements); return enforcements.build(); } } //////////////////////////////////////////////////////////////////////////////////////////////// private final DirectOptions options; private final Set<Enforcement> enabledEnforcements; private Supplier<Clock> clockSupplier = new NanosOffsetClockSupplier(); /** * Construct a {@link DirectRunner} from the provided options. */ public static DirectRunner fromOptions(PipelineOptions options) { return new DirectRunner(options.as(DirectOptions.class)); } private DirectRunner(DirectOptions options) { this.options = options; this.enabledEnforcements = Enforcement.enabled(options); } /** * Returns the {@link PipelineOptions} used to create this {@link DirectRunner}. */ public DirectOptions getPipelineOptions() { return options; } Supplier<Clock> getClockSupplier() { return clockSupplier; } void setClockSupplier(Supplier<Clock> supplier) { this.clockSupplier = supplier; } @Override public DirectPipelineResult run(Pipeline pipeline) { pipeline.replaceAll(defaultTransformOverrides()); MetricsEnvironment.setMetricsSupported(true); DirectGraphVisitor graphVisitor = new DirectGraphVisitor(); pipeline.traverseTopologically(graphVisitor); @SuppressWarnings("rawtypes") KeyedPValueTrackingVisitor keyedPValueVisitor = KeyedPValueTrackingVisitor.create(); pipeline.traverseTopologically(keyedPValueVisitor); DisplayDataValidator.validatePipeline(pipeline); DisplayDataValidator.validateOptions(getPipelineOptions()); DirectGraph graph = graphVisitor.getGraph(); EvaluationContext context = EvaluationContext.create( getPipelineOptions(), clockSupplier.get(), Enforcement.bundleFactoryFor(enabledEnforcements, graph), graph, keyedPValueVisitor.getKeyedPValues()); RootProviderRegistry rootInputProvider = RootProviderRegistry.defaultRegistry(context); TransformEvaluatorRegistry registry = TransformEvaluatorRegistry.defaultRegistry(context); PipelineExecutor executor = ExecutorServiceParallelExecutor.create( options.getTargetParallelism(), graph, rootInputProvider, registry, Enforcement.defaultModelEnforcements(enabledEnforcements), context); executor.start(graph.getRootTransforms()); DirectPipelineResult result = new DirectPipelineResult(executor, context); if (options.isBlockOnRun()) { try { result.waitUntilFinish(); } catch (UserCodeException userException) { throw new PipelineExecutionException(userException.getCause()); } catch (Throwable t) { if (t instanceof RuntimeException) { throw (RuntimeException) t; } throw new RuntimeException(t); } } return result; } /** * The default set of transform overrides to use in the {@link DirectRunner}. * * <p>The order in which overrides is applied is important, as some overrides are expanded into a * composite. If the composite contains {@link PTransform PTransforms} which are also overridden, * these PTransforms must occur later in the iteration order. {@link ImmutableMap} has an * iteration order based on the order at which elements are added to it. */ @SuppressWarnings("rawtypes") private List<PTransformOverride> defaultTransformOverrides() { return ImmutableList.<PTransformOverride>builder() .add( PTransformOverride.of( PTransformMatchers.writeWithRunnerDeterminedSharding(), new WriteWithShardingFactory())) /* Uses a view internally. */ .add( PTransformOverride.of( PTransformMatchers.classEqualTo(CreatePCollectionView.class), new ViewOverrideFactory())) /* Uses pardos and GBKs */ .add( PTransformOverride.of( PTransformMatchers.classEqualTo(TestStream.class), new DirectTestStreamFactory(this))) /* primitive */ // SplittableParMultiDo is implemented in terms of nonsplittable simple ParDos and extra // primitives .add( PTransformOverride.of( PTransformMatchers.splittableParDoMulti(), new ParDoMultiOverrideFactory())) // state and timer pardos are implemented in terms of simple ParDos and extra primitives .add( PTransformOverride.of( PTransformMatchers.stateOrTimerParDoMulti(), new ParDoMultiOverrideFactory())) .add( PTransformOverride.of( PTransformMatchers.classEqualTo(SplittableParDo.ProcessKeyedElements.class), new SplittableParDoViaKeyedWorkItems.OverrideFactory())) .add( PTransformOverride.of( PTransformMatchers.classEqualTo( SplittableParDoViaKeyedWorkItems.GBKIntoKeyedWorkItems.class), new DirectGBKIntoKeyedWorkItemsOverrideFactory())) /* Returns a GBKO */ .add( PTransformOverride.of( PTransformMatchers.classEqualTo(GroupByKey.class), new DirectGroupByKeyOverrideFactory())) /* returns two chained primitives. */ .build(); } /** * The result of running a {@link Pipeline} with the {@link DirectRunner}. */ public static class DirectPipelineResult implements PipelineResult { private final PipelineExecutor executor; private final EvaluationContext evaluationContext; private State state; private DirectPipelineResult( PipelineExecutor executor, EvaluationContext evaluationContext) { this.executor = executor; this.evaluationContext = evaluationContext; // Only ever constructed after the executor has started. this.state = State.RUNNING; } @Override public State getState() { return state; } @Override public MetricResults metrics() { return evaluationContext.getMetrics(); } /** * {@inheritDoc}. * * <p>If the pipeline terminates abnormally by throwing an {@link Exception}, this will rethrow * the original {@link Exception}. Future calls to {@link #getState()} will return {@link * org.apache.beam.sdk.PipelineResult.State#FAILED}. */ @Override public State waitUntilFinish() { return waitUntilFinish(Duration.ZERO); } @Override public State cancel() { this.state = executor.getPipelineState(); if (!this.state.isTerminal()) { executor.stop(); this.state = executor.getPipelineState(); } return executor.getPipelineState(); } /** * {@inheritDoc}. * * <p>If the pipeline terminates abnormally by throwing an {@link Exception}, this will rethrow * the original {@link Exception}. Future calls to {@link #getState()} will return {@link * org.apache.beam.sdk.PipelineResult.State#FAILED}. */ @Override public State waitUntilFinish(Duration duration) { State startState = this.state; if (!startState.isTerminal()) { try { state = executor.waitUntilFinish(duration); } catch (UserCodeException uce) { // Emulates the behavior of Pipeline#run(), where a stack trace caused by a // UserCodeException is truncated and replaced with the stack starting at the call to // waitToFinish throw new Pipeline.PipelineExecutionException(uce.getCause()); } catch (Exception e) { if (e instanceof InterruptedException) { Thread.currentThread().interrupt(); } if (e instanceof RuntimeException) { throw (RuntimeException) e; } throw new RuntimeException(e); } } return this.state; } } /** * A {@link Supplier} that creates a {@link NanosOffsetClock}. */ private static class NanosOffsetClockSupplier implements Supplier<Clock> { @Override public Clock get() { return NanosOffsetClock.create(); } } }