/* * Copyright 2014, Google Inc. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are * met: * * * Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above * copyright notice, this list of conditions and the following disclaimer * in the documentation and/or other materials provided with the * distribution. * * * Neither the name of Google Inc. nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ package io.grpc; import static com.google.common.truth.Truth.assertThat; import static java.util.concurrent.TimeUnit.NANOSECONDS; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotSame; import static org.junit.Assert.assertSame; import static org.junit.Assert.assertTrue; import static org.mockito.Matchers.any; import static org.mockito.Matchers.isA; import static org.mockito.Matchers.same; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; import io.grpc.ClientInterceptors.CheckedForwardingClientCall; import io.grpc.ForwardingClientCall.SimpleForwardingClientCall; import io.grpc.ForwardingClientCallListener.SimpleForwardingClientCallListener; import io.grpc.testing.TestMethodDescriptors; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.Mockito; import org.mockito.MockitoAnnotations; /** Unit tests for {@link ClientInterceptors}. */ @RunWith(JUnit4.class) public class ClientInterceptorsTest { @Mock private Channel channel; private BaseClientCall call = new BaseClientCall(); private final MethodDescriptor<Void, Void> method = TestMethodDescriptors.voidMethod(); /** * Sets up mocks. */ @Before public void setUp() { MockitoAnnotations.initMocks(this); when(channel.newCall( Mockito.<MethodDescriptor<String, Integer>>any(), any(CallOptions.class))) .thenReturn(call); } @Test(expected = NullPointerException.class) public void npeForNullChannel() { ClientInterceptors.intercept(null, Arrays.<ClientInterceptor>asList()); } @Test(expected = NullPointerException.class) public void npeForNullInterceptorList() { ClientInterceptors.intercept(channel, (List<ClientInterceptor>) null); } @Test(expected = NullPointerException.class) public void npeForNullInterceptor() { ClientInterceptors.intercept(channel, (ClientInterceptor) null); } @Test public void noop() { assertSame(channel, ClientInterceptors.intercept(channel, Arrays.<ClientInterceptor>asList())); } @Test public void channelAndInterceptorCalled() { ClientInterceptor interceptor = spy(new NoopInterceptor()); Channel intercepted = ClientInterceptors.intercept(channel, interceptor); CallOptions callOptions = CallOptions.DEFAULT; // First call assertSame(call, intercepted.newCall(method, callOptions)); verify(channel).newCall(same(method), same(callOptions)); verify(interceptor).interceptCall(same(method), same(callOptions), Mockito.<Channel>any()); verifyNoMoreInteractions(channel, interceptor); // Second call assertSame(call, intercepted.newCall(method, callOptions)); verify(channel, times(2)).newCall(same(method), same(callOptions)); verify(interceptor, times(2)) .interceptCall(same(method), same(callOptions), Mockito.<Channel>any()); verifyNoMoreInteractions(channel, interceptor); } @Test public void callNextTwice() { ClientInterceptor interceptor = new ClientInterceptor() { @Override public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall( MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) { // Calling next twice is permitted, although should only rarely be useful. assertSame(call, next.newCall(method, callOptions)); return next.newCall(method, callOptions); } }; Channel intercepted = ClientInterceptors.intercept(channel, interceptor); assertSame(call, intercepted.newCall(method, CallOptions.DEFAULT)); verify(channel, times(2)).newCall(same(method), same(CallOptions.DEFAULT)); verifyNoMoreInteractions(channel); } @Test public void ordered() { final List<String> order = new ArrayList<String>(); channel = new Channel() { @SuppressWarnings("unchecked") @Override public <ReqT, RespT> ClientCall<ReqT, RespT> newCall( MethodDescriptor<ReqT, RespT> method, CallOptions callOptions) { order.add("channel"); return (ClientCall<ReqT, RespT>) call; } @Override public String authority() { return null; } }; ClientInterceptor interceptor1 = new ClientInterceptor() { @Override public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall( MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) { order.add("i1"); return next.newCall(method, callOptions); } }; ClientInterceptor interceptor2 = new ClientInterceptor() { @Override public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall( MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) { order.add("i2"); return next.newCall(method, callOptions); } }; Channel intercepted = ClientInterceptors.intercept(channel, interceptor1, interceptor2); assertSame(call, intercepted.newCall(method, CallOptions.DEFAULT)); assertEquals(Arrays.asList("i2", "i1", "channel"), order); } @Test public void orderedForward() { final List<String> order = new ArrayList<String>(); channel = new Channel() { @SuppressWarnings("unchecked") @Override public <ReqT, RespT> ClientCall<ReqT, RespT> newCall( MethodDescriptor<ReqT, RespT> method, CallOptions callOptions) { order.add("channel"); return (ClientCall<ReqT, RespT>) call; } @Override public String authority() { return null; } }; ClientInterceptor interceptor1 = new ClientInterceptor() { @Override public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall( MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) { order.add("i1"); return next.newCall(method, callOptions); } }; ClientInterceptor interceptor2 = new ClientInterceptor() { @Override public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall( MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) { order.add("i2"); return next.newCall(method, callOptions); } }; Channel intercepted = ClientInterceptors.interceptForward(channel, interceptor1, interceptor2); assertSame(call, intercepted.newCall(method, CallOptions.DEFAULT)); assertEquals(Arrays.asList("i1", "i2", "channel"), order); } @Test public void callOptions() { final CallOptions initialCallOptions = CallOptions.DEFAULT.withDeadlineAfter(100, NANOSECONDS); final CallOptions newCallOptions = initialCallOptions.withDeadlineAfter(300, NANOSECONDS); assertNotSame(initialCallOptions, newCallOptions); ClientInterceptor interceptor = spy(new ClientInterceptor() { @Override public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall( MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) { return next.newCall(method, newCallOptions); } }); Channel intercepted = ClientInterceptors.intercept(channel, interceptor); intercepted.newCall(method, initialCallOptions); verify(interceptor).interceptCall( same(method), same(initialCallOptions), Mockito.<Channel>any()); verify(channel).newCall(same(method), same(newCallOptions)); } @Test public void addOutboundHeaders() { final Metadata.Key<String> credKey = Metadata.Key.of("Cred", Metadata.ASCII_STRING_MARSHALLER); ClientInterceptor interceptor = new ClientInterceptor() { @Override public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall( MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) { ClientCall<ReqT, RespT> call = next.newCall(method, callOptions); return new SimpleForwardingClientCall<ReqT, RespT>(call) { @Override public void start(ClientCall.Listener<RespT> responseListener, Metadata headers) { headers.put(credKey, "abcd"); super.start(responseListener, headers); } }; } }; Channel intercepted = ClientInterceptors.intercept(channel, interceptor); @SuppressWarnings("unchecked") ClientCall.Listener<Void> listener = mock(ClientCall.Listener.class); ClientCall<Void, Void> interceptedCall = intercepted.newCall(method, CallOptions.DEFAULT); // start() on the intercepted call will eventually reach the call created by the real channel interceptedCall.start(listener, new Metadata()); // The headers passed to the real channel call will contain the information inserted by the // interceptor. assertSame(listener, call.listener); assertEquals("abcd", call.headers.get(credKey)); } @Test public void examineInboundHeaders() { final List<Metadata> examinedHeaders = new ArrayList<Metadata>(); ClientInterceptor interceptor = new ClientInterceptor() { @Override public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall( MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) { ClientCall<ReqT, RespT> call = next.newCall(method, callOptions); return new SimpleForwardingClientCall<ReqT, RespT>(call) { @Override public void start(ClientCall.Listener<RespT> responseListener, Metadata headers) { super.start(new SimpleForwardingClientCallListener<RespT>(responseListener) { @Override public void onHeaders(Metadata headers) { examinedHeaders.add(headers); super.onHeaders(headers); } }, headers); } }; } }; Channel intercepted = ClientInterceptors.intercept(channel, interceptor); @SuppressWarnings("unchecked") ClientCall.Listener<Void> listener = mock(ClientCall.Listener.class); ClientCall<Void, Void> interceptedCall = intercepted.newCall(method, CallOptions.DEFAULT); interceptedCall.start(listener, new Metadata()); // Capture the underlying call listener that will receive headers from the transport. Metadata inboundHeaders = new Metadata(); // Simulate that a headers arrives on the underlying call listener. call.listener.onHeaders(inboundHeaders); assertThat(examinedHeaders).contains(inboundHeaders); } @Test public void normalCall() { ClientInterceptor interceptor = new ClientInterceptor() { @Override public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall( MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) { ClientCall<ReqT, RespT> call = next.newCall(method, callOptions); return new SimpleForwardingClientCall<ReqT, RespT>(call) { }; } }; Channel intercepted = ClientInterceptors.intercept(channel, interceptor); ClientCall<Void, Void> interceptedCall = intercepted.newCall(method, CallOptions.DEFAULT); assertNotSame(call, interceptedCall); @SuppressWarnings("unchecked") ClientCall.Listener<Void> listener = mock(ClientCall.Listener.class); Metadata headers = new Metadata(); interceptedCall.start(listener, headers); assertSame(listener, call.listener); assertSame(headers, call.headers); interceptedCall.sendMessage(null /*request*/); assertThat(call.messages).containsExactly((Void) null /*request*/); interceptedCall.halfClose(); assertTrue(call.halfClosed); interceptedCall.request(1); assertThat(call.requests).containsExactly(1); } @Test public void exceptionInStart() { final Exception error = new Exception("emulated error"); ClientInterceptor interceptor = new ClientInterceptor() { @Override public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall( MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) { ClientCall<ReqT, RespT> call = next.newCall(method, callOptions); return new CheckedForwardingClientCall<ReqT, RespT>(call) { @Override protected void checkedStart(ClientCall.Listener<RespT> responseListener, Metadata headers) throws Exception { throw error; // delegate().start will not be called } }; } }; Channel intercepted = ClientInterceptors.intercept(channel, interceptor); @SuppressWarnings("unchecked") ClientCall.Listener<Void> listener = mock(ClientCall.Listener.class); ClientCall<Void, Void> interceptedCall = intercepted.newCall(method, CallOptions.DEFAULT); assertNotSame(call, interceptedCall); interceptedCall.start(listener, new Metadata()); interceptedCall.sendMessage(null /*request*/); interceptedCall.halfClose(); interceptedCall.request(1); call.done = true; ArgumentCaptor<Status> captor = ArgumentCaptor.forClass(Status.class); verify(listener).onClose(captor.capture(), any(Metadata.class)); assertSame(error, captor.getValue().getCause()); // Make sure nothing bad happens after the exception. ClientCall<?, ?> noop = ((CheckedForwardingClientCall<?, ?>)interceptedCall).delegate(); // Should not throw, even on bad input noop.cancel("Cancel for test", null); noop.start(null, null); noop.request(-1); noop.halfClose(); noop.sendMessage(null); assertFalse(noop.isReady()); } @Test public void authorityIsDelegated() { ClientInterceptor interceptor = new ClientInterceptor() { @Override public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall( MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) { return next.newCall(method, callOptions); } }; when(channel.authority()).thenReturn("auth"); Channel intercepted = ClientInterceptors.intercept(channel, interceptor); assertEquals("auth", intercepted.authority()); } @Test public void customOptionAccessible() { CallOptions.Key<String> customOption = CallOptions.Key.of("custom", null); CallOptions callOptions = CallOptions.DEFAULT.withOption(customOption, "value"); ArgumentCaptor<CallOptions> passedOptions = ArgumentCaptor.forClass(CallOptions.class); ClientInterceptor interceptor = spy(new NoopInterceptor()); Channel intercepted = ClientInterceptors.intercept(channel, interceptor); assertSame(call, intercepted.newCall(method, callOptions)); verify(channel).newCall(same(method), same(callOptions)); verify(interceptor).interceptCall(same(method), passedOptions.capture(), isA(Channel.class)); assertSame("value", passedOptions.getValue().getOption(customOption)); } private static class NoopInterceptor implements ClientInterceptor { @Override public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) { return next.newCall(method, callOptions); } } private static class BaseClientCall extends ClientCall<String, Integer> { private boolean started; private boolean done; private ClientCall.Listener<Integer> listener; private Metadata headers; private List<Integer> requests = new ArrayList<Integer>(); private List<String> messages = new ArrayList<String>(); private boolean halfClosed; private Throwable cancelCause; private String cancelMessage; @Override public void start(ClientCall.Listener<Integer> listener, Metadata headers) { checkNotDone(); started = true; this.listener = listener; this.headers = headers; } @Override public void request(int numMessages) { checkNotDone(); checkStarted(); requests.add(numMessages); } @Override public void cancel(String message, Throwable cause) { checkNotDone(); this.cancelMessage = message; this.cancelCause = cause; } @Override public void halfClose() { checkNotDone(); checkStarted(); this.halfClosed = true; } @Override public void sendMessage(String message) { checkNotDone(); checkStarted(); messages.add(message); } private void checkNotDone() { if (done) { throw new IllegalStateException("no more methods should be called"); } } private void checkStarted() { if (!started) { throw new IllegalStateException("should have called start"); } } } }