/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.beam.sdk.transforms; import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; import com.google.common.base.Function; import com.google.common.base.MoreObjects; import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import org.apache.beam.sdk.annotations.Experimental; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.state.Timer; import org.apache.beam.sdk.transforms.DoFn.OnTimerContext; import org.apache.beam.sdk.transforms.reflect.DoFnInvoker; import org.apache.beam.sdk.transforms.reflect.DoFnInvokers; import org.apache.beam.sdk.transforms.reflect.DoFnSignature; import org.apache.beam.sdk.transforms.reflect.DoFnSignatures; import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.GlobalWindow; import org.apache.beam.sdk.transforms.windowing.PaneInfo; import org.apache.beam.sdk.util.SerializableUtils; import org.apache.beam.sdk.util.UserCodeException; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.TimestampedValue; import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.ValueInSingleWindow; import org.joda.time.Instant; /** * A harness for unit-testing a {@link DoFn}. * * <p>For example: * * <pre> {@code * DoFn<InputT, OutputT> fn = ...; * * DoFnTester<InputT, OutputT> fnTester = DoFnTester.of(fn); * * // Set arguments shared across all bundles: * fnTester.setSideInputs(...); // If fn takes side inputs. * fnTester.setOutputTags(...); // If fn writes to more than one output. * * // Process a bundle containing a single input element: * Input testInput = ...; * List<OutputT> testOutputs = fnTester.processBundle(testInput); * Assert.assertThat(testOutputs, Matchers.hasItems(...)); * * // Process a bigger bundle: * Assert.assertThat(fnTester.processBundle(i1, i2, ...), Matchers.hasItems(...)); * } </pre> * * @param <InputT> the type of the {@link DoFn}'s (main) input elements * @param <OutputT> the type of the {@link DoFn}'s (main) output elements */ public class DoFnTester<InputT, OutputT> implements AutoCloseable { /** * Returns a {@code DoFnTester} supporting unit-testing of the given * {@link DoFn}. By default, uses {@link CloningBehavior#CLONE_ONCE}. * * <p>The only supported extra parameter of the {@link DoFn.ProcessElement} method is * {@link BoundedWindow}. */ @SuppressWarnings("unchecked") public static <InputT, OutputT> DoFnTester<InputT, OutputT> of(DoFn<InputT, OutputT> fn) { checkNotNull(fn, "fn can't be null"); return new DoFnTester<>(fn); } /** * Registers the tuple of values of the side input {@link PCollectionView}s to * pass to the {@link DoFn} under test. * * <p>Resets the state of this {@link DoFnTester}. * * <p>If this isn't called, {@code DoFnTester} assumes the * {@link DoFn} takes no side inputs. */ public void setSideInputs(Map<PCollectionView<?>, Map<BoundedWindow, ?>> sideInputs) { checkState( state == State.UNINITIALIZED, "Can't add side inputs: DoFnTester is already initialized, in state %s", state); this.sideInputs = sideInputs; } /** * Registers the values of a side input {@link PCollectionView} to pass to the {@link DoFn} * under test. * * <p>The provided value is the final value of the side input in the specified window, not * the value of the input PCollection in that window. * * <p>If this isn't called, {@code DoFnTester} will return the default value for any side input * that is used. */ public <T> void setSideInput(PCollectionView<T> sideInput, BoundedWindow window, T value) { checkState( state == State.UNINITIALIZED, "Can't add side inputs: DoFnTester is already initialized, in state %s", state); Map<BoundedWindow, T> windowValues = (Map<BoundedWindow, T>) sideInputs.get(sideInput); if (windowValues == null) { windowValues = new HashMap<>(); sideInputs.put(sideInput, windowValues); } windowValues.put(window, value); } public PipelineOptions getPipelineOptions() { return options; } /** * When a {@link DoFnTester} should clone the {@link DoFn} under test and how it should manage * the lifecycle of the {@link DoFn}. */ public enum CloningBehavior { /** * Clone the {@link DoFn} and call {@link DoFn.Setup} every time a bundle starts; call {@link * DoFn.Teardown} every time a bundle finishes. */ CLONE_PER_BUNDLE, /** * Clone the {@link DoFn} and call {@link DoFn.Setup} on the first access; call {@link * DoFn.Teardown} only explicitly. */ CLONE_ONCE, /** * Do not clone the {@link DoFn}; call {@link DoFn.Setup} on the first access; call {@link * DoFn.Teardown} only explicitly. */ DO_NOT_CLONE } /** * Instruct this {@link DoFnTester} whether or not to clone the {@link DoFn} under test. */ public void setCloningBehavior(CloningBehavior newValue) { checkState(state == State.UNINITIALIZED, "Wrong state: %s", state); this.cloningBehavior = newValue; } /** * Indicates whether this {@link DoFnTester} will clone the {@link DoFn} under test. */ public CloningBehavior getCloningBehavior() { return cloningBehavior; } /** * A convenience operation that first calls {@link #startBundle}, * then calls {@link #processElement} on each of the input elements, then * calls {@link #finishBundle}, then returns the result of * {@link #takeOutputElements}. */ public List<OutputT> processBundle(Iterable <? extends InputT> inputElements) throws Exception { startBundle(); for (InputT inputElement : inputElements) { processElement(inputElement); } finishBundle(); return takeOutputElements(); } /** * A convenience method for testing {@link DoFn DoFns} with bundles of elements. * Logic proceeds as follows: * * <ol> * <li>Calls {@link #startBundle}.</li> * <li>Calls {@link #processElement} on each of the arguments.</li> * <li>Calls {@link #finishBundle}.</li> * <li>Returns the result of {@link #takeOutputElements}.</li> * </ol> */ @SafeVarargs public final List<OutputT> processBundle(InputT... inputElements) throws Exception { return processBundle(Arrays.asList(inputElements)); } /** * Calls the {@link DoFn.StartBundle} method on the {@link DoFn} under test. * * <p>If needed, first creates a fresh instance of the {@link DoFn} under test and calls * {@link DoFn.Setup}. */ public void startBundle() throws Exception { checkState( state == State.UNINITIALIZED || state == State.BUNDLE_FINISHED, "Wrong state during startBundle: %s", state); if (state == State.UNINITIALIZED) { initializeState(); } try { fnInvoker.invokeStartBundle(new TestStartBundleContext()); } catch (UserCodeException e) { unwrapUserCodeException(e); } state = State.BUNDLE_STARTED; } private static void unwrapUserCodeException(UserCodeException e) throws Exception { if (e.getCause() instanceof Exception) { throw (Exception) e.getCause(); } else if (e.getCause() instanceof Error) { throw (Error) e.getCause(); } else { throw e; } } /** * Calls the {@link DoFn.ProcessElement} method on the {@link DoFn} under test, in a * context where {@link DoFn.ProcessContext#element} returns the * given element and the element is in the global window. * * <p>Will call {@link #startBundle} automatically, if it hasn't * already been called. * * @throws IllegalStateException if the {@code DoFn} under test has already * been finished */ public void processElement(InputT element) throws Exception { processTimestampedElement(TimestampedValue.atMinimumTimestamp(element)); } /** * Calls {@link DoFn.ProcessElement} on the {@code DoFn} under test, in a * context where {@link DoFn.ProcessContext#element} returns the * given element and timestamp and the element is in the global window. * * <p>Will call {@link #startBundle} automatically, if it hasn't * already been called. */ public void processTimestampedElement(TimestampedValue<InputT> element) throws Exception { checkNotNull(element, "Timestamped element cannot be null"); processWindowedElement( element.getValue(), element.getTimestamp(), GlobalWindow.INSTANCE); } /** * Calls {@link DoFn.ProcessElement} on the {@code DoFn} under test, in a * context where {@link DoFn.ProcessContext#element} returns the * given element and timestamp and the element is in the given window. * * <p>Will call {@link #startBundle} automatically, if it hasn't * already been called. */ public void processWindowedElement( InputT element, Instant timestamp, final BoundedWindow window) throws Exception { if (state != State.BUNDLE_STARTED) { startBundle(); } try { final DoFn<InputT, OutputT>.ProcessContext processContext = createProcessContext( ValueInSingleWindow.of(element, timestamp, window, PaneInfo.NO_FIRING)); fnInvoker.invokeProcessElement( new DoFnInvoker.ArgumentProvider<InputT, OutputT>() { @Override public BoundedWindow window() { return window; } @Override public DoFn<InputT, OutputT>.StartBundleContext startBundleContext( DoFn<InputT, OutputT> doFn) { throw new UnsupportedOperationException( "Not expected to access DoFn.StartBundleContext from @ProcessElement"); } @Override public DoFn<InputT, OutputT>.FinishBundleContext finishBundleContext( DoFn<InputT, OutputT> doFn) { throw new UnsupportedOperationException( "Not expected to access DoFn.FinishBundleContext from @ProcessElement"); } @Override public DoFn<InputT, OutputT>.ProcessContext processContext(DoFn<InputT, OutputT> doFn) { return processContext; } @Override public OnTimerContext onTimerContext(DoFn<InputT, OutputT> doFn) { throw new UnsupportedOperationException("DoFnTester doesn't support timers yet."); } @Override public RestrictionTracker<?> restrictionTracker() { throw new UnsupportedOperationException( "Not expected to access RestrictionTracker from a regular DoFn in DoFnTester"); } @Override public org.apache.beam.sdk.state.State state(String stateId) { throw new UnsupportedOperationException("DoFnTester doesn't support state yet"); } @Override public Timer timer(String timerId) { throw new UnsupportedOperationException("DoFnTester doesn't support timers yet"); } }); } catch (UserCodeException e) { unwrapUserCodeException(e); } } /** * Calls the {@link DoFn.FinishBundle} method of the {@link DoFn} under test. * * <p>If {@link #setCloningBehavior} was called with {@link CloningBehavior#CLONE_PER_BUNDLE}, * then also calls {@link DoFn.Teardown} on the {@link DoFn}, and it will be cloned and * {@link DoFn.Setup} again when processing the next bundle. * * @throws IllegalStateException if {@link DoFn.FinishBundle} has already been called * for this bundle. */ public void finishBundle() throws Exception { checkState( state == State.BUNDLE_STARTED, "Must be inside bundle to call finishBundle, but was: %s", state); try { fnInvoker.invokeFinishBundle(new TestFinishBundleContext()); } catch (UserCodeException e) { unwrapUserCodeException(e); } if (cloningBehavior == CloningBehavior.CLONE_PER_BUNDLE) { fnInvoker.invokeTeardown(); fn = null; fnInvoker = null; state = State.UNINITIALIZED; } else { state = State.BUNDLE_FINISHED; } } /** * Returns the elements output so far to the main output. Does not * clear them, so subsequent calls will continue to include these * elements. * * @see #takeOutputElements * @see #clearOutputElements * */ public List<OutputT> peekOutputElements() { return Lists.transform( peekOutputElementsWithTimestamp(), new Function<TimestampedValue<OutputT>, OutputT>() { @Override @SuppressWarnings("unchecked") public OutputT apply(TimestampedValue<OutputT> input) { return input.getValue(); } }); } /** * Returns the elements output so far to the main output with associated timestamps. Does not * clear them, so subsequent calls will continue to include these. * elements. * * @see #takeOutputElementsWithTimestamp * @see #clearOutputElements */ @Experimental public List<TimestampedValue<OutputT>> peekOutputElementsWithTimestamp() { // TODO: Should we return an unmodifiable list? return Lists.transform(getImmutableOutput(mainOutputTag), new Function<ValueInSingleWindow<OutputT>, TimestampedValue<OutputT>>() { @Override @SuppressWarnings("unchecked") public TimestampedValue<OutputT> apply(ValueInSingleWindow<OutputT> input) { return TimestampedValue.of(input.getValue(), input.getTimestamp()); } }); } /** * Returns the elements output so far to the main output in the provided window with associated * timestamps. */ public List<TimestampedValue<OutputT>> peekOutputElementsInWindow(BoundedWindow window) { return peekOutputElementsInWindow(mainOutputTag, window); } /** * Returns the elements output so far to the specified output in the provided window with * associated timestamps. */ public List<TimestampedValue<OutputT>> peekOutputElementsInWindow( TupleTag<OutputT> tag, BoundedWindow window) { ImmutableList.Builder<TimestampedValue<OutputT>> valuesBuilder = ImmutableList.builder(); for (ValueInSingleWindow<OutputT> value : getImmutableOutput(tag)) { if (value.getWindow().equals(window)) { valuesBuilder.add(TimestampedValue.of(value.getValue(), value.getTimestamp())); } } return valuesBuilder.build(); } /** * Clears the record of the elements output so far to the main output. * * @see #peekOutputElements */ public void clearOutputElements() { getMutableOutput(mainOutputTag).clear(); } /** * Returns the elements output so far to the main output. * Clears the list so these elements don't appear in future calls. * * @see #peekOutputElements */ public List<OutputT> takeOutputElements() { List<OutputT> resultElems = new ArrayList<>(peekOutputElements()); clearOutputElements(); return resultElems; } /** * Returns the elements output so far to the main output with associated timestamps. * Clears the list so these elements don't appear in future calls. * * @see #peekOutputElementsWithTimestamp * @see #takeOutputElements * @see #clearOutputElements */ @Experimental public List<TimestampedValue<OutputT>> takeOutputElementsWithTimestamp() { List<TimestampedValue<OutputT>> resultElems = new ArrayList<>(peekOutputElementsWithTimestamp()); clearOutputElements(); return resultElems; } /** * Returns the elements output so far to the output with the * given tag. Does not clear them, so subsequent calls will * continue to include these elements. * * @see #takeOutputElements * @see #clearOutputElements */ public <T> List<T> peekOutputElements(TupleTag<T> tag) { // TODO: Should we return an unmodifiable list? return Lists.transform(getImmutableOutput(tag), new Function<ValueInSingleWindow<T>, T>() { @SuppressWarnings("unchecked") @Override public T apply(ValueInSingleWindow<T> input) { return input.getValue(); }}); } /** * Clears the record of the elements output so far to the output with the given tag. * * @see #peekOutputElements */ public <T> void clearOutputElements(TupleTag<T> tag) { getMutableOutput(tag).clear(); } /** * Returns the elements output so far to the output with the given tag. * Clears the list so these elements don't appear in future calls. * * @see #peekOutputElements */ public <T> List<T> takeOutputElements(TupleTag<T> tag) { List<T> resultElems = new ArrayList<>(peekOutputElements(tag)); clearOutputElements(tag); return resultElems; } private <T> List<ValueInSingleWindow<T>> getImmutableOutput(TupleTag<T> tag) { @SuppressWarnings({"unchecked", "rawtypes"}) List<ValueInSingleWindow<T>> elems = (List) outputs.get(tag); return ImmutableList.copyOf( MoreObjects.firstNonNull(elems, Collections.<ValueInSingleWindow<T>>emptyList())); } @SuppressWarnings({"unchecked", "rawtypes"}) public <T> List<ValueInSingleWindow<T>> getMutableOutput(TupleTag<T> tag) { List<ValueInSingleWindow<T>> outputList = (List) outputs.get(tag); if (outputList == null) { outputList = new ArrayList<>(); outputs.put(tag, (List) outputList); } return outputList; } public TupleTag<OutputT> getMainOutputTag() { return mainOutputTag; } private class TestStartBundleContext extends DoFn<InputT, OutputT>.StartBundleContext { private TestStartBundleContext() { fn.super(); } @Override public PipelineOptions getPipelineOptions() { return options; } } private class TestFinishBundleContext extends DoFn<InputT, OutputT>.FinishBundleContext { private TestFinishBundleContext() { fn.super(); } private void throwUnsupportedOutputFromBundleMethods() { throw new UnsupportedOperationException( "DoFnTester doesn't support output from bundle methods"); } @Override public PipelineOptions getPipelineOptions() { return options; } @Override public void output( OutputT output, Instant timestamp, BoundedWindow window) { throwUnsupportedOutputFromBundleMethods(); } @Override public <T> void output(TupleTag<T> tag, T output, Instant timestamp, BoundedWindow window) { throwUnsupportedOutputFromBundleMethods(); } } public DoFn<InputT, OutputT>.ProcessContext createProcessContext( ValueInSingleWindow<InputT> element) { return new TestProcessContext(element); } private class TestProcessContext extends DoFn<InputT, OutputT>.ProcessContext { private final ValueInSingleWindow<InputT> element; private TestProcessContext(ValueInSingleWindow<InputT> element) { fn.super(); this.element = element; } @Override public InputT element() { return element.getValue(); } @Override public <T> T sideInput(PCollectionView<T> view) { Map<BoundedWindow, ?> viewValues = sideInputs.get(view); if (viewValues != null) { BoundedWindow sideInputWindow = view.getWindowMappingFn() .getSideInputWindow(element.getWindow()); @SuppressWarnings("unchecked") T windowValue = (T) viewValues.get(sideInputWindow); if (windowValue != null) { return windowValue; } } return view.getViewFn().apply(Collections.<WindowedValue<?>>emptyList()); } @Override public Instant timestamp() { return element.getTimestamp(); } @Override public PaneInfo pane() { return element.getPane(); } @Override public void updateWatermark(Instant watermark) { throw new UnsupportedOperationException(); } @Override public PipelineOptions getPipelineOptions() { return options; } @Override public void output(OutputT output) { output(mainOutputTag, output); } @Override public void outputWithTimestamp(OutputT output, Instant timestamp) { outputWithTimestamp(mainOutputTag, output, timestamp); } @Override public <T> void output(TupleTag<T> tag, T output) { outputWithTimestamp(tag, output, element.getTimestamp()); } @Override public <T> void outputWithTimestamp(TupleTag<T> tag, T output, Instant timestamp) { getMutableOutput(tag) .add(ValueInSingleWindow.of(output, timestamp, element.getWindow(), element.getPane())); } private void throwUnsupportedOutputFromBundleMethods() { throw new UnsupportedOperationException( "DoFnTester doesn't support output from bundle methods"); } } @Override public void close() throws Exception { if (state == State.BUNDLE_STARTED) { finishBundle(); } if (state == State.BUNDLE_FINISHED) { fnInvoker.invokeTeardown(); fn = null; fnInvoker = null; } state = State.TORN_DOWN; } ///////////////////////////////////////////////////////////////////////////// /** The possible states of processing a {@link DoFn}. */ private enum State { UNINITIALIZED, BUNDLE_STARTED, BUNDLE_FINISHED, TORN_DOWN } private final PipelineOptions options = PipelineOptionsFactory.create(); /** The original {@link DoFn} under test. */ private final DoFn<InputT, OutputT> origFn; /** * Whether to clone the original {@link DoFn} or just use it as-is. * * <p>Worker-side {@link DoFn DoFns} may not be serializable, and are not required to be. */ private CloningBehavior cloningBehavior = CloningBehavior.CLONE_ONCE; /** The side input values to provide to the {@link DoFn} under test. */ private Map<PCollectionView<?>, Map<BoundedWindow, ?>> sideInputs = new HashMap<>(); /** The output tags used by the {@link DoFn} under test. */ private TupleTag<OutputT> mainOutputTag = new TupleTag<>(); /** The original DoFn under test, if started. */ private DoFn<InputT, OutputT> fn; private DoFnInvoker<InputT, OutputT> fnInvoker; /** The outputs from the {@link DoFn} under test. */ private Map<TupleTag<?>, List<ValueInSingleWindow<?>>> outputs; /** The state of processing of the {@link DoFn} under test. */ private State state = State.UNINITIALIZED; private DoFnTester(DoFn<InputT, OutputT> origFn) { this.origFn = origFn; DoFnSignature signature = DoFnSignatures.signatureForDoFn(origFn); for (DoFnSignature.Parameter param : signature.processElement().extraParameters()) { param.match( new DoFnSignature.Parameter.Cases.WithDefault<Void>() { @Override public Void dispatch(DoFnSignature.Parameter.ProcessContextParameter p) { // ProcessContext parameter is obviously supported. return null; } @Override public Void dispatch(DoFnSignature.Parameter.WindowParameter p) { // We also support the BoundedWindow parameter. return null; } @Override protected Void dispatchDefault(DoFnSignature.Parameter p) { throw new UnsupportedOperationException( "Parameter " + p + " not supported by DoFnTester"); } }); } } @SuppressWarnings("unchecked") private void initializeState() throws Exception { checkState(state == State.UNINITIALIZED, "Already initialized"); checkState(fn == null, "Uninitialized but fn != null"); if (cloningBehavior.equals(CloningBehavior.DO_NOT_CLONE)) { fn = origFn; } else { fn = (DoFn<InputT, OutputT>) SerializableUtils.deserializeFromByteArray( SerializableUtils.serializeToByteArray(origFn), origFn.toString()); } fnInvoker = DoFnInvokers.invokerFor(fn); fnInvoker.invokeSetup(); outputs = new HashMap<>(); } }