/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.streaming.api.graph;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.flink.annotation.Internal;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.source.InputFormatSourceFunction;
import org.apache.flink.streaming.api.transformations.CoFeedbackTransformation;
import org.apache.flink.streaming.api.transformations.FeedbackTransformation;
import org.apache.flink.streaming.api.transformations.OneInputTransformation;
import org.apache.flink.streaming.api.transformations.PartitionTransformation;
import org.apache.flink.streaming.api.transformations.SelectTransformation;
import org.apache.flink.streaming.api.transformations.SideOutputTransformation;
import org.apache.flink.streaming.api.transformations.SinkTransformation;
import org.apache.flink.streaming.api.transformations.SourceTransformation;
import org.apache.flink.streaming.api.transformations.SplitTransformation;
import org.apache.flink.streaming.api.transformations.StreamTransformation;
import org.apache.flink.streaming.api.transformations.TwoInputTransformation;
import org.apache.flink.streaming.api.transformations.UnionTransformation;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* A generator that generates a {@link StreamGraph} from a graph of
* {@link StreamTransformation StreamTransformations}.
*
* <p>This traverses the tree of {@code StreamTransformations} starting from the sinks. At each
* transformation we recursively transform the inputs, then create a node in the {@code StreamGraph}
* and add edges from the input Nodes to our newly created node. The transformation methods
* return the IDs of the nodes in the StreamGraph that represent the input transformation. Several
* IDs can be returned to be able to deal with feedback transformations and unions.
*
* <p>Partitioning, split/select and union don't create actual nodes in the {@code StreamGraph}. For
* these, we create a virtual node in the {@code StreamGraph} that holds the specific property, i.e.
* partitioning, selector and so on. When an edge is created from a virtual node to a downstream
* node the {@code StreamGraph} resolved the id of the original node and creates an edge
* in the graph with the desired property. For example, if you have this graph:
*
* <pre>
* Map-1 -> HashPartition-2 -> Map-3
* </pre>
*
* <p>where the numbers represent transformation IDs. We first recurse all the way down. {@code Map-1}
* is transformed, i.e. we create a {@code StreamNode} with ID 1. Then we transform the
* {@code HashPartition}, for this, we create virtual node of ID 4 that holds the property
* {@code HashPartition}. This transformation returns the ID 4. Then we transform the {@code Map-3}.
* We add the edge {@code 4 -> 3}. The {@code StreamGraph} resolved the actual node with ID 1 and
* creates and edge {@code 1 -> 3} with the property HashPartition.
*/
@Internal
public class StreamGraphGenerator {
private static final Logger LOG = LoggerFactory.getLogger(StreamGraphGenerator.class);
public static final int DEFAULT_LOWER_BOUND_MAX_PARALLELISM = KeyGroupRangeAssignment.DEFAULT_LOWER_BOUND_MAX_PARALLELISM;
public static final int UPPER_BOUND_MAX_PARALLELISM = KeyGroupRangeAssignment.UPPER_BOUND_MAX_PARALLELISM;
// The StreamGraph that is being built, this is initialized at the beginning.
private final StreamGraph streamGraph;
private final StreamExecutionEnvironment env;
// This is used to assign a unique ID to iteration source/sink
protected static Integer iterationIdCounter = 0;
public static int getNewIterationNodeId() {
iterationIdCounter--;
return iterationIdCounter;
}
// Keep track of which Transforms we have already transformed, this is necessary because
// we have loops, i.e. feedback edges.
private Map<StreamTransformation<?>, Collection<Integer>> alreadyTransformed;
/**
* Private constructor. The generator should only be invoked using {@link #generate}.
*/
private StreamGraphGenerator(StreamExecutionEnvironment env) {
this.streamGraph = new StreamGraph(env);
this.streamGraph.setChaining(env.isChainingEnabled());
this.streamGraph.setStateBackend(env.getStateBackend());
this.env = env;
this.alreadyTransformed = new HashMap<>();
}
/**
* Generates a {@code StreamGraph} by traversing the graph of {@code StreamTransformations}
* starting from the given transformations.
*
* @param env The {@code StreamExecutionEnvironment} that is used to set some parameters of the
* job
* @param transformations The transformations starting from which to transform the graph
*
* @return The generated {@code StreamGraph}
*/
public static StreamGraph generate(StreamExecutionEnvironment env, List<StreamTransformation<?>> transformations) {
return new StreamGraphGenerator(env).generateInternal(transformations);
}
/**
* This starts the actual transformation, beginning from the sinks.
*/
private StreamGraph generateInternal(List<StreamTransformation<?>> transformations) {
for (StreamTransformation<?> transformation: transformations) {
transform(transformation);
}
return streamGraph;
}
/**
* Transforms one {@code StreamTransformation}.
*
* <p>This checks whether we already transformed it and exits early in that case. If not it
* delegates to one of the transformation specific methods.
*/
private Collection<Integer> transform(StreamTransformation<?> transform) {
if (alreadyTransformed.containsKey(transform)) {
return alreadyTransformed.get(transform);
}
LOG.debug("Transforming " + transform);
if (transform.getMaxParallelism() <= 0) {
// if the max parallelism hasn't been set, then first use the job wide max parallelism
// from theExecutionConfig.
int globalMaxParallelismFromConfig = env.getConfig().getMaxParallelism();
if (globalMaxParallelismFromConfig > 0) {
transform.setMaxParallelism(globalMaxParallelismFromConfig);
}
}
// call at least once to trigger exceptions about MissingTypeInfo
transform.getOutputType();
Collection<Integer> transformedIds;
if (transform instanceof OneInputTransformation<?, ?>) {
transformedIds = transformOneInputTransform((OneInputTransformation<?, ?>) transform);
} else if (transform instanceof TwoInputTransformation<?, ?, ?>) {
transformedIds = transformTwoInputTransform((TwoInputTransformation<?, ?, ?>) transform);
} else if (transform instanceof SourceTransformation<?>) {
transformedIds = transformSource((SourceTransformation<?>) transform);
} else if (transform instanceof SinkTransformation<?>) {
transformedIds = transformSink((SinkTransformation<?>) transform);
} else if (transform instanceof UnionTransformation<?>) {
transformedIds = transformUnion((UnionTransformation<?>) transform);
} else if (transform instanceof SplitTransformation<?>) {
transformedIds = transformSplit((SplitTransformation<?>) transform);
} else if (transform instanceof SelectTransformation<?>) {
transformedIds = transformSelect((SelectTransformation<?>) transform);
} else if (transform instanceof FeedbackTransformation<?>) {
transformedIds = transformFeedback((FeedbackTransformation<?>) transform);
} else if (transform instanceof CoFeedbackTransformation<?>) {
transformedIds = transformCoFeedback((CoFeedbackTransformation<?>) transform);
} else if (transform instanceof PartitionTransformation<?>) {
transformedIds = transformPartition((PartitionTransformation<?>) transform);
} else if (transform instanceof SideOutputTransformation<?>) {
transformedIds = transformSideOutput((SideOutputTransformation<?>) transform);
} else {
throw new IllegalStateException("Unknown transformation: " + transform);
}
// need this check because the iterate transformation adds itself before
// transforming the feedback edges
if (!alreadyTransformed.containsKey(transform)) {
alreadyTransformed.put(transform, transformedIds);
}
if (transform.getBufferTimeout() > 0) {
streamGraph.setBufferTimeout(transform.getId(), transform.getBufferTimeout());
}
if (transform.getUid() != null) {
streamGraph.setTransformationUID(transform.getId(), transform.getUid());
}
if (transform.getUserProvidedNodeHash() != null) {
streamGraph.setTransformationUserHash(transform.getId(), transform.getUserProvidedNodeHash());
}
if (transform.getMinResources() != null && transform.getPreferredResources() != null) {
streamGraph.setResources(transform.getId(), transform.getMinResources(), transform.getPreferredResources());
}
return transformedIds;
}
/**
* Transforms a {@code UnionTransformation}.
*
* <p>This is easy, we only have to transform the inputs and return all the IDs in a list so
* that downstream operations can connect to all upstream nodes.
*/
private <T> Collection<Integer> transformUnion(UnionTransformation<T> union) {
List<StreamTransformation<T>> inputs = union.getInputs();
List<Integer> resultIds = new ArrayList<>();
for (StreamTransformation<T> input: inputs) {
resultIds.addAll(transform(input));
}
return resultIds;
}
/**
* Transforms a {@code PartitionTransformation}.
*
* <p>For this we create a virtual node in the {@code StreamGraph} that holds the partition
* property. @see StreamGraphGenerator
*/
private <T> Collection<Integer> transformPartition(PartitionTransformation<T> partition) {
StreamTransformation<T> input = partition.getInput();
List<Integer> resultIds = new ArrayList<>();
Collection<Integer> transformedIds = transform(input);
for (Integer transformedId: transformedIds) {
int virtualId = StreamTransformation.getNewNodeId();
streamGraph.addVirtualPartitionNode(transformedId, virtualId, partition.getPartitioner());
resultIds.add(virtualId);
}
return resultIds;
}
/**
* Transforms a {@code SplitTransformation}.
*
* <p>We add the output selector to previously transformed nodes.
*/
private <T> Collection<Integer> transformSplit(SplitTransformation<T> split) {
StreamTransformation<T> input = split.getInput();
Collection<Integer> resultIds = transform(input);
// the recursive transform call might have transformed this already
if (alreadyTransformed.containsKey(split)) {
return alreadyTransformed.get(split);
}
for (int inputId : resultIds) {
streamGraph.addOutputSelector(inputId, split.getOutputSelector());
}
return resultIds;
}
/**
* Transforms a {@code SelectTransformation}.
*
* <p>For this we create a virtual node in the {@code StreamGraph} holds the selected names.
*
* @see org.apache.flink.streaming.api.graph.StreamGraphGenerator
*/
private <T> Collection<Integer> transformSelect(SelectTransformation<T> select) {
StreamTransformation<T> input = select.getInput();
Collection<Integer> resultIds = transform(input);
// the recursive transform might have already transformed this
if (alreadyTransformed.containsKey(select)) {
return alreadyTransformed.get(select);
}
List<Integer> virtualResultIds = new ArrayList<>();
for (int inputId : resultIds) {
int virtualId = StreamTransformation.getNewNodeId();
streamGraph.addVirtualSelectNode(inputId, virtualId, select.getSelectedNames());
virtualResultIds.add(virtualId);
}
return virtualResultIds;
}
/**
* Transforms a {@code SideOutputTransformation}.
*
* <p>For this we create a virtual node in the {@code StreamGraph} that holds the side-output
* {@link org.apache.flink.util.OutputTag}.
*
* @see org.apache.flink.streaming.api.graph.StreamGraphGenerator
*/
private <T> Collection<Integer> transformSideOutput(SideOutputTransformation<T> sideOutput) {
StreamTransformation<?> input = sideOutput.getInput();
Collection<Integer> resultIds = transform(input);
// the recursive transform might have already transformed this
if (alreadyTransformed.containsKey(sideOutput)) {
return alreadyTransformed.get(sideOutput);
}
List<Integer> virtualResultIds = new ArrayList<>();
for (int inputId : resultIds) {
int virtualId = StreamTransformation.getNewNodeId();
streamGraph.addVirtualSideOutputNode(inputId, virtualId, sideOutput.getOutputTag());
virtualResultIds.add(virtualId);
}
return virtualResultIds;
}
/**
* Transforms a {@code FeedbackTransformation}.
*
* <p>This will recursively transform the input and the feedback edges. We return the
* concatenation of the input IDs and the feedback IDs so that downstream operations can be
* wired to both.
*
* <p>This is responsible for creating the IterationSource and IterationSink which are used to
* feed back the elements.
*/
private <T> Collection<Integer> transformFeedback(FeedbackTransformation<T> iterate) {
if (iterate.getFeedbackEdges().size() <= 0) {
throw new IllegalStateException("Iteration " + iterate + " does not have any feedback edges.");
}
StreamTransformation<T> input = iterate.getInput();
List<Integer> resultIds = new ArrayList<>();
// first transform the input stream(s) and store the result IDs
Collection<Integer> inputIds = transform(input);
resultIds.addAll(inputIds);
// the recursive transform might have already transformed this
if (alreadyTransformed.containsKey(iterate)) {
return alreadyTransformed.get(iterate);
}
// create the fake iteration source/sink pair
Tuple2<StreamNode, StreamNode> itSourceAndSink = streamGraph.createIterationSourceAndSink(
iterate.getId(),
getNewIterationNodeId(),
getNewIterationNodeId(),
iterate.getWaitTime(),
iterate.getParallelism(),
iterate.getMaxParallelism(),
iterate.getMinResources(),
iterate.getPreferredResources());
StreamNode itSource = itSourceAndSink.f0;
StreamNode itSink = itSourceAndSink.f1;
// We set the proper serializers for the sink/source
streamGraph.setSerializers(itSource.getId(), null, null, iterate.getOutputType().createSerializer(env.getConfig()));
streamGraph.setSerializers(itSink.getId(), iterate.getOutputType().createSerializer(env.getConfig()), null, null);
// also add the feedback source ID to the result IDs, so that downstream operators will
// add both as input
resultIds.add(itSource.getId());
// at the iterate to the already-seen-set with the result IDs, so that we can transform
// the feedback edges and let them stop when encountering the iterate node
alreadyTransformed.put(iterate, resultIds);
// so that we can determine the slot sharing group from all feedback edges
List<Integer> allFeedbackIds = new ArrayList<>();
for (StreamTransformation<T> feedbackEdge : iterate.getFeedbackEdges()) {
Collection<Integer> feedbackIds = transform(feedbackEdge);
allFeedbackIds.addAll(feedbackIds);
for (Integer feedbackId: feedbackIds) {
streamGraph.addEdge(feedbackId,
itSink.getId(),
0
);
}
}
String slotSharingGroup = determineSlotSharingGroup(null, allFeedbackIds);
itSink.setSlotSharingGroup(slotSharingGroup);
itSource.setSlotSharingGroup(slotSharingGroup);
return resultIds;
}
/**
* Transforms a {@code CoFeedbackTransformation}.
*
* <p>This will only transform feedback edges, the result of this transform will be wired
* to the second input of a Co-Transform. The original input is wired directly to the first
* input of the downstream Co-Transform.
*
* <p>This is responsible for creating the IterationSource and IterationSink which
* are used to feed back the elements.
*/
private <F> Collection<Integer> transformCoFeedback(CoFeedbackTransformation<F> coIterate) {
// For Co-Iteration we don't need to transform the input and wire the input to the
// head operator by returning the input IDs, the input is directly wired to the left
// input of the co-operation. This transform only needs to return the ids of the feedback
// edges, since they need to be wired to the second input of the co-operation.
// create the fake iteration source/sink pair
Tuple2<StreamNode, StreamNode> itSourceAndSink = streamGraph.createIterationSourceAndSink(
coIterate.getId(),
getNewIterationNodeId(),
getNewIterationNodeId(),
coIterate.getWaitTime(),
coIterate.getParallelism(),
coIterate.getMaxParallelism(),
coIterate.getMinResources(),
coIterate.getPreferredResources());
StreamNode itSource = itSourceAndSink.f0;
StreamNode itSink = itSourceAndSink.f1;
// We set the proper serializers for the sink/source
streamGraph.setSerializers(itSource.getId(), null, null, coIterate.getOutputType().createSerializer(env.getConfig()));
streamGraph.setSerializers(itSink.getId(), coIterate.getOutputType().createSerializer(env.getConfig()), null, null);
Collection<Integer> resultIds = Collections.singleton(itSource.getId());
// at the iterate to the already-seen-set with the result IDs, so that we can transform
// the feedback edges and let them stop when encountering the iterate node
alreadyTransformed.put(coIterate, resultIds);
// so that we can determine the slot sharing group from all feedback edges
List<Integer> allFeedbackIds = new ArrayList<>();
for (StreamTransformation<F> feedbackEdge : coIterate.getFeedbackEdges()) {
Collection<Integer> feedbackIds = transform(feedbackEdge);
allFeedbackIds.addAll(feedbackIds);
for (Integer feedbackId: feedbackIds) {
streamGraph.addEdge(feedbackId,
itSink.getId(),
0
);
}
}
String slotSharingGroup = determineSlotSharingGroup(null, allFeedbackIds);
itSink.setSlotSharingGroup(slotSharingGroup);
itSource.setSlotSharingGroup(slotSharingGroup);
return Collections.singleton(itSource.getId());
}
/**
* Transforms a {@code SourceTransformation}.
*/
private <T> Collection<Integer> transformSource(SourceTransformation<T> source) {
String slotSharingGroup = determineSlotSharingGroup(source.getSlotSharingGroup(), new ArrayList<Integer>());
streamGraph.addSource(source.getId(),
slotSharingGroup,
source.getOperator(),
null,
source.getOutputType(),
"Source: " + source.getName());
if (source.getOperator().getUserFunction() instanceof InputFormatSourceFunction) {
InputFormatSourceFunction<T> fs = (InputFormatSourceFunction<T>) source.getOperator().getUserFunction();
streamGraph.setInputFormat(source.getId(), fs.getFormat());
}
streamGraph.setParallelism(source.getId(), source.getParallelism());
streamGraph.setMaxParallelism(source.getId(), source.getMaxParallelism());
return Collections.singleton(source.getId());
}
/**
* Transforms a {@code SourceTransformation}.
*/
private <T> Collection<Integer> transformSink(SinkTransformation<T> sink) {
Collection<Integer> inputIds = transform(sink.getInput());
String slotSharingGroup = determineSlotSharingGroup(sink.getSlotSharingGroup(), inputIds);
streamGraph.addSink(sink.getId(),
slotSharingGroup,
sink.getOperator(),
sink.getInput().getOutputType(),
null,
"Sink: " + sink.getName());
streamGraph.setParallelism(sink.getId(), sink.getParallelism());
streamGraph.setMaxParallelism(sink.getId(), sink.getMaxParallelism());
for (Integer inputId: inputIds) {
streamGraph.addEdge(inputId,
sink.getId(),
0
);
}
if (sink.getStateKeySelector() != null) {
TypeSerializer<?> keySerializer = sink.getStateKeyType().createSerializer(env.getConfig());
streamGraph.setOneInputStateKey(sink.getId(), sink.getStateKeySelector(), keySerializer);
}
return Collections.emptyList();
}
/**
* Transforms a {@code OneInputTransformation}.
*
* <p>This recursively transforms the inputs, creates a new {@code StreamNode} in the graph and
* wired the inputs to this new node.
*/
private <IN, OUT> Collection<Integer> transformOneInputTransform(OneInputTransformation<IN, OUT> transform) {
Collection<Integer> inputIds = transform(transform.getInput());
// the recursive call might have already transformed this
if (alreadyTransformed.containsKey(transform)) {
return alreadyTransformed.get(transform);
}
String slotSharingGroup = determineSlotSharingGroup(transform.getSlotSharingGroup(), inputIds);
streamGraph.addOperator(transform.getId(),
slotSharingGroup,
transform.getOperator(),
transform.getInputType(),
transform.getOutputType(),
transform.getName());
if (transform.getStateKeySelector() != null) {
TypeSerializer<?> keySerializer = transform.getStateKeyType().createSerializer(env.getConfig());
streamGraph.setOneInputStateKey(transform.getId(), transform.getStateKeySelector(), keySerializer);
}
streamGraph.setParallelism(transform.getId(), transform.getParallelism());
streamGraph.setMaxParallelism(transform.getId(), transform.getMaxParallelism());
for (Integer inputId: inputIds) {
streamGraph.addEdge(inputId, transform.getId(), 0);
}
return Collections.singleton(transform.getId());
}
/**
* Transforms a {@code TwoInputTransformation}.
*
* <p>This recusively transforms the inputs, creates a new {@code StreamNode} in the graph and
* wired the inputs to this new node.
*/
private <IN1, IN2, OUT> Collection<Integer> transformTwoInputTransform(TwoInputTransformation<IN1, IN2, OUT> transform) {
Collection<Integer> inputIds1 = transform(transform.getInput1());
Collection<Integer> inputIds2 = transform(transform.getInput2());
// the recursive call might have already transformed this
if (alreadyTransformed.containsKey(transform)) {
return alreadyTransformed.get(transform);
}
List<Integer> allInputIds = new ArrayList<>();
allInputIds.addAll(inputIds1);
allInputIds.addAll(inputIds2);
String slotSharingGroup = determineSlotSharingGroup(transform.getSlotSharingGroup(), allInputIds);
streamGraph.addCoOperator(
transform.getId(),
slotSharingGroup,
transform.getOperator(),
transform.getInputType1(),
transform.getInputType2(),
transform.getOutputType(),
transform.getName());
if (transform.getStateKeySelector1() != null) {
TypeSerializer<?> keySerializer = transform.getStateKeyType().createSerializer(env.getConfig());
streamGraph.setTwoInputStateKey(transform.getId(), transform.getStateKeySelector1(), transform.getStateKeySelector2(), keySerializer);
}
streamGraph.setParallelism(transform.getId(), transform.getParallelism());
streamGraph.setMaxParallelism(transform.getId(), transform.getMaxParallelism());
for (Integer inputId: inputIds1) {
streamGraph.addEdge(inputId,
transform.getId(),
1
);
}
for (Integer inputId: inputIds2) {
streamGraph.addEdge(inputId,
transform.getId(),
2
);
}
return Collections.singleton(transform.getId());
}
/**
* Determines the slot sharing group for an operation based on the slot sharing group set by
* the user and the slot sharing groups of the inputs.
*
* <p>If the user specifies a group name, this is taken as is. If nothing is specified and
* the input operations all have the same group name then this name is taken. Otherwise the
* default group is choosen.
*
* @param specifiedGroup The group specified by the user.
* @param inputIds The IDs of the input operations.
*/
private String determineSlotSharingGroup(String specifiedGroup, Collection<Integer> inputIds) {
if (specifiedGroup != null) {
return specifiedGroup;
} else {
String inputGroup = null;
for (int id: inputIds) {
String inputGroupCandidate = streamGraph.getSlotSharingGroup(id);
if (inputGroup == null) {
inputGroup = inputGroupCandidate;
} else if (!inputGroup.equals(inputGroupCandidate)) {
return "default";
}
}
return inputGroup == null ? "default" : inputGroup;
}
}
}