package com.github.kristofa.brave.okhttp; import com.github.kristofa.brave.Brave; import com.github.kristofa.brave.ClientRequestInterceptor; import com.github.kristofa.brave.ClientResponseInterceptor; import com.github.kristofa.brave.ClientTracer; import com.github.kristofa.brave.SpanId; import com.github.kristofa.brave.http.BraveHttpHeaders; import okhttp3.OkHttpClient; import okhttp3.Request; import okhttp3.Response; import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; import okhttp3.mockwebserver.RecordedRequest; import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; import org.mockito.Answers; import org.mockito.InOrder; import org.mockito.Mock; import org.mockito.MockitoAnnotations; import zipkin.TraceKeys; import java.io.IOException; import static org.junit.Assert.assertEquals; import static org.mockito.Mockito.*; public class BraveOkHttpRequestResponseInterceptorTest { private static final Long SPAN_ID = 151864L; private static final Long TRACE_ID = 8494864L; private static final String TRACE_ID_STRING = SpanId.builder().spanId(TRACE_ID).build().traceIdString(); private static final String HTTP_METHOD_GET = "GET"; @Rule public final ExpectedException thrown = ExpectedException.none(); @Rule public final MockWebServer server = new MockWebServer(); @Mock private Brave brave; @Mock(answer = Answers.RETURNS_SMART_NULLS) private ClientTracer clientTracer; private SpanId spanId; private OkHttpClient client; @Before public void setup() throws IOException { MockitoAnnotations.initMocks(this); this.spanId = SpanId.builder().spanId(SPAN_ID).traceId(TRACE_ID).parentId(null).build(); when(brave.clientRequestInterceptor()) .thenReturn(new ClientRequestInterceptor(clientTracer)); when(brave.clientResponseInterceptor()) .thenReturn(new ClientResponseInterceptor(clientTracer)); this.client = new OkHttpClient.Builder() .addInterceptor(BraveOkHttpRequestResponseInterceptor.create(brave)) .build(); } @Test public void testTracingTrue() throws IOException, InterruptedException { when(clientTracer.startNewSpan(HTTP_METHOD_GET)).thenReturn(spanId); String url = "http://localhost:" + server.getPort() + "/foo"; Request request = new Request.Builder() .url(url) .build(); server.enqueue(new MockResponse() .setBody("bar") .setResponseCode(200) ); Response response = client.newCall(request).execute(); assertEquals(200, response.code()); InOrder inOrder = inOrder(clientTracer); inOrder.verify(clientTracer).startNewSpan(HTTP_METHOD_GET); inOrder.verify(clientTracer).submitBinaryAnnotation(TraceKeys.HTTP_URL, url); inOrder.verify(clientTracer).setClientSent(); inOrder.verify(clientTracer).setClientReceived(); verifyNoMoreInteractions(clientTracer); RecordedRequest serverRequest = server.takeRequest(); assertEquals(HTTP_METHOD_GET, serverRequest.getMethod()); assertEquals("1", serverRequest.getHeader(BraveHttpHeaders.Sampled.getName())); assertEquals(TRACE_ID_STRING, serverRequest.getHeader(BraveHttpHeaders.TraceId.getName())); assertEquals(Long.toString(SPAN_ID, 16), serverRequest.getHeader(BraveHttpHeaders.SpanId.getName())); } @Test public void testTracingTrueHttpNoOk() throws IOException, InterruptedException { when(clientTracer.startNewSpan(HTTP_METHOD_GET)).thenReturn(spanId); String url = "http://localhost:" + server.getPort() + "/foo"; Request request = new Request.Builder() .url(url) .build(); server.enqueue(new MockResponse() .setBody("bar") .setResponseCode(400) ); Response response = client.newCall(request).execute(); assertEquals(400, response.code()); InOrder inOrder = inOrder(clientTracer); inOrder.verify(clientTracer).startNewSpan(HTTP_METHOD_GET); inOrder.verify(clientTracer).submitBinaryAnnotation(TraceKeys.HTTP_URL, url); inOrder.verify(clientTracer).setClientSent(); inOrder.verify(clientTracer).submitBinaryAnnotation(TraceKeys.HTTP_STATUS_CODE, "400"); inOrder.verify(clientTracer).setClientReceived(); verifyNoMoreInteractions(clientTracer); RecordedRequest serverRequest = server.takeRequest(); assertEquals(HTTP_METHOD_GET, serverRequest.getMethod()); assertEquals("1", serverRequest.getHeader(BraveHttpHeaders.Sampled.getName())); assertEquals(TRACE_ID_STRING, serverRequest.getHeader(BraveHttpHeaders.TraceId.getName())); assertEquals(Long.toString(SPAN_ID, 16), serverRequest.getHeader(BraveHttpHeaders.SpanId.getName())); } @Test public void testTracingFalse() throws IOException, InterruptedException { when(clientTracer.startNewSpan(HTTP_METHOD_GET)).thenReturn(null); String url = "http://localhost:" + server.getPort() + "/foo"; Request request = new Request.Builder() .url(url) .build(); server.enqueue(new MockResponse() .setBody("bar") ); Response response = client.newCall(request).execute(); assertEquals(200, response.code()); InOrder inOrder = inOrder(clientTracer); inOrder.verify(clientTracer).startNewSpan(HTTP_METHOD_GET); inOrder.verify(clientTracer).setClientReceived(); verifyNoMoreInteractions(clientTracer); RecordedRequest serverRequest = server.takeRequest(); assertEquals(HTTP_METHOD_GET, serverRequest.getMethod()); assertEquals("0", serverRequest.getHeader(BraveHttpHeaders.Sampled.getName())); } @Test public void testQueryParams() throws IOException, InterruptedException { when(clientTracer.startNewSpan(HTTP_METHOD_GET)).thenReturn(spanId); String url = "http://localhost:" + server.getPort() + "/foo?z=2&yAA"; Request request = new Request.Builder() .url(url) .build(); server.enqueue(new MockResponse() .setBody("bar") ); Response response = client.newCall(request).execute(); assertEquals(200, response.code()); InOrder inOrder = inOrder(clientTracer); inOrder.verify(clientTracer).startNewSpan(HTTP_METHOD_GET); inOrder.verify(clientTracer).submitBinaryAnnotation(TraceKeys.HTTP_URL, url); inOrder.verify(clientTracer).setClientSent(); inOrder.verify(clientTracer).setClientReceived(); verifyNoMoreInteractions(clientTracer); RecordedRequest serverRequest = server.takeRequest(); assertEquals(HTTP_METHOD_GET, serverRequest.getMethod()); assertEquals("1", serverRequest.getHeader(BraveHttpHeaders.Sampled.getName())); assertEquals(TRACE_ID_STRING, serverRequest.getHeader(BraveHttpHeaders.TraceId.getName())); assertEquals(Long.toString(SPAN_ID, 16), serverRequest.getHeader(BraveHttpHeaders.SpanId.getName())); } }