/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.flink.api.common.operators; import org.apache.flink.annotation.Internal; import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.InvalidProgramException; import org.apache.flink.api.common.JobExecutionResult; import org.apache.flink.api.common.Plan; import org.apache.flink.api.common.TaskInfo; import org.apache.flink.api.common.accumulators.Accumulator; import org.apache.flink.api.common.accumulators.AccumulatorHelper; import org.apache.flink.api.common.aggregators.Aggregator; import org.apache.flink.api.common.aggregators.AggregatorWithName; import org.apache.flink.api.common.aggregators.ConvergenceCriterion; import org.apache.flink.api.common.cache.DistributedCache; import org.apache.flink.api.common.functions.IterationRuntimeContext; import org.apache.flink.api.common.functions.RichFunction; import org.apache.flink.api.common.functions.util.RuntimeUDFContext; import org.apache.flink.api.common.io.RichInputFormat; import org.apache.flink.api.common.io.RichOutputFormat; import org.apache.flink.api.common.operators.base.BulkIterationBase; import org.apache.flink.api.common.operators.base.BulkIterationBase.PartialSolutionPlaceHolder; import org.apache.flink.api.common.operators.base.DeltaIterationBase; import org.apache.flink.api.common.operators.base.DeltaIterationBase.SolutionSetPlaceHolder; import org.apache.flink.api.common.operators.base.DeltaIterationBase.WorksetPlaceHolder; import org.apache.flink.api.common.operators.util.TypeComparable; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeutils.CompositeType; import org.apache.flink.api.common.typeutils.TypeComparator; import org.apache.flink.core.fs.FileSystem; import org.apache.flink.core.fs.Path; import org.apache.flink.core.fs.local.LocalFileSystem; import org.apache.flink.metrics.MetricGroup; import org.apache.flink.metrics.groups.UnregisteredMetricsGroup; import org.apache.flink.types.Value; import org.apache.flink.util.Visitor; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.LinkedHashSet; import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; /** * Execution utility for serial, local, collection-based executions of Flink programs. */ @Internal public class CollectionExecutor { private final Map<Operator<?>, List<?>> intermediateResults; private final Map<String, Accumulator<?, ?>> accumulators; private final Map<String, Future<Path>> cachedFiles; private final Map<String, Value> previousAggregates; private final Map<String, Aggregator<?>> aggregators; private final ClassLoader classLoader; private final ExecutionConfig executionConfig; private int iterationSuperstep; // -------------------------------------------------------------------------------------------- public CollectionExecutor(ExecutionConfig executionConfig) { this.executionConfig = executionConfig; this.intermediateResults = new HashMap<Operator<?>, List<?>>(); this.accumulators = new HashMap<String, Accumulator<?,?>>(); this.previousAggregates = new HashMap<String, Value>(); this.aggregators = new HashMap<String, Aggregator<?>>(); this.cachedFiles = new HashMap<String, Future<Path>>(); this.classLoader = getClass().getClassLoader(); } // -------------------------------------------------------------------------------------------- // General execution methods // -------------------------------------------------------------------------------------------- public JobExecutionResult execute(Plan program) throws Exception { long startTime = System.currentTimeMillis(); initCache(program.getCachedFiles()); Collection<? extends GenericDataSinkBase<?>> sinks = program.getDataSinks(); for (Operator<?> sink : sinks) { execute(sink); } long endTime = System.currentTimeMillis(); Map<String, Object> accumulatorResults = AccumulatorHelper.toResultMap(accumulators); return new JobExecutionResult(null, endTime - startTime, accumulatorResults); } private void initCache(Set<Map.Entry<String, DistributedCache.DistributedCacheEntry>> files){ for(Map.Entry<String, DistributedCache.DistributedCacheEntry> file: files){ Future<Path> doNothing = new CompletedFuture(new Path(file.getValue().filePath)); cachedFiles.put(file.getKey(), doNothing); } }; private List<?> execute(Operator<?> operator) throws Exception { return execute(operator, 0); } private List<?> execute(Operator<?> operator, int superStep) throws Exception { List<?> result = this.intermediateResults.get(operator); // if it has already been computed, use the cached variant if (result != null) { return result; } if (operator instanceof BulkIterationBase) { result = executeBulkIteration((BulkIterationBase<?>) operator); } else if (operator instanceof DeltaIterationBase) { result = executeDeltaIteration((DeltaIterationBase<?, ?>) operator); } else if (operator instanceof SingleInputOperator) { result = executeUnaryOperator((SingleInputOperator<?, ?, ?>) operator, superStep); } else if (operator instanceof DualInputOperator) { result = executeBinaryOperator((DualInputOperator<?, ?, ?, ?>) operator, superStep); } else if (operator instanceof GenericDataSourceBase) { result = executeDataSource((GenericDataSourceBase<?, ?>) operator, superStep); } else if (operator instanceof GenericDataSinkBase) { executeDataSink((GenericDataSinkBase<?>) operator, superStep); result = Collections.emptyList(); } else { throw new RuntimeException("Cannot execute operator " + operator.getClass().getName()); } this.intermediateResults.put(operator, result); return result; } // -------------------------------------------------------------------------------------------- // Operator class specific execution methods // -------------------------------------------------------------------------------------------- private <IN> void executeDataSink(GenericDataSinkBase<?> sink, int superStep) throws Exception { Operator<?> inputOp = sink.getInput(); if (inputOp == null) { throw new InvalidProgramException("The data sink " + sink.getName() + " has no input."); } @SuppressWarnings("unchecked") List<IN> input = (List<IN>) execute(inputOp); @SuppressWarnings("unchecked") GenericDataSinkBase<IN> typedSink = (GenericDataSinkBase<IN>) sink; // build the runtime context and compute broadcast variables, if necessary TaskInfo taskInfo = new TaskInfo(typedSink.getName(), 1, 0, 1, 0); RuntimeUDFContext ctx; MetricGroup metrics = new UnregisteredMetricsGroup(); if (RichOutputFormat.class.isAssignableFrom(typedSink.getUserCodeWrapper().getUserCodeClass())) { ctx = superStep == 0 ? new RuntimeUDFContext(taskInfo, classLoader, executionConfig, cachedFiles, accumulators, metrics) : new IterationRuntimeUDFContext(taskInfo, classLoader, executionConfig, cachedFiles, accumulators, metrics); } else { ctx = null; } typedSink.executeOnCollections(input, ctx, executionConfig); } private <OUT> List<OUT> executeDataSource(GenericDataSourceBase<?, ?> source, int superStep) throws Exception { @SuppressWarnings("unchecked") GenericDataSourceBase<OUT, ?> typedSource = (GenericDataSourceBase<OUT, ?>) source; // build the runtime context and compute broadcast variables, if necessary TaskInfo taskInfo = new TaskInfo(typedSource.getName(), 1, 0, 1, 0); RuntimeUDFContext ctx; MetricGroup metrics = new UnregisteredMetricsGroup(); if (RichInputFormat.class.isAssignableFrom(typedSource.getUserCodeWrapper().getUserCodeClass())) { ctx = superStep == 0 ? new RuntimeUDFContext(taskInfo, classLoader, executionConfig, cachedFiles, accumulators, metrics) : new IterationRuntimeUDFContext(taskInfo, classLoader, executionConfig, cachedFiles, accumulators, metrics); } else { ctx = null; } return typedSource.executeOnCollections(ctx, executionConfig); } private <IN, OUT> List<OUT> executeUnaryOperator(SingleInputOperator<?, ?, ?> operator, int superStep) throws Exception { Operator<?> inputOp = operator.getInput(); if (inputOp == null) { throw new InvalidProgramException("The unary operation " + operator.getName() + " has no input."); } @SuppressWarnings("unchecked") List<IN> inputData = (List<IN>) execute(inputOp, superStep); @SuppressWarnings("unchecked") SingleInputOperator<IN, OUT, ?> typedOp = (SingleInputOperator<IN, OUT, ?>) operator; // build the runtime context and compute broadcast variables, if necessary TaskInfo taskInfo = new TaskInfo(typedOp.getName(), 1, 0, 1, 0); RuntimeUDFContext ctx; MetricGroup metrics = new UnregisteredMetricsGroup(); if (RichFunction.class.isAssignableFrom(typedOp.getUserCodeWrapper().getUserCodeClass())) { ctx = superStep == 0 ? new RuntimeUDFContext(taskInfo, classLoader, executionConfig, cachedFiles, accumulators, metrics) : new IterationRuntimeUDFContext(taskInfo, classLoader, executionConfig, cachedFiles, accumulators, metrics); for (Map.Entry<String, Operator<?>> bcInputs : operator.getBroadcastInputs().entrySet()) { List<?> bcData = execute(bcInputs.getValue()); ctx.setBroadcastVariable(bcInputs.getKey(), bcData); } } else { ctx = null; } return typedOp.executeOnCollections(inputData, ctx, executionConfig); } private <IN1, IN2, OUT> List<OUT> executeBinaryOperator(DualInputOperator<?, ?, ?, ?> operator, int superStep) throws Exception { Operator<?> inputOp1 = operator.getFirstInput(); Operator<?> inputOp2 = operator.getSecondInput(); if (inputOp1 == null) { throw new InvalidProgramException("The binary operation " + operator.getName() + " has no first input."); } if (inputOp2 == null) { throw new InvalidProgramException("The binary operation " + operator.getName() + " has no second input."); } // compute inputs @SuppressWarnings("unchecked") List<IN1> inputData1 = (List<IN1>) execute(inputOp1, superStep); @SuppressWarnings("unchecked") List<IN2> inputData2 = (List<IN2>) execute(inputOp2, superStep); @SuppressWarnings("unchecked") DualInputOperator<IN1, IN2, OUT, ?> typedOp = (DualInputOperator<IN1, IN2, OUT, ?>) operator; // build the runtime context and compute broadcast variables, if necessary TaskInfo taskInfo = new TaskInfo(typedOp.getName(), 1, 0, 1, 0); RuntimeUDFContext ctx; MetricGroup metrics = new UnregisteredMetricsGroup(); if (RichFunction.class.isAssignableFrom(typedOp.getUserCodeWrapper().getUserCodeClass())) { ctx = superStep == 0 ? new RuntimeUDFContext(taskInfo, classLoader, executionConfig, cachedFiles, accumulators, metrics) : new IterationRuntimeUDFContext(taskInfo, classLoader, executionConfig, cachedFiles, accumulators, metrics); for (Map.Entry<String, Operator<?>> bcInputs : operator.getBroadcastInputs().entrySet()) { List<?> bcData = execute(bcInputs.getValue()); ctx.setBroadcastVariable(bcInputs.getKey(), bcData); } } else { ctx = null; } return typedOp.executeOnCollections(inputData1, inputData2, ctx, executionConfig); } @SuppressWarnings("unchecked") private <T> List<T> executeBulkIteration(BulkIterationBase<?> iteration) throws Exception { Operator<?> inputOp = iteration.getInput(); if (inputOp == null) { throw new InvalidProgramException("The iteration " + iteration.getName() + " has no input (initial partial solution)."); } if (iteration.getNextPartialSolution() == null) { throw new InvalidProgramException("The iteration " + iteration.getName() + " has no next partial solution defined (is not closed)."); } List<T> inputData = (List<T>) execute(inputOp); // get the operators that are iterative Set<Operator<?>> dynamics = new LinkedHashSet<Operator<?>>(); DynamicPathCollector dynCollector = new DynamicPathCollector(dynamics); iteration.getNextPartialSolution().accept(dynCollector); if (iteration.getTerminationCriterion() != null) { iteration.getTerminationCriterion().accept(dynCollector); } // register the aggregators for (AggregatorWithName<?> a : iteration.getAggregators().getAllRegisteredAggregators()) { aggregators.put(a.getName(), a.getAggregator()); } String convCriterionAggName = iteration.getAggregators().getConvergenceCriterionAggregatorName(); ConvergenceCriterion<Value> convCriterion = (ConvergenceCriterion<Value>) iteration.getAggregators().getConvergenceCriterion(); List<T> currentResult = inputData; final int maxIterations = iteration.getMaximumNumberOfIterations(); for (int superstep = 1; superstep <= maxIterations; superstep++) { // set the input to the current partial solution this.intermediateResults.put(iteration.getPartialSolution(), currentResult); // set the superstep number iterationSuperstep = superstep; // grab the current iteration result currentResult = (List<T>) execute(iteration.getNextPartialSolution(), superstep); // evaluate the termination criterion if (iteration.getTerminationCriterion() != null) { execute(iteration.getTerminationCriterion(), superstep); } // evaluate the aggregator convergence criterion if (convCriterion != null && convCriterionAggName != null) { Value v = aggregators.get(convCriterionAggName).getAggregate(); if (convCriterion.isConverged(superstep, v)) { break; } } // clear the dynamic results for (Operator<?> o : dynamics) { intermediateResults.remove(o); } // set the previous iteration's aggregates and reset the aggregators for (Map.Entry<String, Aggregator<?>> e : aggregators.entrySet()) { previousAggregates.put(e.getKey(), e.getValue().getAggregate()); e.getValue().reset(); } } previousAggregates.clear(); aggregators.clear(); return currentResult; } @SuppressWarnings("unchecked") private <T> List<T> executeDeltaIteration(DeltaIterationBase<?, ?> iteration) throws Exception { Operator<?> solutionInput = iteration.getInitialSolutionSet(); Operator<?> worksetInput = iteration.getInitialWorkset(); if (solutionInput == null) { throw new InvalidProgramException("The delta iteration " + iteration.getName() + " has no initial solution set."); } if (worksetInput == null) { throw new InvalidProgramException("The delta iteration " + iteration.getName() + " has no initial workset."); } if (iteration.getSolutionSetDelta() == null) { throw new InvalidProgramException("The iteration " + iteration.getName() + " has no solution set delta defined (is not closed)."); } if (iteration.getNextWorkset() == null) { throw new InvalidProgramException("The iteration " + iteration.getName() + " has no workset defined (is not closed)."); } List<T> solutionInputData = (List<T>) execute(solutionInput); List<T> worksetInputData = (List<T>) execute(worksetInput); // get the operators that are iterative Set<Operator<?>> dynamics = new LinkedHashSet<Operator<?>>(); DynamicPathCollector dynCollector = new DynamicPathCollector(dynamics); iteration.getSolutionSetDelta().accept(dynCollector); iteration.getNextWorkset().accept(dynCollector); BinaryOperatorInformation<?, ?, ?> operatorInfo = iteration.getOperatorInfo(); TypeInformation<?> solutionType = operatorInfo.getFirstInputType(); int[] keyColumns = iteration.getSolutionSetKeyFields(); boolean[] inputOrderings = new boolean[keyColumns.length]; TypeComparator<T> inputComparator = ((CompositeType<T>) solutionType).createComparator(keyColumns, inputOrderings, 0, executionConfig); Map<TypeComparable<T>, T> solutionMap = new HashMap<TypeComparable<T>, T>(solutionInputData.size()); // fill the solution from the initial input for (T delta: solutionInputData) { TypeComparable<T> wrapper = new TypeComparable<T>(delta, inputComparator); solutionMap.put(wrapper, delta); } List<?> currentWorkset = worksetInputData; // register the aggregators for (AggregatorWithName<?> a : iteration.getAggregators().getAllRegisteredAggregators()) { aggregators.put(a.getName(), a.getAggregator()); } String convCriterionAggName = iteration.getAggregators().getConvergenceCriterionAggregatorName(); ConvergenceCriterion<Value> convCriterion = (ConvergenceCriterion<Value>) iteration.getAggregators().getConvergenceCriterion(); final int maxIterations = iteration.getMaximumNumberOfIterations(); for (int superstep = 1; superstep <= maxIterations; superstep++) { List<T> currentSolution = new ArrayList<T>(solutionMap.size()); currentSolution.addAll(solutionMap.values()); // set the input to the current partial solution this.intermediateResults.put(iteration.getSolutionSet(), currentSolution); this.intermediateResults.put(iteration.getWorkset(), currentWorkset); // set the superstep number iterationSuperstep = superstep; // grab the current iteration result List<T> solutionSetDelta = (List<T>) execute(iteration.getSolutionSetDelta(), superstep); this.intermediateResults.put(iteration.getSolutionSetDelta(), solutionSetDelta); // update the solution for (T delta: solutionSetDelta) { TypeComparable<T> wrapper = new TypeComparable<T>(delta, inputComparator); solutionMap.put(wrapper, delta); } currentWorkset = execute(iteration.getNextWorkset(), superstep); if (currentWorkset.isEmpty()) { break; } // evaluate the aggregator convergence criterion if (convCriterion != null && convCriterionAggName != null) { Value v = aggregators.get(convCriterionAggName).getAggregate(); if (convCriterion.isConverged(superstep, v)) { break; } } // clear the dynamic results for (Operator<?> o : dynamics) { intermediateResults.remove(o); } // set the previous iteration's aggregates and reset the aggregators for (Map.Entry<String, Aggregator<?>> e : aggregators.entrySet()) { previousAggregates.put(e.getKey(), e.getValue().getAggregate()); e.getValue().reset(); } } previousAggregates.clear(); aggregators.clear(); List<T> currentSolution = new ArrayList<T>(solutionMap.size()); currentSolution.addAll(solutionMap.values()); return currentSolution; } // -------------------------------------------------------------------------------------------- // -------------------------------------------------------------------------------------------- private static final class DynamicPathCollector implements Visitor<Operator<?>> { private final Set<Operator<?>> visited = new HashSet<Operator<?>>(); private final Set<Operator<?>> dynamicPathOperations; public DynamicPathCollector(Set<Operator<?>> dynamicPathOperations) { this.dynamicPathOperations = dynamicPathOperations; } @Override public boolean preVisit(Operator<?> op) { return visited.add(op); } @Override public void postVisit(Operator<?> op) { if (op instanceof SingleInputOperator) { SingleInputOperator<?, ?, ?> siop = (SingleInputOperator<?, ?, ?>) op; if (dynamicPathOperations.contains(siop.getInput())) { dynamicPathOperations.add(op); } else { for (Operator<?> o : siop.getBroadcastInputs().values()) { if (dynamicPathOperations.contains(o)) { dynamicPathOperations.add(op); break; } } } } else if (op instanceof DualInputOperator) { DualInputOperator<?, ?, ?, ?> siop = (DualInputOperator<?, ?, ?, ?>) op; if (dynamicPathOperations.contains(siop.getFirstInput())) { dynamicPathOperations.add(op); } else if (dynamicPathOperations.contains(siop.getSecondInput())) { dynamicPathOperations.add(op); } else { for (Operator<?> o : siop.getBroadcastInputs().values()) { if (dynamicPathOperations.contains(o)) { dynamicPathOperations.add(op); break; } } } } else if (op.getClass() == PartialSolutionPlaceHolder.class || op.getClass() == WorksetPlaceHolder.class || op.getClass() == SolutionSetPlaceHolder.class) { dynamicPathOperations.add(op); } else if (op instanceof GenericDataSourceBase) { // skip } else { throw new RuntimeException("Cannot handle operator type " + op.getClass().getName()); } } } private class IterationRuntimeUDFContext extends RuntimeUDFContext implements IterationRuntimeContext { public IterationRuntimeUDFContext(TaskInfo taskInfo, ClassLoader classloader, ExecutionConfig executionConfig, Map<String, Future<Path>> cpTasks, Map<String, Accumulator<?, ?>> accumulators, MetricGroup metrics) { super(taskInfo, classloader, executionConfig, cpTasks, accumulators, metrics); } @Override public int getSuperstepNumber() { return iterationSuperstep; } @SuppressWarnings("unchecked") @Override public <T extends Aggregator<?>> T getIterationAggregator(String name) { return (T) aggregators.get(name); } @SuppressWarnings("unchecked") @Override public <T extends Value> T getPreviousIterationAggregate(String name) { return (T) previousAggregates.get(name); } } private static final class CompletedFuture implements Future<Path>{ private final Path result; public CompletedFuture(Path entry) { try{ LocalFileSystem fs = (LocalFileSystem) FileSystem.getUnguardedFileSystem(entry.toUri()); result = entry.isAbsolute() ? new Path(entry.toUri().getPath()): new Path(fs.getWorkingDirectory(),entry); } catch (Exception e){ throw new RuntimeException("DistributedCache supports only local files for Collection Environments"); } } @Override public boolean cancel(boolean mayInterruptIfRunning) { return false; } @Override public boolean isCancelled() { return false; } @Override public boolean isDone() { return true; } @Override public Path get() throws InterruptedException, ExecutionException { return result; } @Override public Path get(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException { return get(); } } }