/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.beam.sdk.util; import java.io.IOException; import java.io.NotSerializableException; import java.io.ObjectOutputStream; import org.apache.beam.sdk.coders.CannotProvideCoderException; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.CoderRegistry; import org.apache.beam.sdk.state.StateContext; import org.apache.beam.sdk.transforms.Combine.CombineFn; import org.apache.beam.sdk.transforms.CombineFnBase.GlobalCombineFn; import org.apache.beam.sdk.transforms.CombineWithContext.CombineFnWithContext; import org.apache.beam.sdk.transforms.CombineWithContext.Context; import org.apache.beam.sdk.transforms.display.DisplayData; /** * Static utility methods that create combine function instances. */ public class CombineFnUtil { /** * Returns the partial application of the {@link CombineFnWithContext} to a specific context * to produce a {@link CombineFn}. * * <p>The returned {@link CombineFn} cannot be serialized. */ public static <K, InputT, AccumT, OutputT> CombineFn<InputT, AccumT, OutputT> bindContext( CombineFnWithContext<InputT, AccumT, OutputT> combineFn, StateContext<?> stateContext) { Context context = CombineContextFactory.createFromStateContext(stateContext); return new NonSerializableBoundedCombineFn<>(combineFn, context); } /** * Return a {@link CombineFnWithContext} from the given {@link GlobalCombineFn}. */ public static <InputT, AccumT, OutputT> CombineFnWithContext<InputT, AccumT, OutputT> toFnWithContext( GlobalCombineFn<InputT, AccumT, OutputT> globalCombineFn) { if (globalCombineFn instanceof CombineFnWithContext) { @SuppressWarnings("unchecked") CombineFnWithContext<InputT, AccumT, OutputT> combineFnWithContext = (CombineFnWithContext<InputT, AccumT, OutputT>) globalCombineFn; return combineFnWithContext; } else { @SuppressWarnings("unchecked") final CombineFn<InputT, AccumT, OutputT> combineFn = (CombineFn<InputT, AccumT, OutputT>) globalCombineFn; return new CombineFnWithContext<InputT, AccumT, OutputT>() { @Override public AccumT createAccumulator(Context c) { return combineFn.createAccumulator(); } @Override public AccumT addInput(AccumT accumulator, InputT input, Context c) { return combineFn.addInput(accumulator, input); } @Override public AccumT mergeAccumulators(Iterable<AccumT> accumulators, Context c) { return combineFn.mergeAccumulators(accumulators); } @Override public OutputT extractOutput(AccumT accumulator, Context c) { return combineFn.extractOutput(accumulator); } @Override public AccumT compact(AccumT accumulator, Context c) { return combineFn.compact(accumulator); } @Override public OutputT defaultValue() { return combineFn.defaultValue(); } @Override public Coder<AccumT> getAccumulatorCoder(CoderRegistry registry, Coder<InputT> inputCoder) throws CannotProvideCoderException { return combineFn.getAccumulatorCoder(registry, inputCoder); } @Override public Coder<OutputT> getDefaultOutputCoder( CoderRegistry registry, Coder<InputT> inputCoder) throws CannotProvideCoderException { return combineFn.getDefaultOutputCoder(registry, inputCoder); } @Override public void populateDisplayData(DisplayData.Builder builder) { super.populateDisplayData(builder); combineFn.populateDisplayData(builder); } }; } } private static class NonSerializableBoundedCombineFn<InputT, AccumT, OutputT> extends CombineFn<InputT, AccumT, OutputT> { private final CombineFnWithContext<InputT, AccumT, OutputT> combineFn; private final Context context; private NonSerializableBoundedCombineFn( CombineFnWithContext<InputT, AccumT, OutputT> combineFn, Context context) { this.combineFn = combineFn; this.context = context; } @Override public AccumT createAccumulator() { return combineFn.createAccumulator(context); } @Override public AccumT addInput(AccumT accumulator, InputT value) { return combineFn.addInput(accumulator, value, context); } @Override public AccumT mergeAccumulators(Iterable<AccumT> accumulators) { return combineFn.mergeAccumulators(accumulators, context); } @Override public OutputT extractOutput(AccumT accumulator) { return combineFn.extractOutput(accumulator, context); } @Override public AccumT compact(AccumT accumulator) { return combineFn.compact(accumulator, context); } @Override public Coder<AccumT> getAccumulatorCoder(CoderRegistry registry, Coder<InputT> inputCoder) throws CannotProvideCoderException { return combineFn.getAccumulatorCoder(registry, inputCoder); } @Override public Coder<OutputT> getDefaultOutputCoder( CoderRegistry registry, Coder<InputT> inputCoder) throws CannotProvideCoderException { return combineFn.getDefaultOutputCoder(registry, inputCoder); } @Override public void populateDisplayData(DisplayData.Builder builder) { combineFn.populateDisplayData(builder); } private void writeObject(@SuppressWarnings("unused") ObjectOutputStream out) throws IOException { throw new NotSerializableException( "Cannot serialize the CombineFn resulting from CombineFnUtil.bindContext."); } } }