/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.beam.fn.harness.control; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.Iterables.getOnlyElement; import com.google.common.collect.Collections2; import com.google.common.collect.HashMultimap; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableMultimap; import com.google.common.collect.Lists; import com.google.common.collect.Multimap; import com.google.protobuf.ByteString; import com.google.protobuf.BytesValue; import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.Message; import java.io.IOException; import java.util.ArrayList; import java.util.Collection; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.function.BiConsumer; import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Supplier; import org.apache.beam.fn.harness.data.BeamFnDataClient; import org.apache.beam.fn.harness.fake.FakeStepContext; import org.apache.beam.fn.harness.fn.ThrowingConsumer; import org.apache.beam.fn.harness.fn.ThrowingRunnable; import org.apache.beam.fn.v1.BeamFnApi; import org.apache.beam.runners.core.BeamFnDataReadRunner; import org.apache.beam.runners.core.BeamFnDataWriteRunner; import org.apache.beam.runners.core.BoundedSourceRunner; import org.apache.beam.runners.core.DoFnRunner; import org.apache.beam.runners.core.DoFnRunners; import org.apache.beam.runners.core.DoFnRunners.OutputManager; import org.apache.beam.runners.core.NullSideInputReader; import org.apache.beam.runners.dataflow.util.DoFnInfo; import org.apache.beam.sdk.io.BoundedSource; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.util.SerializableUtils; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.WindowingStrategy; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * Processes {@link org.apache.beam.fn.v1.BeamFnApi.ProcessBundleRequest}s by materializing * the set of required runners for each {@link org.apache.beam.fn.v1.BeamFnApi.FunctionSpec}, * wiring them together based upon the {@code input} and {@code output} map definitions. * * <p>Finally executes the DAG based graph by starting all runners in reverse topological order, * and finishing all runners in forward topological order. */ public class ProcessBundleHandler { // TODO: What should the initial set of URNs be? private static final String DATA_INPUT_URN = "urn:org.apache.beam:source:runner:0.1"; private static final String DATA_OUTPUT_URN = "urn:org.apache.beam:sink:runner:0.1"; private static final String JAVA_DO_FN_URN = "urn:org.apache.beam:dofn:java:0.1"; private static final String JAVA_SOURCE_URN = "urn:org.apache.beam:source:java:0.1"; private static final Logger LOG = LoggerFactory.getLogger(ProcessBundleHandler.class); private final PipelineOptions options; private final Function<String, Message> fnApiRegistry; private final BeamFnDataClient beamFnDataClient; public ProcessBundleHandler( PipelineOptions options, Function<String, Message> fnApiRegistry, BeamFnDataClient beamFnDataClient) { this.options = options; this.fnApiRegistry = fnApiRegistry; this.beamFnDataClient = beamFnDataClient; } protected <InputT, OutputT> void createConsumersForPrimitiveTransform( BeamFnApi.PrimitiveTransform primitiveTransform, Supplier<String> processBundleInstructionId, Function<BeamFnApi.Target, Collection<ThrowingConsumer<WindowedValue<OutputT>>>> consumers, BiConsumer<BeamFnApi.Target, ThrowingConsumer<WindowedValue<InputT>>> addConsumer, Consumer<ThrowingRunnable> addStartFunction, Consumer<ThrowingRunnable> addFinishFunction) throws IOException { BeamFnApi.FunctionSpec functionSpec = primitiveTransform.getFunctionSpec(); // For every output PCollection, create a map from output name to Consumer ImmutableMap.Builder<String, Collection<ThrowingConsumer<WindowedValue<OutputT>>>> outputMapBuilder = ImmutableMap.builder(); for (Map.Entry<String, BeamFnApi.PCollection> entry : primitiveTransform.getOutputsMap().entrySet()) { outputMapBuilder.put( entry.getKey(), consumers.apply( BeamFnApi.Target.newBuilder() .setPrimitiveTransformReference(primitiveTransform.getId()) .setName(entry.getKey()) .build())); } ImmutableMap<String, Collection<ThrowingConsumer<WindowedValue<OutputT>>>> outputMap = outputMapBuilder.build(); // Based upon the function spec, populate the start/finish/consumer information. ThrowingConsumer<WindowedValue<InputT>> consumer; switch (functionSpec.getUrn()) { default: BeamFnApi.Target target; BeamFnApi.Coder coderSpec; throw new IllegalArgumentException( String.format("Unknown FunctionSpec %s", functionSpec)); case DATA_OUTPUT_URN: target = BeamFnApi.Target.newBuilder() .setPrimitiveTransformReference(primitiveTransform.getId()) .setName(getOnlyElement(primitiveTransform.getOutputsMap().keySet())) .build(); coderSpec = (BeamFnApi.Coder) fnApiRegistry.apply( getOnlyElement(primitiveTransform.getOutputsMap().values()).getCoderReference()); BeamFnDataWriteRunner<InputT> remoteGrpcWriteRunner = new BeamFnDataWriteRunner<>( functionSpec, processBundleInstructionId, target, coderSpec, beamFnDataClient); addStartFunction.accept(remoteGrpcWriteRunner::registerForOutput); consumer = remoteGrpcWriteRunner::consume; addFinishFunction.accept(remoteGrpcWriteRunner::close); break; case DATA_INPUT_URN: target = BeamFnApi.Target.newBuilder() .setPrimitiveTransformReference(primitiveTransform.getId()) .setName(getOnlyElement(primitiveTransform.getInputsMap().keySet())) .build(); coderSpec = (BeamFnApi.Coder) fnApiRegistry.apply( getOnlyElement(primitiveTransform.getOutputsMap().values()).getCoderReference()); BeamFnDataReadRunner<OutputT> remoteGrpcReadRunner = new BeamFnDataReadRunner<>( functionSpec, processBundleInstructionId, target, coderSpec, beamFnDataClient, outputMap); addStartFunction.accept(remoteGrpcReadRunner::registerInputLocation); consumer = null; addFinishFunction.accept(remoteGrpcReadRunner::blockTillReadFinishes); break; case JAVA_DO_FN_URN: DoFnRunner<InputT, OutputT> doFnRunner = createDoFnRunner(functionSpec, outputMap); addStartFunction.accept(doFnRunner::startBundle); addFinishFunction.accept(doFnRunner::finishBundle); consumer = doFnRunner::processElement; break; case JAVA_SOURCE_URN: @SuppressWarnings({"unchecked", "rawtypes"}) BoundedSourceRunner<BoundedSource<OutputT>, OutputT> sourceRunner = createBoundedSourceRunner(functionSpec, outputMap); @SuppressWarnings({"unchecked", "rawtypes"}) ThrowingConsumer<WindowedValue<?>> sourceConsumer = (ThrowingConsumer) (ThrowingConsumer<WindowedValue<BoundedSource<OutputT>>>) sourceRunner::runReadLoop; // TODO: Remove and replace with source being sent across gRPC port addStartFunction.accept(sourceRunner::start); consumer = (ThrowingConsumer) sourceConsumer; break; } if (consumer != null) { for (Map.Entry<String, BeamFnApi.Target.List> entry : primitiveTransform.getInputsMap().entrySet()) { for (BeamFnApi.Target target : entry.getValue().getTargetList()) { addConsumer.accept(target, consumer); } } } } public BeamFnApi.InstructionResponse.Builder processBundle(BeamFnApi.InstructionRequest request) throws Exception { BeamFnApi.InstructionResponse.Builder response = BeamFnApi.InstructionResponse.newBuilder() .setProcessBundle(BeamFnApi.ProcessBundleResponse.getDefaultInstance()); String bundleId = request.getProcessBundle().getProcessBundleDescriptorReference(); BeamFnApi.ProcessBundleDescriptor bundleDescriptor = (BeamFnApi.ProcessBundleDescriptor) fnApiRegistry.apply(bundleId); Multimap<BeamFnApi.Target, ThrowingConsumer<WindowedValue<Object>>> outputTargetToConsumer = HashMultimap.create(); List<ThrowingRunnable> startFunctions = new ArrayList<>(); List<ThrowingRunnable> finishFunctions = new ArrayList<>(); // We process the primitive transform list in reverse order // because we assume that the runner provides it in topologically order. // This means that all the start/finish functions will be in reverse topological order. for (BeamFnApi.PrimitiveTransform primitiveTransform : Lists.reverse(bundleDescriptor.getPrimitiveTransformList())) { createConsumersForPrimitiveTransform( primitiveTransform, request::getInstructionId, outputTargetToConsumer::get, outputTargetToConsumer::put, startFunctions::add, finishFunctions::add); } // Already in reverse order so we don't need to do anything. for (ThrowingRunnable startFunction : startFunctions) { LOG.debug("Starting function {}", startFunction); startFunction.run(); } // Need to reverse this since we want to call finish in topological order. for (ThrowingRunnable finishFunction : Lists.reverse(finishFunctions)) { LOG.debug("Finishing function {}", finishFunction); finishFunction.run(); } return response; } /** * Converts a {@link org.apache.beam.fn.v1.BeamFnApi.FunctionSpec} into a {@link DoFnRunner}. */ private <InputT, OutputT> DoFnRunner<InputT, OutputT> createDoFnRunner( BeamFnApi.FunctionSpec functionSpec, Map<String, Collection<ThrowingConsumer<WindowedValue<OutputT>>>> outputMap) { ByteString serializedFn; try { serializedFn = functionSpec.getData().unpack(BytesValue.class).getValue(); } catch (InvalidProtocolBufferException e) { throw new IllegalArgumentException( String.format("Unable to unwrap DoFn %s", functionSpec), e); } DoFnInfo<?, ?> doFnInfo = (DoFnInfo<?, ?>) SerializableUtils.deserializeFromByteArray(serializedFn.toByteArray(), "DoFnInfo"); checkArgument( Objects.equals( new HashSet<>(Collections2.transform(outputMap.keySet(), Long::parseLong)), doFnInfo.getOutputMap().keySet()), "Unexpected mismatch between transform output map %s and DoFnInfo output map %s.", outputMap.keySet(), doFnInfo.getOutputMap()); ImmutableMultimap.Builder<TupleTag<?>, ThrowingConsumer<WindowedValue<OutputT>>> tagToOutput = ImmutableMultimap.builder(); for (Map.Entry<Long, TupleTag<?>> entry : doFnInfo.getOutputMap().entrySet()) { tagToOutput.putAll(entry.getValue(), outputMap.get(Long.toString(entry.getKey()))); } @SuppressWarnings({"unchecked", "rawtypes"}) final Map<TupleTag<?>, Collection<ThrowingConsumer<WindowedValue<?>>>> tagBasedOutputMap = (Map) tagToOutput.build().asMap(); OutputManager outputManager = new OutputManager() { Map<TupleTag<?>, Collection<ThrowingConsumer<WindowedValue<?>>>> tupleTagToOutput = tagBasedOutputMap; @Override public <T> void output(TupleTag<T> tag, WindowedValue<T> output) { try { Collection<ThrowingConsumer<WindowedValue<?>>> consumers = tupleTagToOutput.get(tag); if (consumers == null) { /* This is a normal case, e.g., if a DoFn has output but that output is not * consumed. Drop the output. */ return; } for (ThrowingConsumer<WindowedValue<?>> consumer : consumers) { consumer.accept(output); } } catch (Throwable t) { throw new RuntimeException(t); } } }; @SuppressWarnings({"unchecked", "rawtypes", "deprecation"}) DoFnRunner<InputT, OutputT> runner = DoFnRunners.simpleRunner( PipelineOptionsFactory.create(), /* TODO */ (DoFn) doFnInfo.getDoFn(), NullSideInputReader.empty(), /* TODO */ outputManager, (TupleTag) doFnInfo.getOutputMap().get(doFnInfo.getMainOutput()), new ArrayList<>(doFnInfo.getOutputMap().values()), new FakeStepContext(), (WindowingStrategy) doFnInfo.getWindowingStrategy()); return runner; } private <InputT extends BoundedSource<OutputT>, OutputT> BoundedSourceRunner<InputT, OutputT> createBoundedSourceRunner( BeamFnApi.FunctionSpec functionSpec, Map<String, Collection<ThrowingConsumer<WindowedValue<OutputT>>>> outputMap) { @SuppressWarnings({"rawtypes", "unchecked"}) BoundedSourceRunner<InputT, OutputT> runner = new BoundedSourceRunner(options, functionSpec, outputMap); return runner; } }