/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.beam.runners.spark.translation; import static com.google.common.base.Preconditions.checkArgument; import com.google.common.collect.Iterables; import java.util.HashMap; import java.util.LinkedHashMap; import java.util.LinkedHashSet; import java.util.Map; import java.util.Set; import org.apache.beam.runners.spark.SparkPipelineOptions; import org.apache.beam.runners.spark.coders.CoderHelpers; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.runners.AppliedPTransform; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.PValue; import org.apache.beam.sdk.values.TupleTag; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.streaming.api.java.JavaStreamingContext; /** * The EvaluationContext allows us to define pipeline instructions and translate between * {@code PObject<T>}s or {@code PCollection<T>}s and Ts or DStreams/RDDs of Ts. */ public class EvaluationContext { private final JavaSparkContext jsc; private JavaStreamingContext jssc; private final SparkRuntimeContext runtime; private final Pipeline pipeline; private final Map<PValue, Dataset> datasets = new LinkedHashMap<>(); private final Map<PValue, Dataset> pcollections = new LinkedHashMap<>(); private final Set<Dataset> leaves = new LinkedHashSet<>(); private final Map<PValue, Object> pobjects = new LinkedHashMap<>(); private AppliedPTransform<?, ?, ?> currentTransform; private final SparkPCollectionView pviews = new SparkPCollectionView(); private final Map<PCollection, Long> cacheCandidates = new HashMap<>(); private final PipelineOptions options; public EvaluationContext(JavaSparkContext jsc, Pipeline pipeline, PipelineOptions options) { this.jsc = jsc; this.pipeline = pipeline; this.options = options; this.runtime = new SparkRuntimeContext(pipeline, options); } public EvaluationContext( JavaSparkContext jsc, Pipeline pipeline, PipelineOptions options, JavaStreamingContext jssc) { this(jsc, pipeline, options); this.jssc = jssc; } public JavaSparkContext getSparkContext() { return jsc; } public JavaStreamingContext getStreamingContext() { return jssc; } public Pipeline getPipeline() { return pipeline; } public PipelineOptions getOptions() { return options; } public SparkRuntimeContext getRuntimeContext() { return runtime; } public void setCurrentTransform(AppliedPTransform<?, ?, ?> transform) { this.currentTransform = transform; } public AppliedPTransform<?, ?, ?> getCurrentTransform() { return currentTransform; } public <T extends PValue> T getInput(PTransform<T, ?> transform) { @SuppressWarnings("unchecked") T input = (T) Iterables.getOnlyElement(getInputs(transform).values()); return input; } public <T> Map<TupleTag<?>, PValue> getInputs(PTransform<?, ?> transform) { checkArgument(currentTransform != null && currentTransform.getTransform() == transform, "can only be called with current transform"); return currentTransform.getInputs(); } public <T extends PValue> T getOutput(PTransform<?, T> transform) { @SuppressWarnings("unchecked") T output = (T) Iterables.getOnlyElement(getOutputs(transform).values()); return output; } public Map<TupleTag<?>, PValue> getOutputs(PTransform<?, ?> transform) { checkArgument(currentTransform != null && currentTransform.getTransform() == transform, "can only be called with current transform"); return currentTransform.getOutputs(); } private boolean shouldCache(PValue pvalue) { if ((pvalue instanceof PCollection) && cacheCandidates.containsKey(pvalue) && cacheCandidates.get(pvalue) > 1) { return true; } return false; } public void putDataset(PTransform<?, ? extends PValue> transform, Dataset dataset) { putDataset(getOutput(transform), dataset); } public void putDataset(PValue pvalue, Dataset dataset) { try { dataset.setName(pvalue.getName()); } catch (IllegalStateException e) { // name not set, ignore } if (shouldCache(pvalue)) { dataset.cache(storageLevel()); } datasets.put(pvalue, dataset); leaves.add(dataset); } <T> void putBoundedDatasetFromValues( PTransform<?, ? extends PValue> transform, Iterable<T> values, Coder<T> coder) { PValue output = getOutput(transform); if (shouldCache(output)) { // eagerly create the RDD, as it will be reused. Iterable<WindowedValue<T>> elems = Iterables.transform(values, WindowingHelpers.<T>windowValueFunction()); WindowedValue.ValueOnlyWindowedValueCoder<T> windowCoder = WindowedValue.getValueOnlyCoder(coder); JavaRDD<WindowedValue<T>> rdd = getSparkContext().parallelize(CoderHelpers.toByteArrays(elems, windowCoder)) .map(CoderHelpers.fromByteFunction(windowCoder)); putDataset(transform, new BoundedDataset<>(rdd)); } else { // create a BoundedDataset that would create a RDD on demand datasets.put(getOutput(transform), new BoundedDataset<>(values, jsc, coder)); } } public Dataset borrowDataset(PTransform<? extends PValue, ?> transform) { return borrowDataset(getInput(transform)); } public Dataset borrowDataset(PValue pvalue) { Dataset dataset = datasets.get(pvalue); leaves.remove(dataset); return dataset; } /** * Computes the outputs for all RDDs that are leaves in the DAG and do not have any actions (like * saving to a file) registered on them (i.e. they are performed for side effects). */ public void computeOutputs() { for (Dataset dataset : leaves) { dataset.action(); // force computation. } } /** * Retrieve an object of Type T associated with the PValue passed in. * * @param value PValue to retrieve associated data for. * @param <T> Type of object to return. * @return Native object. */ @SuppressWarnings("unchecked") public <T> T get(PValue value) { if (pobjects.containsKey(value)) { T result = (T) pobjects.get(value); return result; } if (pcollections.containsKey(value)) { JavaRDD<?> rdd = ((BoundedDataset) pcollections.get(value)).getRDD(); T res = (T) Iterables.getOnlyElement(rdd.collect()); pobjects.put(value, res); return res; } throw new IllegalStateException("Cannot resolve un-known PObject: " + value); } /** * Retrun the current views creates in the pipepline. * * @return SparkPCollectionView */ public SparkPCollectionView getPViews() { return pviews; } /** * Adds/Replaces a view to the current views creates in the pipepline. * * @param view - Identifier of the view * @param value - Actual value of the view * @param coder - Coder of the value */ public void putPView( PCollectionView<?> view, Iterable<WindowedValue<?>> value, Coder<Iterable<WindowedValue<?>>> coder) { pviews.putPView(view, value, coder); } /** * Get the map of cache candidates hold by the evaluation context. * * @return The current {@link Map} of cache candidates. */ public Map<PCollection, Long> getCacheCandidates() { return this.cacheCandidates; } <T> Iterable<WindowedValue<T>> getWindowedValues(PCollection<T> pcollection) { @SuppressWarnings("unchecked") BoundedDataset<T> boundedDataset = (BoundedDataset<T>) datasets.get(pcollection); leaves.remove(boundedDataset); return boundedDataset.getValues(pcollection); } private String storageLevel() { return runtime.getPipelineOptions().as(SparkPipelineOptions.class).getStorageLevel(); } }