/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.core.construction;
import static com.google.common.base.Preconditions.checkArgument;
import com.google.auto.service.AutoService;
import com.google.auto.value.AutoValue;
import com.google.common.base.Optional;
import com.google.protobuf.Any;
import com.google.protobuf.ByteString;
import com.google.protobuf.BytesValue;
import com.google.protobuf.InvalidProtocolBufferException;
import java.io.IOException;
import java.io.Serializable;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import org.apache.beam.runners.core.construction.PTransforms.TransformPayloadTranslator;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.IterableCoder;
import org.apache.beam.sdk.common.runner.v1.RunnerApi;
import org.apache.beam.sdk.common.runner.v1.RunnerApi.Components;
import org.apache.beam.sdk.common.runner.v1.RunnerApi.FunctionSpec;
import org.apache.beam.sdk.common.runner.v1.RunnerApi.ParDoPayload;
import org.apache.beam.sdk.common.runner.v1.RunnerApi.Parameter.Type;
import org.apache.beam.sdk.common.runner.v1.RunnerApi.SdkFunctionSpec;
import org.apache.beam.sdk.common.runner.v1.RunnerApi.SideInput;
import org.apache.beam.sdk.common.runner.v1.RunnerApi.SideInput.Builder;
import org.apache.beam.sdk.common.runner.v1.RunnerApi.StateSpec;
import org.apache.beam.sdk.common.runner.v1.RunnerApi.TimerSpec;
import org.apache.beam.sdk.runners.AppliedPTransform;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.Materializations;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.ParDo.MultiOutput;
import org.apache.beam.sdk.transforms.ViewFn;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.Cases;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.RestrictionTrackerParameter;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.WindowParameter;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature.StateDeclaration;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature.TimerDeclaration;
import org.apache.beam.sdk.transforms.reflect.DoFnSignatures;
import org.apache.beam.sdk.transforms.windowing.WindowMappingFn;
import org.apache.beam.sdk.util.SerializableUtils;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.util.WindowedValue.FullWindowedValueCoder;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.WindowingStrategy;
/**
* Utilities for interacting with {@link ParDo} instances and {@link ParDoPayload} protos.
*/
public class ParDos {
/**
* The URN for a {@link ParDoPayload}.
*/
public static final String PAR_DO_PAYLOAD_URN = "urn:beam:pardo:v1";
/**
* The URN for an unknown Java {@link DoFn}.
*/
public static final String CUSTOM_JAVA_DO_FN_URN = "urn:beam:dofn:javasdk:0.1";
/**
* The URN for an unknown Java {@link ViewFn}.
*/
public static final String CUSTOM_JAVA_VIEW_FN_URN = "urn:beam:viewfn:javasdk:0.1";
/**
* The URN for an unknown Java {@link WindowMappingFn}.
*/
public static final String CUSTOM_JAVA_WINDOW_MAPPING_FN_URN =
"urn:beam:windowmappingfn:javasdk:0.1";
/**
* A {@link TransformPayloadTranslator} for {@link ParDo}.
*/
public static class ParDoPayloadTranslator
implements PTransforms.TransformPayloadTranslator<ParDo.MultiOutput<?, ?>> {
public static TransformPayloadTranslator create() {
return new ParDoPayloadTranslator();
}
private ParDoPayloadTranslator() {}
@Override
public FunctionSpec translate(
AppliedPTransform<?, ?, MultiOutput<?, ?>> transform, SdkComponents components) {
ParDoPayload payload = toProto(transform.getTransform(), components);
return RunnerApi.FunctionSpec.newBuilder()
.setUrn(PAR_DO_PAYLOAD_URN)
.setParameter(Any.pack(payload))
.build();
}
/**
* Registers {@link ParDoPayloadTranslator}.
*/
@AutoService(TransformPayloadTranslatorRegistrar.class)
public static class Registrar implements TransformPayloadTranslatorRegistrar {
@Override
public Map<? extends Class<? extends PTransform>, ? extends TransformPayloadTranslator>
getTransformPayloadTranslators() {
return Collections.singletonMap(ParDo.MultiOutput.class, new ParDoPayloadTranslator());
}
}
}
public static ParDoPayload toProto(ParDo.MultiOutput<?, ?> parDo, SdkComponents components) {
DoFnSignature signature = DoFnSignatures.getSignature(parDo.getFn().getClass());
Map<String, StateDeclaration> states = signature.stateDeclarations();
Map<String, TimerDeclaration> timers = signature.timerDeclarations();
List<Parameter> parameters = signature.processElement().extraParameters();
ParDoPayload.Builder builder = ParDoPayload.newBuilder();
builder.setDoFn(toProto(parDo.getFn(), parDo.getMainOutputTag()));
for (PCollectionView<?> sideInput : parDo.getSideInputs()) {
builder.putSideInputs(sideInput.getTagInternal().getId(), toProto(sideInput));
}
for (Parameter parameter : parameters) {
Optional<RunnerApi.Parameter> protoParameter = toProto(parameter);
if (protoParameter.isPresent()) {
builder.addParameters(protoParameter.get());
}
}
for (Map.Entry<String, StateDeclaration> state : states.entrySet()) {
StateSpec spec = toProto(state.getValue());
builder.putStateSpecs(state.getKey(), spec);
}
for (Map.Entry<String, TimerDeclaration> timer : timers.entrySet()) {
TimerSpec spec = toProto(timer.getValue());
builder.putTimerSpecs(timer.getKey(), spec);
}
return builder.build();
}
public static DoFn<?, ?> getDoFn(ParDoPayload payload) throws InvalidProtocolBufferException {
return doFnAndMainOutputTagFromProto(payload.getDoFn()).getDoFn();
}
public static TupleTag<?> getMainOutputTag(ParDoPayload payload)
throws InvalidProtocolBufferException {
return doFnAndMainOutputTagFromProto(payload.getDoFn()).getMainOutputTag();
}
// TODO: Implement
private static StateSpec toProto(StateDeclaration state) {
throw new UnsupportedOperationException("Not yet supported");
}
// TODO: Implement
private static TimerSpec toProto(TimerDeclaration timer) {
throw new UnsupportedOperationException("Not yet supported");
}
@AutoValue
abstract static class DoFnAndMainOutput implements Serializable {
public static DoFnAndMainOutput of(
DoFn<?, ?> fn, TupleTag<?> tag) {
return new AutoValue_ParDos_DoFnAndMainOutput(fn, tag);
}
abstract DoFn<?, ?> getDoFn();
abstract TupleTag<?> getMainOutputTag();
}
private static SdkFunctionSpec toProto(DoFn<?, ?> fn, TupleTag<?> tag) {
return SdkFunctionSpec.newBuilder()
.setSpec(
FunctionSpec.newBuilder()
.setUrn(CUSTOM_JAVA_DO_FN_URN)
.setParameter(
Any.pack(
BytesValue.newBuilder()
.setValue(
ByteString.copyFrom(
SerializableUtils.serializeToByteArray(
DoFnAndMainOutput.of(fn, tag))))
.build())))
.build();
}
private static DoFnAndMainOutput doFnAndMainOutputTagFromProto(SdkFunctionSpec fnSpec)
throws InvalidProtocolBufferException {
checkArgument(fnSpec.getSpec().getUrn().equals(CUSTOM_JAVA_DO_FN_URN));
byte[] serializedFn =
fnSpec.getSpec().getParameter().unpack(BytesValue.class).getValue().toByteArray();
return (DoFnAndMainOutput)
SerializableUtils.deserializeFromByteArray(serializedFn, "Custom DoFn And Main Output tag");
}
private static Optional<RunnerApi.Parameter> toProto(Parameter parameter) {
return parameter.match(
new Cases.WithDefault<Optional<RunnerApi.Parameter>>() {
@Override
public Optional<RunnerApi.Parameter> dispatch(WindowParameter p) {
return Optional.of(RunnerApi.Parameter.newBuilder().setType(Type.WINDOW).build());
}
@Override
public Optional<RunnerApi.Parameter> dispatch(RestrictionTrackerParameter p) {
return Optional.of(
RunnerApi.Parameter.newBuilder().setType(Type.RESTRICTION_TRACKER).build());
}
@Override
protected Optional<RunnerApi.Parameter> dispatchDefault(Parameter p) {
return Optional.absent();
}
});
}
private static SideInput toProto(PCollectionView<?> view) {
Builder builder = SideInput.newBuilder();
builder.setAccessPattern(
FunctionSpec.newBuilder()
.setUrn(view.getViewFn().getMaterialization().getUrn())
.build());
builder.setViewFn(toProto(view.getViewFn()));
builder.setWindowMappingFn(toProto(view.getWindowMappingFn()));
return builder.build();
}
public static PCollectionView<?> fromProto(
SideInput sideInput, String id, RunnerApi.PTransform parDoTransform, Components components)
throws IOException {
TupleTag<?> tag = new TupleTag<>(id);
WindowMappingFn<?> windowMappingFn = windowMappingFnFromProto(sideInput.getWindowMappingFn());
ViewFn<?, ?> viewFn = viewFnFromProto(sideInput.getViewFn());
RunnerApi.PCollection inputCollection =
components.getPcollectionsOrThrow(parDoTransform.getInputsOrThrow(id));
WindowingStrategy<?, ?> windowingStrategy =
WindowingStrategies.fromProto(
components.getWindowingStrategiesOrThrow(inputCollection.getWindowingStrategyId()),
components);
Coder<?> elemCoder =
Coders.fromProto(components.getCodersOrThrow(inputCollection.getCoderId()), components);
Coder<Iterable<WindowedValue<?>>> coder =
(Coder)
IterableCoder.of(
FullWindowedValueCoder.of(
elemCoder, windowingStrategy.getWindowFn().windowCoder()));
checkArgument(
sideInput.getAccessPattern().getUrn().equals(Materializations.ITERABLE_MATERIALIZATION_URN),
"Unknown View Materialization URN %s",
sideInput.getAccessPattern().getUrn());
PCollectionView<?> view =
new RunnerPCollectionView<>(
(TupleTag<Iterable<WindowedValue<?>>>) tag,
(ViewFn<Iterable<WindowedValue<?>>, ?>) viewFn,
windowMappingFn,
windowingStrategy,
coder);
return view;
}
private static SdkFunctionSpec toProto(ViewFn<?, ?> viewFn) {
return SdkFunctionSpec.newBuilder()
.setSpec(
FunctionSpec.newBuilder()
.setUrn(CUSTOM_JAVA_VIEW_FN_URN)
.setParameter(
Any.pack(
BytesValue.newBuilder()
.setValue(
ByteString.copyFrom(SerializableUtils.serializeToByteArray(viewFn)))
.build())))
.build();
}
private static ViewFn<?, ?> viewFnFromProto(SdkFunctionSpec viewFn)
throws InvalidProtocolBufferException {
FunctionSpec spec = viewFn.getSpec();
checkArgument(
spec.getUrn().equals(CUSTOM_JAVA_VIEW_FN_URN),
"Can't deserialize unknown %s type %s",
ViewFn.class.getSimpleName(),
spec.getUrn());
return (ViewFn<?, ?>)
SerializableUtils.deserializeFromByteArray(
spec.getParameter().unpack(BytesValue.class).getValue().toByteArray(), "Custom ViewFn");
}
private static SdkFunctionSpec toProto(WindowMappingFn<?> windowMappingFn) {
return SdkFunctionSpec.newBuilder()
.setSpec(
FunctionSpec.newBuilder()
.setUrn(CUSTOM_JAVA_WINDOW_MAPPING_FN_URN)
.setParameter(
Any.pack(
BytesValue.newBuilder()
.setValue(
ByteString.copyFrom(
SerializableUtils.serializeToByteArray(windowMappingFn)))
.build())))
.build();
}
private static WindowMappingFn<?> windowMappingFnFromProto(SdkFunctionSpec windowMappingFn)
throws InvalidProtocolBufferException {
FunctionSpec spec = windowMappingFn.getSpec();
checkArgument(
spec.getUrn().equals(CUSTOM_JAVA_WINDOW_MAPPING_FN_URN),
"Can't deserialize unknown %s type %s",
WindowMappingFn.class.getSimpleName(),
spec.getUrn());
return (WindowMappingFn<?>)
SerializableUtils.deserializeFromByteArray(
spec.getParameter().unpack(BytesValue.class).getValue().toByteArray(),
"Custom WinodwMappingFn");
}
}