/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.flink.optimizer.postpass; import java.util.Arrays; import java.util.HashSet; import java.util.Set; import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.operators.DualInputOperator; import org.apache.flink.api.common.operators.GenericDataSourceBase; import org.apache.flink.api.common.operators.Operator; import org.apache.flink.api.common.operators.SingleInputOperator; import org.apache.flink.api.common.operators.base.BulkIterationBase; import org.apache.flink.api.common.operators.base.DeltaIterationBase; import org.apache.flink.api.common.operators.base.GroupReduceOperatorBase; import org.apache.flink.api.common.operators.util.FieldList; import org.apache.flink.api.common.typeinfo.AtomicType; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeutils.CompositeType; import org.apache.flink.api.common.typeutils.TypeComparator; import org.apache.flink.api.common.typeutils.TypeComparatorFactory; import org.apache.flink.api.common.typeutils.TypePairComparatorFactory; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.common.typeutils.TypeSerializerFactory; import org.apache.flink.api.java.operators.translation.PlanUnwrappingReduceGroupOperator; import org.apache.flink.api.java.tuple.Tuple; import org.apache.flink.api.java.typeutils.runtime.RuntimeComparatorFactory; import org.apache.flink.api.java.typeutils.runtime.RuntimePairComparatorFactory; import org.apache.flink.api.java.typeutils.runtime.RuntimeSerializerFactory; import org.apache.flink.optimizer.CompilerException; import org.apache.flink.optimizer.CompilerPostPassException; import org.apache.flink.optimizer.plan.BulkIterationPlanNode; import org.apache.flink.optimizer.plan.BulkPartialSolutionPlanNode; import org.apache.flink.optimizer.plan.Channel; import org.apache.flink.optimizer.plan.DualInputPlanNode; import org.apache.flink.optimizer.plan.NAryUnionPlanNode; import org.apache.flink.optimizer.plan.OptimizedPlan; import org.apache.flink.optimizer.plan.PlanNode; import org.apache.flink.optimizer.plan.SingleInputPlanNode; import org.apache.flink.optimizer.plan.SinkPlanNode; import org.apache.flink.optimizer.plan.SolutionSetPlanNode; import org.apache.flink.optimizer.plan.SourcePlanNode; import org.apache.flink.optimizer.plan.WorksetIterationPlanNode; import org.apache.flink.optimizer.plan.WorksetPlanNode; import org.apache.flink.optimizer.util.NoOpUnaryUdfOp; import org.apache.flink.runtime.operators.DriverStrategy; /** * The post-optimizer plan traversal. This traversal fills in the API specific utilities (serializers and * comparators). */ public class JavaApiPostPass implements OptimizerPostPass { private final Set<PlanNode> alreadyDone = new HashSet<PlanNode>(); private ExecutionConfig executionConfig = null; @Override public void postPass(OptimizedPlan plan) { executionConfig = plan.getOriginalPlan().getExecutionConfig(); for (SinkPlanNode sink : plan.getDataSinks()) { traverse(sink); } } protected void traverse(PlanNode node) { if (!alreadyDone.add(node)) { // already worked on that one return; } // distinguish the node types if (node instanceof SinkPlanNode) { // descend to the input channel SinkPlanNode sn = (SinkPlanNode) node; Channel inchannel = sn.getInput(); traverseChannel(inchannel); } else if (node instanceof SourcePlanNode) { TypeInformation<?> typeInfo = getTypeInfoFromSource((SourcePlanNode) node); ((SourcePlanNode) node).setSerializer(createSerializer(typeInfo)); } else if (node instanceof BulkIterationPlanNode) { BulkIterationPlanNode iterationNode = (BulkIterationPlanNode) node; if (iterationNode.getRootOfStepFunction() instanceof NAryUnionPlanNode) { throw new CompilerException("Optimizer cannot compile an iteration step function where next partial solution is created by a Union node."); } // traverse the termination criterion for the first time. create schema only, no utilities. Needed in case of intermediate termination criterion if (iterationNode.getRootOfTerminationCriterion() != null) { SingleInputPlanNode addMapper = (SingleInputPlanNode) iterationNode.getRootOfTerminationCriterion(); traverseChannel(addMapper.getInput()); } BulkIterationBase<?> operator = (BulkIterationBase<?>) iterationNode.getProgramOperator(); // set the serializer iterationNode.setSerializerForIterationChannel(createSerializer(operator.getOperatorInfo().getOutputType())); // done, we can now propagate our info down traverseChannel(iterationNode.getInput()); traverse(iterationNode.getRootOfStepFunction()); } else if (node instanceof WorksetIterationPlanNode) { WorksetIterationPlanNode iterationNode = (WorksetIterationPlanNode) node; if (iterationNode.getNextWorkSetPlanNode() instanceof NAryUnionPlanNode) { throw new CompilerException("Optimizer cannot compile a workset iteration step function where the next workset is produced by a Union node."); } if (iterationNode.getSolutionSetDeltaPlanNode() instanceof NAryUnionPlanNode) { throw new CompilerException("Optimizer cannot compile a workset iteration step function where the solution set delta is produced by a Union node."); } DeltaIterationBase<?, ?> operator = (DeltaIterationBase<?, ?>) iterationNode.getProgramOperator(); // set the serializers and comparators for the workset iteration iterationNode.setSolutionSetSerializer(createSerializer(operator.getOperatorInfo().getFirstInputType())); iterationNode.setWorksetSerializer(createSerializer(operator.getOperatorInfo().getSecondInputType())); iterationNode.setSolutionSetComparator(createComparator(operator.getOperatorInfo().getFirstInputType(), iterationNode.getSolutionSetKeyFields(), getSortOrders(iterationNode.getSolutionSetKeyFields(), null))); // traverse the inputs traverseChannel(iterationNode.getInput1()); traverseChannel(iterationNode.getInput2()); // traverse the step function traverse(iterationNode.getSolutionSetDeltaPlanNode()); traverse(iterationNode.getNextWorkSetPlanNode()); } else if (node instanceof SingleInputPlanNode) { SingleInputPlanNode sn = (SingleInputPlanNode) node; if (!(sn.getOptimizerNode().getOperator() instanceof SingleInputOperator)) { // Special case for delta iterations if(sn.getOptimizerNode().getOperator() instanceof NoOpUnaryUdfOp) { traverseChannel(sn.getInput()); return; } else { throw new RuntimeException("Wrong operator type found in post pass."); } } SingleInputOperator<?, ?, ?> singleInputOperator = (SingleInputOperator<?, ?, ?>) sn.getOptimizerNode().getOperator(); // parameterize the node's driver strategy for(int i=0;i<sn.getDriverStrategy().getNumRequiredComparators();i++) { sn.setComparator(createComparator(singleInputOperator.getOperatorInfo().getInputType(), sn.getKeys(i), getSortOrders(sn.getKeys(i), sn.getSortOrders(i))), i); } // done, we can now propagate our info down traverseChannel(sn.getInput()); // don't forget the broadcast inputs for (Channel c: sn.getBroadcastInputs()) { traverseChannel(c); } } else if (node instanceof DualInputPlanNode) { DualInputPlanNode dn = (DualInputPlanNode) node; if (!(dn.getOptimizerNode().getOperator() instanceof DualInputOperator)) { throw new RuntimeException("Wrong operator type found in post pass."); } DualInputOperator<?, ?, ?, ?> dualInputOperator = (DualInputOperator<?, ?, ?, ?>) dn.getOptimizerNode().getOperator(); // parameterize the node's driver strategy if (dn.getDriverStrategy().getNumRequiredComparators() > 0) { dn.setComparator1(createComparator(dualInputOperator.getOperatorInfo().getFirstInputType(), dn.getKeysForInput1(), getSortOrders(dn.getKeysForInput1(), dn.getSortOrders()))); dn.setComparator2(createComparator(dualInputOperator.getOperatorInfo().getSecondInputType(), dn.getKeysForInput2(), getSortOrders(dn.getKeysForInput2(), dn.getSortOrders()))); dn.setPairComparator(createPairComparator(dualInputOperator.getOperatorInfo().getFirstInputType(), dualInputOperator.getOperatorInfo().getSecondInputType())); } traverseChannel(dn.getInput1()); traverseChannel(dn.getInput2()); // don't forget the broadcast inputs for (Channel c: dn.getBroadcastInputs()) { traverseChannel(c); } } // catch the sources of the iterative step functions else if (node instanceof BulkPartialSolutionPlanNode || node instanceof SolutionSetPlanNode || node instanceof WorksetPlanNode) { // Do nothing :D } else if (node instanceof NAryUnionPlanNode){ // Traverse to all child channels for (Channel channel : node.getInputs()) { traverseChannel(channel); } } else { throw new CompilerPostPassException("Unknown node type encountered: " + node.getClass().getName()); } } private void traverseChannel(Channel channel) { PlanNode source = channel.getSource(); Operator<?> javaOp = source.getProgramOperator(); // if (!(javaOp instanceof BulkIteration) && !(javaOp instanceof JavaPlanNode)) { // throw new RuntimeException("Wrong operator type found in post pass: " + javaOp); // } TypeInformation<?> type = javaOp.getOperatorInfo().getOutputType(); if(javaOp instanceof GroupReduceOperatorBase && (source.getDriverStrategy() == DriverStrategy.SORTED_GROUP_COMBINE || source.getDriverStrategy() == DriverStrategy.ALL_GROUP_REDUCE_COMBINE)) { GroupReduceOperatorBase<?, ?, ?> groupNode = (GroupReduceOperatorBase<?, ?, ?>) javaOp; type = groupNode.getInput().getOperatorInfo().getOutputType(); } else if(javaOp instanceof PlanUnwrappingReduceGroupOperator && source.getDriverStrategy().equals(DriverStrategy.SORTED_GROUP_COMBINE)) { PlanUnwrappingReduceGroupOperator<?, ?, ?> groupNode = (PlanUnwrappingReduceGroupOperator<?, ?, ?>) javaOp; type = groupNode.getInput().getOperatorInfo().getOutputType(); } // the serializer always exists channel.setSerializer(createSerializer(type)); // parameterize the ship strategy if (channel.getShipStrategy().requiresComparator()) { channel.setShipStrategyComparator(createComparator(type, channel.getShipStrategyKeys(), getSortOrders(channel.getShipStrategyKeys(), channel.getShipStrategySortOrder()))); } // parameterize the local strategy if (channel.getLocalStrategy().requiresComparator()) { channel.setLocalStrategyComparator(createComparator(type, channel.getLocalStrategyKeys(), getSortOrders(channel.getLocalStrategyKeys(), channel.getLocalStrategySortOrder()))); } // descend to the channel's source traverse(channel.getSource()); } @SuppressWarnings("unchecked") private static <T> TypeInformation<T> getTypeInfoFromSource(SourcePlanNode node) { Operator<?> op = node.getOptimizerNode().getOperator(); if (op instanceof GenericDataSourceBase) { return ((GenericDataSourceBase<T, ?>) op).getOperatorInfo().getOutputType(); } else { throw new RuntimeException("Wrong operator type found in post pass."); } } private <T> TypeSerializerFactory<?> createSerializer(TypeInformation<T> typeInfo) { TypeSerializer<T> serializer = typeInfo.createSerializer(executionConfig); return new RuntimeSerializerFactory<T>(serializer, typeInfo.getTypeClass()); } @SuppressWarnings("unchecked") private <T> TypeComparatorFactory<?> createComparator(TypeInformation<T> typeInfo, FieldList keys, boolean[] sortOrder) { TypeComparator<T> comparator; if (typeInfo instanceof CompositeType) { comparator = ((CompositeType<T>) typeInfo).createComparator(keys.toArray(), sortOrder, 0, executionConfig); } else if (typeInfo instanceof AtomicType) { // handle grouping of atomic types comparator = ((AtomicType<T>) typeInfo).createComparator(sortOrder[0], executionConfig); } else { throw new RuntimeException("Unrecognized type: " + typeInfo); } return new RuntimeComparatorFactory<T>(comparator); } private static <T1 extends Tuple, T2 extends Tuple> TypePairComparatorFactory<T1,T2> createPairComparator(TypeInformation<?> typeInfo1, TypeInformation<?> typeInfo2) { // @SuppressWarnings("unchecked") // TupleTypeInfo<T1> info1 = (TupleTypeInfo<T1>) typeInfo1; // @SuppressWarnings("unchecked") // TupleTypeInfo<T2> info2 = (TupleTypeInfo<T2>) typeInfo2; return new RuntimePairComparatorFactory<T1,T2>(); } private static final boolean[] getSortOrders(FieldList keys, boolean[] orders) { if (orders == null) { orders = new boolean[keys.size()]; Arrays.fill(orders, true); } return orders; } }