package me.dinowernli.grpc.polyglot.testing;
import java.util.ArrayList;
import java.util.List;
import io.grpc.stub.StreamObserver;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import polyglot.test.TestProto.TestRequest;
import polyglot.test.TestProto.TestResponse;
import polyglot.test.TestServiceGrpc.TestServiceImplBase;
/**
* An implementation of {@link TestServiceImplBase} which records the calls and produces a constant
* response.
*/
public class RecordingTestService extends TestServiceImplBase {
private static final Logger logger = LoggerFactory.getLogger(RecordingTestService.class);
/** The number of messages to wait for from client stream before ending the rpc. */
private static final int NUM_MESSAGES_FROM_CLIENT = 3;
/** We fake some expensive computation to shake out race conditions. */
private static final long RESPONSE_COMPUTATION_TIME_MS = 1000L;
private final TestResponse unaryResponse;
private final TestResponse streamResponse;
private final TestResponse clientStreamResponse;
private final TestResponse bidiStreamResponse;
private final List<TestRequest> recordedRequests;
public RecordingTestService(
TestResponse unaryResponse,
TestResponse streamResponse,
TestResponse clientStreamResponse,
TestResponse bidiStreamResponse) {
this.unaryResponse = unaryResponse;
this.streamResponse = streamResponse;
this.clientStreamResponse = clientStreamResponse;
this.bidiStreamResponse = bidiStreamResponse;
this.recordedRequests = new ArrayList<>();
}
@Override
public void testMethod(TestRequest request, StreamObserver<TestResponse> responseStream) {
logger.info("Handling unary method call");
recordedRequests.add(request);
fakeExpensiveComputation();
responseStream.onNext(unaryResponse);
responseStream.onCompleted();
}
public int numRequests() {
return recordedRequests.size();
}
public TestRequest getRequest(int index) {
return recordedRequests.get(index);
}
@Override
public StreamObserver<TestRequest> testMethodBidi(StreamObserver<TestResponse> responseStream) {
logger.info("Handling bidi method call");
return new StreamObserver<TestRequest>() {
int numRequestMessages = 0;
@Override
public void onNext(TestRequest testRequest) {
++numRequestMessages;
recordedRequests.add(testRequest);
responseStream.onNext(bidiStreamResponse);
if (numRequestMessages >= NUM_MESSAGES_FROM_CLIENT) {
fakeExpensiveComputation();
responseStream.onCompleted();
}
}
@Override
public void onError(Throwable t) {
logger.error("Got incoming error", t);
}
@Override
public void onCompleted() {
// Do nothing.
}
};
}
@Override
public void testMethodStream(TestRequest request, StreamObserver<TestResponse> responseStream) {
logger.info("Handling server streaming method call");
recordedRequests.add(request);
fakeExpensiveComputation();
responseStream.onNext(streamResponse);
responseStream.onCompleted();
}
@Override
public StreamObserver<TestRequest> testMethodClientStream(
StreamObserver<TestResponse> responseStream) {
logger.info("Handling client streaming method call");
return new StreamObserver<TestRequest>() {
int numRequestMessages = 0;
@Override
public void onNext(TestRequest testRequest) {
++numRequestMessages;
recordedRequests.add(testRequest);
if (numRequestMessages >= NUM_MESSAGES_FROM_CLIENT) {
fakeExpensiveComputation();
responseStream.onNext(clientStreamResponse);
responseStream.onCompleted();
}
}
@Override
public void onError(Throwable t) {
logger.error("Got incoming error", t);
}
@Override
public void onCompleted() {
// Do nothing.
}
};
}
private static void fakeExpensiveComputation() {
try {
Thread.sleep(RESPONSE_COMPUTATION_TIME_MS);
} catch (InterruptedException e) {
logger.error("Interrupted while sleeping", e);
}
}
}