/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.flink.translation.functions;
import java.util.Collections;
import java.util.Map;
import org.apache.beam.runners.core.DoFnRunner;
import org.apache.beam.runners.core.DoFnRunners;
import org.apache.beam.runners.flink.FlinkPipelineOptions;
import org.apache.beam.runners.flink.metrics.DoFnRunnerWithMetricsUpdate;
import org.apache.beam.runners.flink.translation.utils.SerializedPipelineOptions;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.join.RawUnionValue;
import org.apache.beam.sdk.transforms.reflect.DoFnInvoker;
import org.apache.beam.sdk.transforms.reflect.DoFnInvokers;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.WindowingStrategy;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.util.Collector;
/**
* Encapsulates a {@link DoFn}
* inside a Flink {@link org.apache.flink.api.common.functions.RichMapPartitionFunction}.
*
* <p>We get a mapping from {@link org.apache.beam.sdk.values.TupleTag} to output index
* and must tag all outputs with the output number. Afterwards a filter will filter out
* those elements that are not to be in a specific output.
*/
public class FlinkDoFnFunction<InputT, OutputT>
extends RichMapPartitionFunction<WindowedValue<InputT>, WindowedValue<OutputT>> {
private final SerializedPipelineOptions serializedOptions;
private final DoFn<InputT, OutputT> doFn;
private final String stepName;
private final Map<PCollectionView<?>, WindowingStrategy<?, ?>> sideInputs;
private final WindowingStrategy<?, ?> windowingStrategy;
private final Map<TupleTag<?>, Integer> outputMap;
private final TupleTag<OutputT> mainOutputTag;
private transient DoFnInvoker<InputT, OutputT> doFnInvoker;
public FlinkDoFnFunction(
DoFn<InputT, OutputT> doFn,
String stepName,
WindowingStrategy<?, ?> windowingStrategy,
Map<PCollectionView<?>, WindowingStrategy<?, ?>> sideInputs,
PipelineOptions options,
Map<TupleTag<?>, Integer> outputMap,
TupleTag<OutputT> mainOutputTag) {
this.doFn = doFn;
this.stepName = stepName;
this.sideInputs = sideInputs;
this.serializedOptions = new SerializedPipelineOptions(options);
this.windowingStrategy = windowingStrategy;
this.outputMap = outputMap;
this.mainOutputTag = mainOutputTag;
}
@Override
public void mapPartition(
Iterable<WindowedValue<InputT>> values,
Collector<WindowedValue<OutputT>> out) throws Exception {
RuntimeContext runtimeContext = getRuntimeContext();
DoFnRunners.OutputManager outputManager;
if (outputMap == null) {
outputManager = new FlinkDoFnFunction.DoFnOutputManager(out);
} else {
// it has some additional outputs
outputManager =
new FlinkDoFnFunction.MultiDoFnOutputManager((Collector) out, outputMap);
}
DoFnRunner<InputT, OutputT> doFnRunner = DoFnRunners.simpleRunner(
serializedOptions.getPipelineOptions(), doFn,
new FlinkSideInputReader(sideInputs, runtimeContext),
outputManager,
mainOutputTag,
// see SimpleDoFnRunner, just use it to limit number of additional outputs
Collections.<TupleTag<?>>emptyList(),
new FlinkNoOpStepContext(),
windowingStrategy);
if ((serializedOptions.getPipelineOptions().as(FlinkPipelineOptions.class))
.getEnableMetrics()) {
doFnRunner = new DoFnRunnerWithMetricsUpdate<>(stepName, doFnRunner, getRuntimeContext());
}
doFnRunner.startBundle();
for (WindowedValue<InputT> value : values) {
doFnRunner.processElement(value);
}
doFnRunner.finishBundle();
}
@Override
public void open(Configuration parameters) throws Exception {
doFnInvoker = DoFnInvokers.invokerFor(doFn);
doFnInvoker.invokeSetup();
}
@Override
public void close() throws Exception {
doFnInvoker.invokeTeardown();
}
static class DoFnOutputManager
implements DoFnRunners.OutputManager {
private Collector collector;
DoFnOutputManager(Collector collector) {
this.collector = collector;
}
@Override
@SuppressWarnings("unchecked")
public <T> void output(TupleTag<T> tag, WindowedValue<T> output) {
collector.collect(output);
}
}
static class MultiDoFnOutputManager
implements DoFnRunners.OutputManager {
private Collector<WindowedValue<RawUnionValue>> collector;
private Map<TupleTag<?>, Integer> outputMap;
MultiDoFnOutputManager(Collector<WindowedValue<RawUnionValue>> collector,
Map<TupleTag<?>, Integer> outputMap) {
this.collector = collector;
this.outputMap = outputMap;
}
@Override
public <T> void output(TupleTag<T> tag, WindowedValue<T> output) {
collector.collect(WindowedValue.of(new RawUnionValue(outputMap.get(tag), output.getValue()),
output.getTimestamp(), output.getWindows(), output.getPane()));
}
}
}