package com.linkedin.restli.internal.server.filter; import com.linkedin.r2.message.rest.RestRequest; import com.linkedin.restli.common.attachments.RestLiAttachmentReader; import com.linkedin.restli.internal.server.RestLiCallback; import com.linkedin.restli.internal.server.RoutingResult; import com.linkedin.restli.internal.server.filter.testfilters.CountFilter; import com.linkedin.restli.internal.server.filter.testfilters.CountFilterRequestErrorOnError; import com.linkedin.restli.internal.server.filter.testfilters.CountFilterRequestErrorThrowsError; import com.linkedin.restli.internal.server.filter.testfilters.CountFilterRequestOnError; import com.linkedin.restli.internal.server.filter.testfilters.CountFilterRequestThrowsError; import com.linkedin.restli.internal.server.filter.testfilters.CountFilterResponseErrorFixesError; import com.linkedin.restli.internal.server.filter.testfilters.CountFilterResponseErrorOnError; import com.linkedin.restli.internal.server.filter.testfilters.CountFilterResponseErrorThrowsError; import com.linkedin.restli.internal.server.filter.testfilters.CountFilterResponseOnError; import com.linkedin.restli.internal.server.filter.testfilters.CountFilterResponseThrowsError; import com.linkedin.restli.internal.server.filter.testfilters.TestFilterException; import com.linkedin.restli.internal.server.response.RestLiResponseDataImpl; import com.linkedin.restli.internal.server.response.RestLiResponseHandler; import com.linkedin.restli.server.RestLiRequestData; import com.linkedin.restli.server.RestLiResponseAttachments; import com.linkedin.restli.server.RestLiResponseData; import com.linkedin.restli.server.filter.FilterRequestContext; import com.linkedin.restli.server.filter.FilterResponseContext; import java.util.Arrays; import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; import org.testng.annotations.AfterMethod; import org.testng.annotations.BeforeClass; import org.testng.annotations.BeforeMethod; import org.testng.annotations.Test; import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyList; import static org.mockito.Matchers.anyMap; import static org.mockito.Matchers.eq; import static org.mockito.Matchers.isNull; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.reset; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; import static org.testng.Assert.assertEquals; /** * Tests the propagation of the RestLiFilterChain * Test will call the filter chain with and without errors. Tests include filters which correct, propagate, and create * new errors. Based on the expected behavior, the tests use the number of times each method was invoked as the * determining factor as to whether the chain is propagating correctly or not * * @author gye */ public class TestRestLiFilterChain { @Mock private RestLiRequestData _mockRestLiRequestData; @Mock private FilterChainCallback _mockFilterChainCallback; @Mock private RestLiResponseDataImpl _mockRestLiResponseData; @Mock private FilterRequestContext _mockFilterRequestContext; @Mock private FilterResponseContext _mockFilterResponseContext; @Mock private RestLiAttachmentReader _mockRequestAttachmentReader; @Mock private RestLiResponseAttachments _mockResponseAttachments; @Mock private RestLiFilterResponseContextFactory<Object> _mockFilterResponseContextFactory; @Mock private RestRequest _request; @Mock private RoutingResult _method; @Mock private RestLiResponseHandler _responseHandler; private CountFilter[] _filters; private RestLiFilterChain _restLiFilterChain; @BeforeClass protected void setUp() { MockitoAnnotations.initMocks(this); } @BeforeMethod protected void init() { _filters = new CountFilter[] { new CountFilter(), new CountFilter(), new CountFilter() }; } @AfterMethod protected void resetMocks() { reset(_mockFilterChainCallback, _mockFilterRequestContext, _mockFilterResponseContext, _mockRestLiRequestData, _mockRestLiResponseData); _filters = new CountFilter[] { new CountFilter(), new CountFilter(), new CountFilter() }; } @SuppressWarnings(value="unchecked") @Test public void testFilterInvocationSuccess() throws Exception { _restLiFilterChain = new RestLiFilterChain(Arrays.asList(_filters), _mockFilterChainCallback); doAnswer(new Answer<Object>() { @Override public Object answer(InvocationOnMock invocation) throws Throwable { _restLiFilterChain.onResponse(_mockFilterRequestContext, _mockFilterResponseContext, _mockResponseAttachments); return null; } }).when(_mockFilterChainCallback).onRequestSuccess(eq(_mockRestLiRequestData), any(RestLiCallback.class)); when(_mockFilterRequestContext.getRequestData()).thenReturn(_mockRestLiRequestData); when(_mockFilterResponseContext.getResponseData()).thenReturn(_mockRestLiResponseData); when(_mockFilterResponseContextFactory.fromThrowable(any(Throwable.class))).thenReturn(_mockFilterResponseContext); _restLiFilterChain.onRequest(_mockFilterRequestContext, _mockFilterResponseContextFactory); for(CountFilter filter : _filters) { assertEquals(filter.getNumRequests(), 1); assertEquals(filter.getNumResponses(), 1); assertEquals(filter.getNumErrors(), 0); } verify(_mockFilterRequestContext).getRequestData(); verify(_mockFilterChainCallback).onRequestSuccess(eq(_mockRestLiRequestData), any(RestLiCallback.class)); verify(_mockFilterChainCallback).onResponseSuccess(_mockRestLiResponseData, _mockResponseAttachments); verifyNoMoreInteractions(_mockFilterChainCallback, _mockFilterRequestContext, _mockRestLiRequestData); } @SuppressWarnings("unchecked") @Test public void testFilterInvocationRequestOnError() throws Exception { _restLiFilterChain = new RestLiFilterChain(Arrays.asList(_filters), _mockFilterChainCallback); _filters[1] = new CountFilterRequestOnError(); when(_responseHandler.buildExceptionResponseData(eq(_request), eq(_method), any(Object.class), anyMap(), anyList())) .thenReturn(_mockRestLiResponseData); when(_mockFilterResponseContextFactory.fromThrowable(any(Throwable.class))).thenReturn(_mockFilterResponseContext); when(_mockFilterResponseContext.getResponseData()).thenReturn(_mockRestLiResponseData); _restLiFilterChain.onRequest(_mockFilterRequestContext, new RestLiFilterResponseContextFactory<Object>(_request, _method, _responseHandler)); verifySecondFilterRequestException(); } @SuppressWarnings("unchecked") @Test public void testFilterInvocationRequestErrorOnError() throws Exception { _restLiFilterChain = new RestLiFilterChain(Arrays.asList(_filters), _mockFilterChainCallback); _filters[1] = new CountFilterRequestErrorOnError(); when(_responseHandler.buildExceptionResponseData(eq(_request), eq(_method), any(Object.class), anyMap(), anyList())) .thenReturn(_mockRestLiResponseData); _restLiFilterChain.onRequest(_mockFilterRequestContext, new RestLiFilterResponseContextFactory<Object>(_request, _method, _responseHandler)); verifySecondFilterRequestException(); } @SuppressWarnings("unchecked") @Test public void testFilterInvocationRequestThrowsError() throws Exception { _restLiFilterChain = new RestLiFilterChain(Arrays.asList(_filters), _mockFilterChainCallback); _filters[1] = new CountFilterRequestThrowsError(); when(_responseHandler.buildExceptionResponseData(eq(_request), eq(_method), any(Object.class), anyMap(), anyList())) .thenReturn(_mockRestLiResponseData); _restLiFilterChain.onRequest(_mockFilterRequestContext, new RestLiFilterResponseContextFactory<Object>(_request, _method, _responseHandler)); verifySecondFilterRequestException(); } @SuppressWarnings("unchecked") @Test public void testFilterInvocationRequestErrorThrowsError() throws Exception { _restLiFilterChain = new RestLiFilterChain(Arrays.asList(_filters), _mockFilterChainCallback); _filters[1] = new CountFilterRequestErrorThrowsError(); when(_responseHandler.buildExceptionResponseData(eq(_request), eq(_method), any(Object.class), anyMap(), anyList())) .thenReturn(_mockRestLiResponseData); _restLiFilterChain.onRequest(_mockFilterRequestContext, new RestLiFilterResponseContextFactory<Object>(_request, _method, _responseHandler)); verifySecondFilterRequestException(); } private void verifySecondFilterRequestException() { assertFilterCounts(_filters[0], 1, 0, 1); assertFilterCounts(_filters[1], 1, 0, 1); assertFilterCounts(_filters[2], 0, 0, 0); verify(_mockFilterChainCallback).onError(any(TestFilterException.class), any(RestLiResponseData.class), isNull(RestLiResponseAttachments.class)); verify(_mockRestLiResponseData, times(2)).setException(any(Throwable.class)); verifyNoMoreInteractions(_mockFilterChainCallback, _mockFilterRequestContext, _mockRestLiRequestData); } @SuppressWarnings(value="unchecked") @Test public void testFilterInvocationResponseOnError() throws Exception { _restLiFilterChain = new RestLiFilterChain(Arrays.asList(_filters), _mockFilterChainCallback); _filters[1] = new CountFilterResponseOnError(); doAnswer(new Answer<Object>() { @Override public Object answer(InvocationOnMock invocation) throws Throwable { _restLiFilterChain.onResponse(_mockFilterRequestContext, _mockFilterResponseContext, _mockResponseAttachments); return null; } }).when(_mockFilterChainCallback).onRequestSuccess(eq(_mockRestLiRequestData), any(RestLiCallback.class)); when(_mockFilterRequestContext.getRequestData()).thenReturn(_mockRestLiRequestData); when(_mockFilterResponseContext.getResponseData()).thenReturn(_mockRestLiResponseData); when(_mockFilterResponseContextFactory.fromThrowable(any(Throwable.class))).thenReturn(_mockFilterResponseContext); _restLiFilterChain.onRequest(_mockFilterRequestContext, _mockFilterResponseContextFactory); verifySecondFilterResponseException(); } @SuppressWarnings("unchecked") @Test public void testFilterInvocationResponseErrorOnError() throws Exception { _restLiFilterChain = new RestLiFilterChain(Arrays.asList(_filters), _mockFilterChainCallback); _filters[1] = new CountFilterResponseErrorOnError(); doAnswer(new Answer<Object>() { @Override public Object answer(InvocationOnMock invocation) throws Throwable { _restLiFilterChain.onResponse(_mockFilterRequestContext, _mockFilterResponseContext, _mockResponseAttachments); return null; } }).when(_mockFilterChainCallback).onRequestSuccess(eq(_mockRestLiRequestData), any(RestLiCallback.class)); when(_mockFilterRequestContext.getRequestData()).thenReturn(_mockRestLiRequestData); when(_mockFilterResponseContext.getResponseData()).thenReturn(_mockRestLiResponseData); when(_responseHandler.buildExceptionResponseData(eq(_request), eq(_method), any(Object.class), anyMap(), anyList())) .thenReturn(_mockRestLiResponseData); when(_mockFilterResponseContextFactory.fromThrowable(any(Throwable.class))).thenReturn(_mockFilterResponseContext); _restLiFilterChain.onRequest(_mockFilterRequestContext, _mockFilterResponseContextFactory); verifySecondFilterResponseException(); } @SuppressWarnings(value="unchecked") @Test public void testFilterInvocationResponseThrowsError() throws Exception { _restLiFilterChain = new RestLiFilterChain(Arrays.asList(_filters), _mockFilterChainCallback); _filters[1] = new CountFilterResponseThrowsError(); doAnswer(new Answer<Object>() { @Override public Object answer(InvocationOnMock invocation) throws Throwable { _restLiFilterChain.onResponse(_mockFilterRequestContext, _mockFilterResponseContext, _mockResponseAttachments); return null; } }).when(_mockFilterChainCallback).onRequestSuccess(eq(_mockRestLiRequestData), any(RestLiCallback.class)); when(_mockFilterRequestContext.getRequestData()).thenReturn(_mockRestLiRequestData); when(_mockFilterResponseContext.getResponseData()).thenReturn(_mockRestLiResponseData); when(_mockFilterResponseContextFactory.fromThrowable(any(Throwable.class))).thenReturn(_mockFilterResponseContext); _restLiFilterChain.onRequest(_mockFilterRequestContext, _mockFilterResponseContextFactory); verifySecondFilterResponseException(); } @SuppressWarnings(value="unchecked") @Test public void testFilterInvocationResponseErrorThrowsError() throws Exception { _restLiFilterChain = new RestLiFilterChain(Arrays.asList(_filters), _mockFilterChainCallback); _filters[1] = new CountFilterResponseErrorThrowsError(); doAnswer(new Answer<Object>() { @Override public Object answer(InvocationOnMock invocation) throws Throwable { _restLiFilterChain.onResponse(_mockFilterRequestContext, _mockFilterResponseContext, _mockResponseAttachments); return null; } }).when(_mockFilterChainCallback).onRequestSuccess(eq(_mockRestLiRequestData), any(RestLiCallback.class)); when(_mockFilterRequestContext.getRequestData()).thenReturn(_mockRestLiRequestData); when(_mockFilterResponseContext.getResponseData()).thenReturn(_mockRestLiResponseData); when(_mockFilterResponseContextFactory.fromThrowable(any(Throwable.class))).thenReturn(_mockFilterResponseContext); _restLiFilterChain.onRequest(_mockFilterRequestContext, _mockFilterResponseContextFactory); verifySecondFilterResponseException(); } @SuppressWarnings(value="unchecked") private void verifySecondFilterResponseException() { assertFilterCounts(_filters[0], 1, 0, 1); assertFilterCounts(_filters[1], 1, 1, 0); assertFilterCounts(_filters[2], 1, 1, 0); verify(_mockFilterChainCallback).onRequestSuccess(eq(_mockRestLiRequestData), any(RestLiCallback.class)); verify(_mockFilterChainCallback).onError(any(TestFilterException.class), eq(_mockRestLiResponseData), eq(_mockResponseAttachments)); verify(_mockFilterRequestContext).getRequestData(); verify(_mockFilterResponseContext, times(5)).getResponseData(); verifyNoMoreInteractions(_mockFilterChainCallback, _mockFilterRequestContext, _mockRestLiRequestData); } @SuppressWarnings(value="unchecked") @Test public void testFilterInvocationResponseErrorFixesError() throws Exception { _restLiFilterChain = new RestLiFilterChain(Arrays.asList(_filters), _mockFilterChainCallback); _filters[1] = new CountFilterResponseErrorFixesError(); _filters[2] = new CountFilterResponseErrorOnError(); doAnswer(new Answer<Object>() { @Override public Object answer(InvocationOnMock invocation) throws Throwable { _restLiFilterChain.onResponse(_mockFilterRequestContext, _mockFilterResponseContext, _mockResponseAttachments); return null; } }).when(_mockFilterChainCallback).onRequestSuccess(eq(_mockRestLiRequestData), any(RestLiCallback.class)); when(_mockFilterRequestContext.getRequestData()).thenReturn(_mockRestLiRequestData); when(_mockFilterResponseContext.getResponseData()).thenReturn(_mockRestLiResponseData); _restLiFilterChain.onRequest(_mockFilterRequestContext, _mockFilterResponseContextFactory); assertFilterCounts(_filters[0], 1, 1, 0); assertFilterCounts(_filters[1], 1, 0, 1); assertFilterCounts(_filters[2], 1, 1, 0); verify(_mockFilterChainCallback).onRequestSuccess(eq(_mockRestLiRequestData), any(RestLiCallback.class)); verify(_mockFilterChainCallback).onResponseSuccess(eq(_mockRestLiResponseData), eq(_mockResponseAttachments)); verify(_mockFilterRequestContext).getRequestData(); verify(_mockFilterResponseContext, times(3)).getResponseData(); verifyNoMoreInteractions(_mockFilterChainCallback, _mockFilterRequestContext, _mockRestLiRequestData); } @SuppressWarnings(value="unchecked") @Test public void testFilterInvocationLastResponseErrorFixesError() throws Exception { _restLiFilterChain = new RestLiFilterChain(Arrays.asList(_filters), _mockFilterChainCallback); _filters[0] = new CountFilterResponseErrorFixesError(); _filters[1] = new CountFilterResponseErrorOnError(); doAnswer(new Answer<Object>() { @Override public Object answer(InvocationOnMock invocation) throws Throwable { _restLiFilterChain.onResponse(_mockFilterRequestContext, _mockFilterResponseContext, _mockResponseAttachments); return null; } }).when(_mockFilterChainCallback).onRequestSuccess(eq(_mockRestLiRequestData), any(RestLiCallback.class)); when(_mockFilterRequestContext.getRequestData()).thenReturn(_mockRestLiRequestData); when(_mockFilterResponseContext.getResponseData()).thenReturn(_mockRestLiResponseData); _restLiFilterChain.onRequest(_mockFilterRequestContext, _mockFilterResponseContextFactory); assertFilterCounts(_filters[0], 1, 0, 1); assertFilterCounts(_filters[1], 1, 1, 0); assertFilterCounts(_filters[2], 1, 1, 0); verify(_mockFilterChainCallback).onRequestSuccess(eq(_mockRestLiRequestData), any(RestLiCallback.class)); verify(_mockFilterChainCallback).onResponseSuccess(eq(_mockRestLiResponseData), eq(_mockResponseAttachments)); verify(_mockFilterRequestContext).getRequestData(); verify(_mockFilterResponseContext, times(3)).getResponseData(); verifyNoMoreInteractions(_mockFilterChainCallback, _mockFilterRequestContext, _mockRestLiRequestData); } @SuppressWarnings(value="unchecked") @Test public void testFilterInvocationOnError() throws Exception { _restLiFilterChain = new RestLiFilterChain(Arrays.asList(_filters), _mockFilterChainCallback); doAnswer(new Answer<Object>() { @Override public Object answer(InvocationOnMock invocation) throws Throwable { _restLiFilterChain.onError(new TestFilterException(), _mockFilterRequestContext, _mockFilterResponseContext, _mockResponseAttachments); return null; } }).when(_mockFilterChainCallback).onRequestSuccess(eq(_mockRestLiRequestData), any(RestLiCallback.class)); when(_mockFilterRequestContext.getRequestData()).thenReturn(_mockRestLiRequestData); when(_mockFilterResponseContext.getResponseData()).thenReturn(_mockRestLiResponseData); _restLiFilterChain.onRequest(_mockFilterRequestContext, _mockFilterResponseContextFactory); assertFilterCounts(_filters[0], 1, 0, 1); assertFilterCounts(_filters[1], 1, 0, 1); assertFilterCounts(_filters[2], 1, 0, 1); verify(_mockFilterChainCallback).onRequestSuccess(eq(_mockRestLiRequestData), any(RestLiCallback.class)); verify(_mockFilterChainCallback).onError(any(TestFilterException.class), eq(_mockRestLiResponseData), eq(_mockResponseAttachments)); verify(_mockFilterRequestContext).getRequestData(); verify(_mockFilterResponseContext, times(7)).getResponseData(); verifyNoMoreInteractions(_mockFilterChainCallback, _mockRequestAttachmentReader, _mockFilterRequestContext, _mockRestLiRequestData); } @SuppressWarnings(value="unchecked") @Test public void testNoFilters() throws Exception { final RestLiFilterChain emptyFilterChain = new RestLiFilterChain(_mockFilterChainCallback); doAnswer(new Answer<Object>() { @Override public Object answer(InvocationOnMock invocation) throws Throwable { emptyFilterChain.onResponse(_mockFilterRequestContext, _mockFilterResponseContext, _mockResponseAttachments); return null; } }).when(_mockFilterChainCallback).onRequestSuccess(eq(_mockRestLiRequestData), any(RestLiCallback.class)); when(_mockFilterRequestContext.getRequestData()).thenReturn(_mockRestLiRequestData); when(_mockFilterResponseContext.getResponseData()).thenReturn(_mockRestLiResponseData); when(_mockFilterResponseContextFactory.fromThrowable(any(Throwable.class))).thenReturn(_mockFilterResponseContext); emptyFilterChain.onRequest(_mockFilterRequestContext, _mockFilterResponseContextFactory); verify(_mockFilterChainCallback).onRequestSuccess(eq(_mockRestLiRequestData), any(RestLiCallback.class)); verify(_mockFilterChainCallback).onResponseSuccess(_mockRestLiResponseData, _mockResponseAttachments); verify(_mockFilterRequestContext).getRequestData(); verify(_mockFilterResponseContext).getResponseData(); verifyNoMoreInteractions(_mockFilterChainCallback, _mockFilterRequestContext, _mockRestLiRequestData); } private void assertFilterCounts(CountFilter filter, int expectedNumRequests, int expectedNumResponses, int expectedNumErrors) { assertEquals(filter.getNumRequests(), expectedNumRequests); assertEquals(filter.getNumResponses(), expectedNumResponses); assertEquals(filter.getNumErrors(), expectedNumErrors); } }