/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.beam.runners.spark.stateful; import com.google.common.collect.AbstractIterator; import com.google.common.collect.Table; import java.io.Serializable; import java.util.ArrayList; import java.util.Collection; import java.util.List; import org.apache.beam.runners.core.GroupAlsoByWindowsAggregators; import org.apache.beam.runners.core.GroupByKeyViaGroupByKeyOnly.GroupAlsoByWindow; import org.apache.beam.runners.core.LateDataUtils; import org.apache.beam.runners.core.OutputWindowedValue; import org.apache.beam.runners.core.ReduceFnRunner; import org.apache.beam.runners.core.SystemReduceFn; import org.apache.beam.runners.core.TimerInternals; import org.apache.beam.runners.core.UnsupportedSideInputReader; import org.apache.beam.runners.core.construction.Triggers; import org.apache.beam.runners.core.metrics.CounterCell; import org.apache.beam.runners.core.metrics.MetricsContainerImpl; import org.apache.beam.runners.core.triggers.ExecutableTriggerStateMachine; import org.apache.beam.runners.core.triggers.TriggerStateMachines; import org.apache.beam.runners.spark.SparkPipelineOptions; import org.apache.beam.runners.spark.coders.CoderHelpers; import org.apache.beam.runners.spark.translation.SparkRuntimeContext; import org.apache.beam.runners.spark.translation.TranslationUtils; import org.apache.beam.runners.spark.translation.WindowingHelpers; import org.apache.beam.runners.spark.util.ByteArray; import org.apache.beam.runners.spark.util.GlobalWatermarkHolder; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.IterableCoder; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.metrics.MetricName; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.PaneInfo; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.util.WindowedValue.FullWindowedValueCoder; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.WindowingStrategy; import org.apache.spark.Partitioner; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext$; import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.api.java.function.Function; import org.apache.spark.streaming.Duration; import org.apache.spark.streaming.api.java.JavaDStream; import org.apache.spark.streaming.api.java.JavaPairDStream; import org.apache.spark.streaming.dstream.DStream; import org.apache.spark.streaming.dstream.PairDStreamFunctions; import org.joda.time.Instant; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import scala.Function1; import scala.Option; import scala.Tuple2; import scala.Tuple3; import scala.collection.Seq; import scala.reflect.ClassTag; import scala.runtime.AbstractFunction1; /** * An implementation of {@link GroupAlsoByWindow} * logic for grouping by windows and controlling trigger firings and pane accumulation. * * <p>This implementation is a composite of Spark transformations revolving around state management * using Spark's * {@link PairDStreamFunctions#updateStateByKey(Function1, Partitioner, boolean, ClassTag)} * to update state with new data and timers. * * <p>Using updateStateByKey allows to scan through the entire state visiting not just the * updated state (new values for key) but also check if timers are ready to fire. * Since updateStateByKey bounds the types of state and output to be the same, * a (state, output) tuple is used, filtering the state (and output if no firing) * in the following steps. */ public class SparkGroupAlsoByWindowViaWindowSet { private static final Logger LOG = LoggerFactory.getLogger( SparkGroupAlsoByWindowViaWindowSet.class); /** * A helper class that is essentially a {@link Serializable} {@link AbstractFunction1}. */ private abstract static class SerializableFunction1<T1, T2> extends AbstractFunction1<T1, T2> implements Serializable { } public static <K, InputT, W extends BoundedWindow> JavaDStream<WindowedValue<KV<K, Iterable<InputT>>>> groupAlsoByWindow( JavaDStream<WindowedValue<KV<K, Iterable<WindowedValue<InputT>>>>> inputDStream, final Coder<K> keyCoder, final Coder<WindowedValue<InputT>> wvCoder, final WindowingStrategy<?, W> windowingStrategy, final SparkRuntimeContext runtimeContext, final List<Integer> sourceIds) { final IterableCoder<WindowedValue<InputT>> itrWvCoder = IterableCoder.of(wvCoder); final Coder<InputT> iCoder = ((FullWindowedValueCoder<InputT>) wvCoder).getValueCoder(); final Coder<? extends BoundedWindow> wCoder = ((FullWindowedValueCoder<InputT>) wvCoder).getWindowCoder(); final Coder<WindowedValue<KV<K, Iterable<InputT>>>> wvKvIterCoder = FullWindowedValueCoder.of(KvCoder.of(keyCoder, IterableCoder.of(iCoder)), wCoder); final TimerInternals.TimerDataCoder timerDataCoder = TimerInternals.TimerDataCoder.of(windowingStrategy.getWindowFn().windowCoder()); long checkpointDurationMillis = runtimeContext.getPipelineOptions().as(SparkPipelineOptions.class) .getCheckpointDurationMillis(); // we have to switch to Scala API to avoid Optional in the Java API, see: SPARK-4819. // we also have a broader API for Scala (access to the actual key and entire iterator). // we use coders to convert objects in the PCollection to byte arrays, so they // can be transferred over the network for the shuffle and be in serialized form // for checkpointing. // for readability, we add comments with actual type next to byte[]. // to shorten line length, we use: //---- WV: WindowedValue //---- Iterable: Itr //---- AccumT: A //---- InputT: I DStream<Tuple2</*K*/ ByteArray, /*Itr<WV<I>>*/ byte[]>> pairDStream = inputDStream .transformToPair( new Function< JavaRDD<WindowedValue<KV<K, Iterable<WindowedValue<InputT>>>>>, JavaPairRDD<ByteArray, byte[]>>() { // we use mapPartitions with the RDD API because its the only available API // that allows to preserve partitioning. @Override public JavaPairRDD<ByteArray, byte[]> call( JavaRDD<WindowedValue<KV<K, Iterable<WindowedValue<InputT>>>>> rdd) throws Exception { return rdd.mapPartitions( TranslationUtils.functionToFlatMapFunction( WindowingHelpers .<KV<K, Iterable<WindowedValue<InputT>>>>unwindowFunction()), true) .mapPartitionsToPair( TranslationUtils .<K, Iterable<WindowedValue<InputT>>>toPairFlatMapFunction(), true) // move to bytes representation and use coders for deserialization // because of checkpointing. .mapPartitionsToPair( TranslationUtils.pairFunctionToPairFlatMapFunction( CoderHelpers.toByteFunction(keyCoder, itrWvCoder)), true); } }) .dstream(); PairDStreamFunctions<ByteArray, byte[]> pairDStreamFunctions = DStream.toPairDStreamFunctions( pairDStream, JavaSparkContext$.MODULE$.<ByteArray>fakeClassTag(), JavaSparkContext$.MODULE$.<byte[]>fakeClassTag(), null); int defaultNumPartitions = pairDStreamFunctions.defaultPartitioner$default$1(); Partitioner partitioner = pairDStreamFunctions.defaultPartitioner(defaultNumPartitions); // use updateStateByKey to scan through the state and update elements and timers. DStream<Tuple2</*K*/ ByteArray, Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/ List<byte[]>>>> firedStream = pairDStreamFunctions.updateStateByKey( new SerializableFunction1< scala.collection.Iterator<Tuple3</*K*/ ByteArray, Seq</*Itr<WV<I>>*/ byte[]>, Option<Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/ List<byte[]>>>>>, scala.collection.Iterator<Tuple2</*K*/ ByteArray, Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/ List<byte[]>>>>>() { @Override public scala.collection.Iterator<Tuple2</*K*/ ByteArray, Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/ List<byte[]>>>> apply( final scala.collection.Iterator<Tuple3</*K*/ ByteArray, Seq</*Itr<WV<I>>*/ byte[]>, Option<Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/ List<byte[]>>>>> iter) { //--- ACTUAL STATEFUL OPERATION: // // Input Iterator: the partition (~bundle) of a cogrouping of the input // and the previous state (if exists). // // Output Iterator: the output key, and the updated state. // // possible input scenarios for (K, Seq, Option<S>): // (1) Option<S>.isEmpty: new data with no previous state. // (2) Seq.isEmpty: no new data, but evaluating previous state (timer-like behaviour). // (3) Seq.nonEmpty && Option<S>.isDefined: new data with previous state. final SystemReduceFn<K, InputT, Iterable<InputT>, Iterable<InputT>, W> reduceFn = SystemReduceFn.buffering( ((FullWindowedValueCoder<InputT>) wvCoder).getValueCoder()); final OutputWindowedValueHolder<K, InputT> outputHolder = new OutputWindowedValueHolder<>(); // use in memory Aggregators since Spark Accumulators are not resilient // in stateful operators, once done with this partition. final MetricsContainerImpl cellProvider = new MetricsContainerImpl("cellProvider"); final CounterCell droppedDueToClosedWindow = cellProvider.getCounter( MetricName.named(SparkGroupAlsoByWindowViaWindowSet.class, GroupAlsoByWindowsAggregators.DROPPED_DUE_TO_CLOSED_WINDOW_COUNTER)); final CounterCell droppedDueToLateness = cellProvider.getCounter( MetricName.named(SparkGroupAlsoByWindowViaWindowSet.class, GroupAlsoByWindowsAggregators.DROPPED_DUE_TO_LATENESS_COUNTER)); AbstractIterator< Tuple2</*K*/ ByteArray, Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/ List<byte[]>>>> outIter = new AbstractIterator<Tuple2</*K*/ ByteArray, Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/ List<byte[]>>>>() { @Override protected Tuple2</*K*/ ByteArray, Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/ List<byte[]>>> computeNext() { // input iterator is a Spark partition (~bundle), containing keys and their // (possibly) previous-state and (possibly) new data. while (iter.hasNext()) { // for each element in the partition: Tuple3<ByteArray, Seq<byte[]>, Option<Tuple2<StateAndTimers, List<byte[]>>>> next = iter.next(); ByteArray encodedKey = next._1(); K key = CoderHelpers.fromByteArray(encodedKey.getValue(), keyCoder); Seq<byte[]> seq = next._2(); Option<Tuple2<StateAndTimers, List<byte[]>>> prevStateAndTimersOpt = next._3(); SparkStateInternals<K> stateInternals; SparkTimerInternals timerInternals = SparkTimerInternals.forStreamFromSources( sourceIds, GlobalWatermarkHolder.get()); // get state(internals) per key. if (prevStateAndTimersOpt.isEmpty()) { // no previous state. stateInternals = SparkStateInternals.forKey(key); } else { // with pre-existing state. StateAndTimers prevStateAndTimers = prevStateAndTimersOpt.get()._1(); stateInternals = SparkStateInternals.forKeyAndState(key, prevStateAndTimers.getState()); Collection<byte[]> serTimers = prevStateAndTimers.getTimers(); timerInternals.addTimers( SparkTimerInternals.deserializeTimers(serTimers, timerDataCoder)); } ReduceFnRunner<K, InputT, Iterable<InputT>, W> reduceFnRunner = new ReduceFnRunner<>( key, windowingStrategy, ExecutableTriggerStateMachine.create( TriggerStateMachines.stateMachineForTrigger( Triggers.toProto(windowingStrategy.getTrigger()))), stateInternals, timerInternals, outputHolder, new UnsupportedSideInputReader("GroupAlsoByWindow"), reduceFn, runtimeContext.getPipelineOptions()); outputHolder.clear(); // clear before potential use. if (!seq.isEmpty()) { // new input for key. try { Iterable<WindowedValue<InputT>> elementsIterable = CoderHelpers.fromByteArray(seq.head(), itrWvCoder); Iterable<WindowedValue<InputT>> validElements = LateDataUtils .dropExpiredWindows( key, elementsIterable, timerInternals, windowingStrategy, droppedDueToLateness); reduceFnRunner.processElements(validElements); } catch (Exception e) { throw new RuntimeException( "Failed to process element with ReduceFnRunner", e); } } else if (stateInternals.getState().isEmpty()) { // no input and no state -> GC evict now. continue; } try { // advance the watermark to HWM to fire by timers. timerInternals.advanceWatermark(); // call on timers that are ready. reduceFnRunner.onTimers(timerInternals.getTimersReadyToProcess()); } catch (Exception e) { throw new RuntimeException( "Failed to process ReduceFnRunner onTimer.", e); } // this is mostly symbolic since actual persist is done by emitting output. reduceFnRunner.persist(); // obtain output, if fired. List<WindowedValue<KV<K, Iterable<InputT>>>> outputs = outputHolder.get(); if (!outputs.isEmpty() || !stateInternals.getState().isEmpty()) { StateAndTimers updated = new StateAndTimers(stateInternals.getState(), SparkTimerInternals.serializeTimers( timerInternals.getTimers(), timerDataCoder)); // persist Spark's state by outputting. List<byte[]> serOutput = CoderHelpers.toByteArrays(outputs, wvKvIterCoder); return new Tuple2<>(encodedKey, new Tuple2<>(updated, serOutput)); } // an empty state with no output, can be evicted completely - do nothing. } return endOfData(); } }; // log if there's something to log. long lateDropped = droppedDueToLateness.getCumulative(); if (lateDropped > 0) { LOG.info(String.format("Dropped %d elements due to lateness.", lateDropped)); droppedDueToLateness.inc(-droppedDueToLateness.getCumulative()); } long closedWindowDropped = droppedDueToClosedWindow.getCumulative(); if (closedWindowDropped > 0) { LOG.info(String.format("Dropped %d elements due to closed window.", closedWindowDropped)); droppedDueToClosedWindow.inc(-droppedDueToClosedWindow.getCumulative()); } return scala.collection.JavaConversions.asScalaIterator(outIter); } }, partitioner, true, JavaSparkContext$.MODULE$.<Tuple2<StateAndTimers, List<byte[]>>>fakeClassTag()); if (checkpointDurationMillis > 0) { firedStream.checkpoint(new Duration(checkpointDurationMillis)); } // go back to Java now. JavaPairDStream</*K*/ ByteArray, Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/ List<byte[]>>> javaFiredStream = JavaPairDStream.fromPairDStream( firedStream, JavaSparkContext$.MODULE$.<ByteArray>fakeClassTag(), JavaSparkContext$.MODULE$.<Tuple2<StateAndTimers, List<byte[]>>>fakeClassTag()); // filter state-only output (nothing to fire) and remove the state from the output. return javaFiredStream.filter( new Function<Tuple2</*K*/ ByteArray, Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/ List<byte[]>>>, Boolean>() { @Override public Boolean call( Tuple2</*K*/ ByteArray, Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/ List<byte[]>>> t2) throws Exception { // filter output if defined. return !t2._2()._2().isEmpty(); } }) .flatMap( new FlatMapFunction<Tuple2</*K*/ ByteArray, Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/ List<byte[]>>>, WindowedValue<KV<K, Iterable<InputT>>>>() { @Override public Iterable<WindowedValue<KV<K, Iterable<InputT>>>> call( Tuple2</*K*/ ByteArray, Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/ List<byte[]>>> t2) throws Exception { // drop the state since it is already persisted at this point. // return in serialized form. return CoderHelpers.fromByteArrays(t2._2()._2(), wvKvIterCoder); } }); } private static class StateAndTimers { //Serializable state for internals (namespace to state tag to coded value). private final Table<String, String, byte[]> state; private final Collection<byte[]> serTimers; private StateAndTimers( Table<String, String, byte[]> state, Collection<byte[]> timers) { this.state = state; this.serTimers = timers; } public Table<String, String, byte[]> getState() { return state; } public Collection<byte[]> getTimers() { return serTimers; } } private static class OutputWindowedValueHolder<K, V> implements OutputWindowedValue<KV<K, Iterable<V>>> { private List<WindowedValue<KV<K, Iterable<V>>>> windowedValues = new ArrayList<>(); @Override public void outputWindowedValue( KV<K, Iterable<V>> output, Instant timestamp, Collection<? extends BoundedWindow> windows, PaneInfo pane) { windowedValues.add(WindowedValue.of(output, timestamp, windows, pane)); } private List<WindowedValue<KV<K, Iterable<V>>>> get() { return windowedValues; } private void clear() { windowedValues.clear(); } @Override public <AdditionalOutputT> void outputWindowedValue( TupleTag<AdditionalOutputT> tag, AdditionalOutputT output, Instant timestamp, Collection<? extends BoundedWindow> windows, PaneInfo pane) { throw new UnsupportedOperationException( "Tagged outputs are not allowed in GroupAlsoByWindow."); } } }