/* * Copyright © 2015-2016 Cask Data, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); you may not * use this file except in compliance with the License. You may obtain a copy of * the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations under * the License. */ package co.cask.cdap.etl.batch.spark; import co.cask.cdap.api.TxRunnable; import co.cask.cdap.api.data.DatasetContext; import co.cask.cdap.api.dataset.lib.KeyValue; import co.cask.cdap.api.metrics.Metrics; import co.cask.cdap.api.plugin.PluginContext; import co.cask.cdap.api.spark.JavaSparkExecutionContext; import co.cask.cdap.api.spark.JavaSparkMain; import co.cask.cdap.etl.api.Transform; import co.cask.cdap.etl.api.batch.BatchAggregator; import co.cask.cdap.etl.api.batch.SparkCompute; import co.cask.cdap.etl.api.batch.SparkSink; import co.cask.cdap.etl.batch.BatchPhaseSpec; import co.cask.cdap.etl.batch.PipelinePluginInstantiator; import co.cask.cdap.etl.batch.TransformExecutorFactory; import co.cask.cdap.etl.common.Constants; import co.cask.cdap.etl.common.PipelinePhase; import co.cask.cdap.etl.common.SetMultimapCodec; import co.cask.cdap.etl.common.TransformExecutor; import co.cask.cdap.etl.common.TransformResponse; import co.cask.cdap.etl.planner.StageInfo; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; import com.google.common.collect.SetMultimap; import com.google.gson.Gson; import com.google.gson.GsonBuilder; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.PairFlatMapFunction; import scala.Tuple2; import java.io.DataInputStream; import java.io.FileInputStream; import java.io.InputStream; import java.util.ArrayList; import java.util.Collection; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; import javax.annotation.Nullable; /** * Spark program to run an ETL pipeline. */ public class ETLSparkProgram implements JavaSparkMain, TxRunnable { private static final Gson GSON = new GsonBuilder() .registerTypeAdapter(SetMultimap.class, new SetMultimapCodec<>()).create(); private transient JavaSparkContext jsc; private transient JavaSparkExecutionContext sec; @Override public void run(final JavaSparkExecutionContext sec) throws Exception { this.jsc = new JavaSparkContext(); this.sec = sec; // Execution the whole pipeline in one long transaction. This is because the Spark execution // currently share the same contract and API as the MapReduce one. // The API need to expose DatasetContext, hence it needs to be exeucted inside a transaction sec.execute(this); } @Override public void run(DatasetContext datasetContext) throws Exception { BatchPhaseSpec phaseSpec = GSON.fromJson(sec.getSpecification().getProperty(Constants.PIPELINEID), BatchPhaseSpec.class); Set<StageInfo> aggregators = phaseSpec.getPhase().getStagesOfType(BatchAggregator.PLUGIN_TYPE); String aggregatorName = null; if (!aggregators.isEmpty()) { aggregatorName = aggregators.iterator().next().getName(); } SparkBatchSourceFactory sourceFactory; SparkBatchSinkFactory sinkFactory; Integer numPartitions; try (InputStream is = new FileInputStream(sec.getLocalizationContext().getLocalFile("ETLSpark.config"))) { sourceFactory = SparkBatchSourceFactory.deserialize(is); sinkFactory = SparkBatchSinkFactory.deserialize(is); numPartitions = new DataInputStream(is).readInt(); } JavaPairRDD<Object, Object> rdd = sourceFactory.createRDD(sec, jsc, Object.class, Object.class); JavaPairRDD<String, Object> resultRDD = doTransform(sec, jsc, datasetContext, phaseSpec, rdd, aggregatorName, numPartitions); Set<StageInfo> stagesOfTypeSparkSink = phaseSpec.getPhase().getStagesOfType(SparkSink.PLUGIN_TYPE); Set<String> namesOfTypeSparkSink = new HashSet<>(); for (StageInfo stageInfo : stagesOfTypeSparkSink) { namesOfTypeSparkSink.add(stageInfo.getName()); } for (final String sinkName : phaseSpec.getPhase().getSinks()) { JavaPairRDD<String, Object> filteredResultRDD = resultRDD.filter( new Function<Tuple2<String, Object>, Boolean>() { @Override public Boolean call(Tuple2<String, Object> v1) throws Exception { return v1._1().equals(sinkName); } }); if (namesOfTypeSparkSink.contains(sinkName)) { SparkSink sparkSink = sec.getPluginContext().newPluginInstance(sinkName); sparkSink.run(new BasicSparkExecutionPluginContext(sec, jsc, datasetContext, sinkName), filteredResultRDD.values()); } else { JavaPairRDD<Object, Object> sinkRDD = filteredResultRDD.flatMapToPair(new PairFlatMapFunction<Tuple2<String, Object>, Object, Object>() { @Override public Iterable<Tuple2<Object, Object>> call(Tuple2<String, Object> input) throws Exception { List<Tuple2<Object, Object>> result = new ArrayList<>(); KeyValue<Object, Object> keyValue = (KeyValue<Object, Object>) input._2(); result.add(new Tuple2<>(keyValue.getKey(), keyValue.getValue())); return result; } }); sinkFactory.writeFromRDD(sinkRDD, sec, sinkName, Object.class, Object.class); } } } private JavaPairRDD<String, Object> doTransform(JavaSparkExecutionContext sec, JavaSparkContext jsc, DatasetContext datasetContext, BatchPhaseSpec phaseSpec, JavaPairRDD<Object, Object> input, String aggregatorName, int numPartitions) throws Exception { Set<StageInfo> sparkComputes = phaseSpec.getPhase().getStagesOfType(SparkCompute.PLUGIN_TYPE); if (sparkComputes.isEmpty()) { // if this is not a phase with SparkCompute, do regular transform logic if (aggregatorName != null) { JavaPairRDD<Object, Object> preGroupRDD = input.flatMapToPair(new PreGroupFunction(sec, aggregatorName)); JavaPairRDD<Object, Iterable<Object>> groupedRDD = numPartitions < 0 ? preGroupRDD.groupByKey() : preGroupRDD.groupByKey(numPartitions); return groupedRDD.flatMapToPair(new MapFunction<Iterable<Object>>(sec, null, aggregatorName, false)).cache(); } else { return input.flatMapToPair(new MapFunction<>(sec, null, null, false)).cache(); } } // otherwise, special casing for SparkCompute type: // there should only be no other plugins of type Transform, because of how Smart Workflow breaks up the phases Set<StageInfo> stagesOfTypeTransform = phaseSpec.getPhase().getStagesOfType(Transform.PLUGIN_TYPE); Preconditions.checkArgument(stagesOfTypeTransform.isEmpty(), "Found non-empty set of transform plugins when expecting none: %s", stagesOfTypeTransform); // Smart Workflow should guarantee that only 1 SparkCompute exists per phase. This can be improved in the future // for efficiency. Preconditions.checkArgument(sparkComputes.size() == 1, "Expected only 1 SparkCompute: %s", sparkComputes); String sparkComputeName = Iterables.getOnlyElement(sparkComputes).getName(); Set<String> sourceStages = phaseSpec.getPhase().getSources(); Preconditions.checkArgument(sourceStages.size() == 1, "Expected only 1 source stage: %s", sourceStages); String sourceStageName = Iterables.getOnlyElement(sourceStages); Set<String> sourceNextStages = phaseSpec.getPhase().getStageOutputs(sourceStageName); Preconditions.checkArgument(sourceNextStages.size() == 1, "Expected only 1 stage after source stage: %s", sourceNextStages); Preconditions.checkArgument(sparkComputeName.equals(Iterables.getOnlyElement(sourceNextStages)), "Expected the single stage after the source stage to be the spark compute: %s", sparkComputeName); // phase starting from source to SparkCompute PipelinePhase sourcePhase = phaseSpec.getPhase().subsetTo(ImmutableSet.of(sparkComputeName)); String sourcePipelineStr = GSON.toJson(new BatchPhaseSpec(phaseSpec.getPhaseName(), sourcePhase, phaseSpec.getResources(), phaseSpec.isStageLoggingEnabled(), phaseSpec.getConnectorDatasets())); JavaPairRDD<String, Object> sourceTransformed = input.flatMapToPair(new MapFunction<>(sec, sourcePipelineStr, null, true)).cache(); SparkCompute sparkCompute = new PipelinePluginInstantiator(sec.getPluginContext(), phaseSpec).newPluginInstance(sparkComputeName); JavaRDD<Object> sparkComputed = sparkCompute.transform(new BasicSparkExecutionPluginContext(sec, jsc, datasetContext, sparkComputeName), sourceTransformed.values()); // phase starting from SparkCompute to sink(s) PipelinePhase sinkPhase = phaseSpec.getPhase().subsetFrom(ImmutableSet.of(sparkComputeName)); String sinkPipelineStr = GSON.toJson(new BatchPhaseSpec(phaseSpec.getPhaseName(), sinkPhase, phaseSpec.getResources(), phaseSpec.isStageLoggingEnabled(), phaseSpec.getConnectorDatasets())); JavaPairRDD<String, Object> sinkTransformedValues = sparkComputed.flatMapToPair(new SingleTypeRDDMapFunction(sec, sinkPipelineStr)).cache(); return sinkTransformedValues; } /** * Base function that knows how to set up a transform executor and run it. * Subclasses are responsible for massaging the output of the transform executor into the expected output, * and for configuring the transform executor with the right part of the pipeline. * * @param <IN> type of the input * @param <EXECUTOR_IN> type of the executor input * @param <KEY_OUT> type of the output key * @param <VAL_OUT> type of the output value */ public abstract static class TransformExecutorFunction<IN, EXECUTOR_IN, KEY_OUT, VAL_OUT> implements PairFlatMapFunction<IN, KEY_OUT, VAL_OUT> { protected final PluginContext pluginContext; protected final Metrics metrics; protected final long logicalStartTime; protected final Map<String, String> runtimeArgs; protected final String pipelineStr; private transient TransformExecutor<EXECUTOR_IN> transformExecutor; public TransformExecutorFunction(JavaSparkExecutionContext sec, @Nullable String pipelineStr) { this.pluginContext = sec.getPluginContext(); this.metrics = sec.getMetrics(); this.logicalStartTime = sec.getLogicalStartTime(); this.runtimeArgs = sec.getRuntimeArguments(); this.pipelineStr = pipelineStr != null ? pipelineStr : sec.getSpecification().getProperty(Constants.PIPELINEID); } @Override public Iterable<Tuple2<KEY_OUT, VAL_OUT>> call(IN input) throws Exception { if (transformExecutor == null) { // TODO: There is no way to call destroy() method on Transform // In fact, we can structure transform in a way that it doesn't need destroy // All current usage of destroy() in transform is actually for Source/Sink, which is actually // better do it in prepareRun and onRunFinish, which happen outside of the Job execution (true for both // Spark and MapReduce). BatchPhaseSpec phaseSpec = GSON.fromJson(pipelineStr, BatchPhaseSpec.class); PipelinePluginInstantiator pluginInstantiator = new PipelinePluginInstantiator(pluginContext, phaseSpec); transformExecutor = initialize(phaseSpec, pluginInstantiator); } TransformResponse response = transformExecutor.runOneIteration(computeInputForExecutor(input)); Iterable<Tuple2<KEY_OUT, VAL_OUT>> output = getOutput(response); transformExecutor.resetEmitter(); return output; } protected abstract Iterable<Tuple2<KEY_OUT, VAL_OUT>> getOutput(TransformResponse transformResponse); protected abstract TransformExecutor<EXECUTOR_IN> initialize( BatchPhaseSpec phaseSpec, PipelinePluginInstantiator pluginInstantiator) throws Exception; protected abstract EXECUTOR_IN computeInputForExecutor(IN input); } /** * Performs all transforms before an aggregator plugin. Outputs tuples whose keys are the group key and values * are the group values that result by calling the aggregator's groupBy method. */ public static final class PreGroupFunction extends TransformExecutorFunction<Tuple2<Object, Object>, KeyValue<Object, Object>, Object, Object> { private final String aggregatorName; public PreGroupFunction(JavaSparkExecutionContext sec, @Nullable String aggregatorName) { super(sec, null); this.aggregatorName = aggregatorName; } @Override protected Iterable<Tuple2<Object, Object>> getOutput(TransformResponse transformResponse) { List<Tuple2<Object, Object>> result = new ArrayList<>(); for (Map.Entry<String, Collection<Object>> transformedEntry : transformResponse.getSinksResults().entrySet()) { for (Object output : transformedEntry.getValue()) { result.add((Tuple2<Object, Object>) output); } } return result; } @Override protected TransformExecutor<KeyValue<Object, Object>> initialize(BatchPhaseSpec phaseSpec, PipelinePluginInstantiator pluginInstantiator) throws Exception { TransformExecutorFactory<KeyValue<Object, Object>> transformExecutorFactory = new SparkTransformExecutorFactory<>(pluginContext, pluginInstantiator, metrics, logicalStartTime, runtimeArgs, true); PipelinePhase pipelinePhase = phaseSpec.getPhase().subsetTo(ImmutableSet.of(aggregatorName)); return transformExecutorFactory.create(pipelinePhase); } @Override protected KeyValue<Object, Object> computeInputForExecutor(Tuple2<Object, Object> input) { return new KeyValue<>(input._1(), input._2()); } } /** * Performs all transforms that happen after an aggregator, or if there is no aggregator at all. * Outputs tuples whose first item is the name of the sink that is being written to, and second item is * the key-value that should be written to that sink * * @param <T> type of the map output value */ public static final class MapFunction<T> extends SingleTypeRDDMapFunction<Tuple2<Object, T>, KeyValue<Object, T>> { @Nullable private final String aggregatorName; private final boolean isBeforeBreak; public MapFunction(JavaSparkExecutionContext sec, String pipelineStr, String aggregatorName, boolean isBeforeBreak) { super(sec, pipelineStr); this.aggregatorName = aggregatorName; this.isBeforeBreak = isBeforeBreak; } @Override protected TransformExecutor<KeyValue<Object, T>> initialize(BatchPhaseSpec phaseSpec, PipelinePluginInstantiator pluginInstantiator) throws Exception { TransformExecutorFactory<KeyValue<Object, T>> transformExecutorFactory = new SparkTransformExecutorFactory<>(pluginContext, pluginInstantiator, metrics, logicalStartTime, runtimeArgs, isBeforeBreak); PipelinePhase pipelinePhase = phaseSpec.getPhase(); if (aggregatorName != null) { pipelinePhase = pipelinePhase.subsetFrom(ImmutableSet.of(aggregatorName)); } return transformExecutorFactory.create(pipelinePhase); } @Override protected KeyValue<Object, T> computeInputForExecutor(Tuple2<Object, T> input) { return new KeyValue<>(input._1(), input._2()); } } /** * Used for the transform after a SparkCompute. Otherwise, MapFunction only operates on RDD of JavaPairRDD. * In other words, it does not handle translation from Tuple to KeyValue, but directly sends the RDD type * to the TransformExecutor. * This allows operations on JavaRDD of single type. Handles no aggregation functionality, because it should not * be used in a phase with aggregations. * * @param <IN> type of the input * @param <EXECUTOR_IN> type of the input to the executor */ public static class SingleTypeRDDMapFunction<IN, EXECUTOR_IN> extends TransformExecutorFunction<IN, EXECUTOR_IN, String, Object> { public SingleTypeRDDMapFunction(JavaSparkExecutionContext sec, String pipelineStr) { super(sec, pipelineStr); } @Override protected Iterable<Tuple2<String, Object>> getOutput(TransformResponse transformResponse) { List<Tuple2<String, Object>> result = new ArrayList<>(); for (Map.Entry<String, Collection<Object>> transformedEntry : transformResponse.getSinksResults().entrySet()) { String sinkName = transformedEntry.getKey(); for (Object outputRecord : transformedEntry.getValue()) { result.add(new Tuple2<>(sinkName, outputRecord)); } } return result; } @Override protected TransformExecutor<EXECUTOR_IN> initialize(BatchPhaseSpec phaseSpec, PipelinePluginInstantiator pluginInstantiator) throws Exception { TransformExecutorFactory<EXECUTOR_IN> transformExecutorFactory = new SparkTransformExecutorFactory<>(pluginContext, pluginInstantiator, metrics, logicalStartTime, runtimeArgs, false); PipelinePhase pipelinePhase = phaseSpec.getPhase(); return transformExecutorFactory.create(pipelinePhase); } @Override protected EXECUTOR_IN computeInputForExecutor(IN input) { // by default, have IN same as EXECUTOR_IN return (EXECUTOR_IN) input; } } }