package brave.grpc;
import brave.Tracing;
import brave.internal.HexCodec;
import brave.internal.StrictCurrentTraceContext;
import brave.sampler.Sampler;
import io.grpc.CallOptions;
import io.grpc.Channel;
import io.grpc.ClientCall;
import io.grpc.ClientInterceptor;
import io.grpc.ClientInterceptors;
import io.grpc.ForwardingClientCall.SimpleForwardingClientCall;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.Metadata;
import io.grpc.Metadata.Key;
import io.grpc.MethodDescriptor;
import io.grpc.Server;
import io.grpc.ServerBuilder;
import io.grpc.ServerInterceptors;
import io.grpc.StatusRuntimeException;
import io.grpc.examples.helloworld.GreeterGrpc;
import io.grpc.examples.helloworld.HelloRequest;
import java.util.concurrent.ConcurrentLinkedDeque;
import java.util.concurrent.TimeUnit;
import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import zipkin.Constants;
import zipkin.Span;
import zipkin.internal.Util;
import static brave.grpc.GreeterImpl.HELLO_REQUEST;
import static io.grpc.Metadata.ASCII_STRING_MARSHALLER;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.failBecauseExceptionWasNotThrown;
import static org.assertj.core.api.Assertions.tuple;
public class ITTracingServerInterceptor {
@Rule public ExpectedException thrown = ExpectedException.none();
ConcurrentLinkedDeque<Span> spans = new ConcurrentLinkedDeque<>();
Tracing tracing;
Server server;
ManagedChannel client;
@Before public void setup() throws Exception {
tracing = tracingBuilder(Sampler.ALWAYS_SAMPLE).build();
init();
}
void init() throws Exception {
stop();
server = ServerBuilder.forPort(PickUnusedPort.get())
.addService(ServerInterceptors.intercept(new GreeterImpl(),
GrpcTracing.create(tracing).newServerInterceptor()))
.build().start();
client = ManagedChannelBuilder.forAddress("localhost", server.getPort())
.usePlaintext(true)
.build();
}
@After
public void stop() throws Exception {
if (client != null) {
client.shutdown();
client.awaitTermination(1, TimeUnit.SECONDS);
}
if (server != null) {
server.shutdown();
server.awaitTermination();
}
}
@Test
public void usesExistingTraceId() throws Exception {
final String traceId = "463ac35c9f6413ad";
final String parentId = traceId;
final String spanId = "48485a3953bb6124";
Channel channel = ClientInterceptors.intercept(client, new ClientInterceptor() {
@Override
public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) {
return new SimpleForwardingClientCall<ReqT, RespT>(next.newCall(method, callOptions)) {
@Override
public void start(Listener<RespT> responseListener, Metadata headers) {
headers.put(Key.of("X-B3-TraceId", ASCII_STRING_MARSHALLER), traceId);
headers.put(Key.of("X-B3-ParentSpanId", ASCII_STRING_MARSHALLER), parentId);
headers.put(Key.of("X-B3-SpanId", ASCII_STRING_MARSHALLER), spanId);
super.start(responseListener, headers);
}
};
}
});
GreeterGrpc.newBlockingStub(channel).sayHello(HELLO_REQUEST);
assertThat(spans).allSatisfy(s -> {
assertThat(HexCodec.toLowerHex(s.traceId)).isEqualTo(traceId);
assertThat(HexCodec.toLowerHex(s.parentId)).isEqualTo(parentId);
assertThat(HexCodec.toLowerHex(s.id)).isEqualTo(spanId);
});
}
@Test
public void samplingDisabled() throws Exception {
tracing = tracingBuilder(Sampler.NEVER_SAMPLE).build();
init();
GreeterGrpc.newBlockingStub(client).sayHello(HELLO_REQUEST);
assertThat(spans)
.isEmpty();
}
@Test
public void reportsServerAnnotationsToZipkin() throws Exception {
GreeterGrpc.newBlockingStub(client).sayHello(HELLO_REQUEST);
assertThat(spans)
.flatExtracting(s -> s.annotations)
.extracting(a -> a.value)
.containsExactly("sr", "ss");
}
@Test
public void defaultSpanNameIsMethodName() throws Exception {
GreeterGrpc.newBlockingStub(client).sayHello(HELLO_REQUEST);
assertThat(spans)
.extracting(s -> s.name)
.containsExactly("helloworld.greeter/sayhello");
}
@Test
public void addsErrorTagOnException() throws Exception {
try {
GreeterGrpc.newBlockingStub(client)
.sayHello(HelloRequest.newBuilder().setName("bad").build());
failBecauseExceptionWasNotThrown(StatusRuntimeException.class);
} catch (StatusRuntimeException e) {
assertThat(spans)
.flatExtracting(s -> s.binaryAnnotations)
.extracting(b -> tuple(b.key, new String(b.value, Util.UTF_8)))
.contains(tuple(Constants.ERROR, e.getStatus().getCode().toString()));
}
}
Tracing.Builder tracingBuilder(Sampler sampler) {
return Tracing.newBuilder()
.reporter(spans::add)
.currentTraceContext(new StrictCurrentTraceContext())
.sampler(sampler);
}
}