/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.fn.harness.data;
import static org.apache.beam.sdk.util.CoderUtils.encodeToByteArray;
import static org.apache.beam.sdk.util.WindowedValue.valueInGlobalWindow;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.collection.IsEmptyCollection.empty;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.fail;
import com.google.protobuf.ByteString;
import io.grpc.ManagedChannel;
import io.grpc.Server;
import io.grpc.inprocess.InProcessChannelBuilder;
import io.grpc.inprocess.InProcessServerBuilder;
import io.grpc.stub.CallStreamObserver;
import io.grpc.stub.StreamObserver;
import java.util.Collection;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import java.util.function.Function;
import org.apache.beam.fn.harness.fn.CloseableThrowingConsumer;
import org.apache.beam.fn.harness.fn.ThrowingConsumer;
import org.apache.beam.fn.harness.test.TestStreams;
import org.apache.beam.fn.v1.BeamFnApi;
import org.apache.beam.fn.v1.BeamFnDataGrpc;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.LengthPrefixCoder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.KV;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** Tests for {@link BeamFnDataGrpcClient}. */
@RunWith(JUnit4.class)
public class BeamFnDataGrpcClientTest {
private static final Coder<WindowedValue<String>> CODER =
LengthPrefixCoder.of(
WindowedValue.getFullCoder(StringUtf8Coder.of(),
GlobalWindow.Coder.INSTANCE));
private static final KV<String, BeamFnApi.Target> KEY_A =
KV.of(
"12L",
BeamFnApi.Target.newBuilder()
.setPrimitiveTransformReference("34L")
.setName("targetA")
.build());
private static final KV<String, BeamFnApi.Target> KEY_B =
KV.of(
"56L",
BeamFnApi.Target.newBuilder()
.setPrimitiveTransformReference("78L")
.setName("targetB")
.build());
private static final BeamFnApi.Elements ELEMENTS_A_1;
private static final BeamFnApi.Elements ELEMENTS_A_2;
private static final BeamFnApi.Elements ELEMENTS_B_1;
static {
try {
ELEMENTS_A_1 = BeamFnApi.Elements.newBuilder()
.addData(BeamFnApi.Elements.Data.newBuilder()
.setInstructionReference(KEY_A.getKey())
.setTarget(KEY_A.getValue())
.setData(ByteString.copyFrom(encodeToByteArray(CODER, valueInGlobalWindow("ABC")))
.concat(ByteString.copyFrom(encodeToByteArray(CODER, valueInGlobalWindow("DEF"))))))
.build();
ELEMENTS_A_2 = BeamFnApi.Elements.newBuilder()
.addData(BeamFnApi.Elements.Data.newBuilder()
.setInstructionReference(KEY_A.getKey())
.setTarget(KEY_A.getValue())
.setData(ByteString.copyFrom(encodeToByteArray(CODER, valueInGlobalWindow("GHI")))))
.addData(BeamFnApi.Elements.Data.newBuilder()
.setInstructionReference(KEY_A.getKey())
.setTarget(KEY_A.getValue()))
.build();
ELEMENTS_B_1 = BeamFnApi.Elements.newBuilder()
.addData(BeamFnApi.Elements.Data.newBuilder()
.setInstructionReference(KEY_B.getKey())
.setTarget(KEY_B.getValue())
.setData(ByteString.copyFrom(encodeToByteArray(CODER, valueInGlobalWindow("JKL")))
.concat(ByteString.copyFrom(encodeToByteArray(CODER, valueInGlobalWindow("MNO"))))))
.addData(BeamFnApi.Elements.Data.newBuilder()
.setInstructionReference(KEY_B.getKey())
.setTarget(KEY_B.getValue()))
.build();
} catch (Exception e) {
throw new ExceptionInInitializerError(e);
}
}
@Test
public void testForInboundConsumer() throws Exception {
CountDownLatch waitForClientToConnect = new CountDownLatch(1);
Collection<WindowedValue<String>> inboundValuesA = new ConcurrentLinkedQueue<>();
Collection<WindowedValue<String>> inboundValuesB = new ConcurrentLinkedQueue<>();
Collection<BeamFnApi.Elements> inboundServerValues = new ConcurrentLinkedQueue<>();
AtomicReference<StreamObserver<BeamFnApi.Elements>> outboundServerObserver =
new AtomicReference<>();
CallStreamObserver<BeamFnApi.Elements> inboundServerObserver =
TestStreams.withOnNext(inboundServerValues::add).build();
BeamFnApi.ApiServiceDescriptor apiServiceDescriptor =
BeamFnApi.ApiServiceDescriptor.newBuilder()
.setUrl(this.getClass().getName() + "-" + UUID.randomUUID().toString())
.build();
Server server = InProcessServerBuilder.forName(apiServiceDescriptor.getUrl())
.addService(new BeamFnDataGrpc.BeamFnDataImplBase() {
@Override
public StreamObserver<BeamFnApi.Elements> data(
StreamObserver<BeamFnApi.Elements> outboundObserver) {
outboundServerObserver.set(outboundObserver);
waitForClientToConnect.countDown();
return inboundServerObserver;
}
})
.build();
server.start();
try {
ManagedChannel channel =
InProcessChannelBuilder.forName(apiServiceDescriptor.getUrl()).build();
BeamFnDataGrpcClient clientFactory = new BeamFnDataGrpcClient(
PipelineOptionsFactory.create(),
(BeamFnApi.ApiServiceDescriptor descriptor) -> channel,
this::createStreamForTest);
CompletableFuture<Void> readFutureA = clientFactory.forInboundConsumer(
apiServiceDescriptor,
KEY_A,
CODER,
inboundValuesA::add);
waitForClientToConnect.await();
outboundServerObserver.get().onNext(ELEMENTS_A_1);
// Purposefully transmit some data before the consumer for B is bound showing that
// data is not lost
outboundServerObserver.get().onNext(ELEMENTS_B_1);
Thread.sleep(100);
CompletableFuture<Void> readFutureB = clientFactory.forInboundConsumer(
apiServiceDescriptor,
KEY_B,
CODER,
inboundValuesB::add);
// Show that out of order stream completion can occur.
readFutureB.get();
assertThat(inboundValuesB, contains(
valueInGlobalWindow("JKL"), valueInGlobalWindow("MNO")));
outboundServerObserver.get().onNext(ELEMENTS_A_2);
readFutureA.get();
assertThat(inboundValuesA, contains(
valueInGlobalWindow("ABC"), valueInGlobalWindow("DEF"), valueInGlobalWindow("GHI")));
} finally {
server.shutdownNow();
}
}
@Test
public void testForInboundConsumerThatThrows() throws Exception {
CountDownLatch waitForClientToConnect = new CountDownLatch(1);
AtomicInteger consumerInvoked = new AtomicInteger();
Collection<BeamFnApi.Elements> inboundServerValues = new ConcurrentLinkedQueue<>();
AtomicReference<StreamObserver<BeamFnApi.Elements>> outboundServerObserver =
new AtomicReference<>();
CallStreamObserver<BeamFnApi.Elements> inboundServerObserver =
TestStreams.withOnNext(inboundServerValues::add).build();
BeamFnApi.ApiServiceDescriptor apiServiceDescriptor =
BeamFnApi.ApiServiceDescriptor.newBuilder()
.setUrl(this.getClass().getName() + "-" + UUID.randomUUID().toString())
.build();
Server server = InProcessServerBuilder.forName(apiServiceDescriptor.getUrl())
.addService(new BeamFnDataGrpc.BeamFnDataImplBase() {
@Override
public StreamObserver<BeamFnApi.Elements> data(
StreamObserver<BeamFnApi.Elements> outboundObserver) {
outboundServerObserver.set(outboundObserver);
waitForClientToConnect.countDown();
return inboundServerObserver;
}
})
.build();
server.start();
RuntimeException exceptionToThrow = new RuntimeException("TestFailure");
try {
ManagedChannel channel =
InProcessChannelBuilder.forName(apiServiceDescriptor.getUrl()).build();
BeamFnDataGrpcClient clientFactory = new BeamFnDataGrpcClient(
PipelineOptionsFactory.create(),
(BeamFnApi.ApiServiceDescriptor descriptor) -> channel,
this::createStreamForTest);
CompletableFuture<Void> readFuture = clientFactory.forInboundConsumer(
apiServiceDescriptor,
KEY_A,
CODER,
new ThrowingConsumer<WindowedValue<String>>() {
@Override
public void accept(WindowedValue<String> t) throws Exception {
consumerInvoked.incrementAndGet();
throw exceptionToThrow;
}
});
waitForClientToConnect.await();
// This first message should cause a failure afterwards all other messages are dropped.
outboundServerObserver.get().onNext(ELEMENTS_A_1);
outboundServerObserver.get().onNext(ELEMENTS_A_2);
try {
readFuture.get();
fail("Expected channel to fail");
} catch (ExecutionException e) {
assertEquals(exceptionToThrow, e.getCause());
}
// The server should not have received any values
assertThat(inboundServerValues, empty());
// The consumer should have only been invoked once
assertEquals(1, consumerInvoked.get());
} finally {
server.shutdownNow();
}
}
@Test
public void testForOutboundConsumer() throws Exception {
CountDownLatch waitForInboundServerValuesCompletion = new CountDownLatch(2);
Collection<BeamFnApi.Elements> inboundServerValues = new ConcurrentLinkedQueue<>();
CallStreamObserver<BeamFnApi.Elements> inboundServerObserver =
TestStreams.withOnNext(
new Consumer<BeamFnApi.Elements>() {
@Override
public void accept(BeamFnApi.Elements t) {
inboundServerValues.add(t);
waitForInboundServerValuesCompletion.countDown();
}
}
).build();
BeamFnApi.ApiServiceDescriptor apiServiceDescriptor =
BeamFnApi.ApiServiceDescriptor.newBuilder()
.setUrl(this.getClass().getName() + "-" + UUID.randomUUID().toString())
.build();
Server server = InProcessServerBuilder.forName(apiServiceDescriptor.getUrl())
.addService(new BeamFnDataGrpc.BeamFnDataImplBase() {
@Override
public StreamObserver<BeamFnApi.Elements> data(
StreamObserver<BeamFnApi.Elements> outboundObserver) {
return inboundServerObserver;
}
})
.build();
server.start();
try {
ManagedChannel channel =
InProcessChannelBuilder.forName(apiServiceDescriptor.getUrl()).build();
BeamFnDataGrpcClient clientFactory = new BeamFnDataGrpcClient(
PipelineOptionsFactory.fromArgs(
new String[]{ "--experiments=beam_fn_api_data_buffer_limit=20" }).create(),
(BeamFnApi.ApiServiceDescriptor descriptor) -> channel,
this::createStreamForTest);
try (CloseableThrowingConsumer<WindowedValue<String>> consumer =
clientFactory.forOutboundConsumer(apiServiceDescriptor, KEY_A, CODER)) {
consumer.accept(valueInGlobalWindow("ABC"));
consumer.accept(valueInGlobalWindow("DEF"));
consumer.accept(valueInGlobalWindow("GHI"));
}
waitForInboundServerValuesCompletion.await();
assertThat(inboundServerValues, contains(ELEMENTS_A_1, ELEMENTS_A_2));
} finally {
server.shutdownNow();
}
}
private <ReqT, RespT> StreamObserver<RespT> createStreamForTest(
Function<StreamObserver<ReqT>, StreamObserver<RespT>> clientFactory,
StreamObserver<ReqT> handler) {
return clientFactory.apply(handler);
}
}