package com.twitter.common.net.http.filters;
import java.io.IOException;
import java.lang.reflect.Field;
import java.util.List;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.ws.rs.core.Context;
import com.google.common.collect.Lists;
import com.sun.jersey.api.core.ExtendedUriInfo;
import com.sun.jersey.api.model.AbstractResourceMethod;
import com.sun.jersey.spi.container.ContainerRequest;
import com.sun.jersey.spi.container.ContainerResponse;
import org.junit.Before;
import org.junit.Test;
import com.twitter.common.collections.Pair;
import com.twitter.common.net.http.filters.HttpStatsFilter.TrackRequestStats;
import com.twitter.common.quantity.Amount;
import com.twitter.common.quantity.Time;
import com.twitter.common.stats.SlidingStats;
import com.twitter.common.testing.easymock.EasyMockTest;
import com.twitter.common.util.testing.FakeClock;
import static org.easymock.EasyMock.anyObject;
import static org.easymock.EasyMock.expect;
import static org.easymock.EasyMock.expectLastCall;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;
public class HttpStatsFilterTest extends EasyMockTest {
private FakeClock clock;
private HttpServletRequest request;
private HttpServletResponse response;
private FilterChain filterChain;
private HttpStatsFilter filter;
private ContainerRequest containerRequest;
private ContainerResponse containerResponse;
private ExtendedUriInfo extendedUriInfo;
private HttpServletRequest servletRequest;
private static final Amount<Long, Time> REQUEST_TIME = Amount.of(1000L, Time.NANOSECONDS);
private void injectContextVars() throws Exception {
extendedUriInfo = createMock(ExtendedUriInfo.class);
servletRequest = createMock(HttpServletRequest.class);
List<Object> injectables = Lists.newArrayList(extendedUriInfo, servletRequest);
for (Field f : filter.getClass().getDeclaredFields()) {
if (f.isAnnotationPresent(Context.class)) {
for (Object injectable : injectables) {
if (f.getType().isInstance(injectable)) {
f.setAccessible(true);
f.set(filter, injectable);
}
}
}
}
}
@Before
public void setUp() throws Exception {
clock = new FakeClock();
request = createMock(HttpServletRequest.class);
response = createMock(HttpServletResponse.class);
filterChain = createMock(FilterChain.class);
filter = new HttpStatsFilter(clock);
containerRequest = createMock(ContainerRequest.class);
containerResponse = createMock(ContainerResponse.class);
injectContextVars();
}
@Test
public void testStartTimeIsSetAsRequestAttribute() throws Exception {
request.setAttribute(HttpStatsFilter.REQUEST_START_TIME, REQUEST_TIME.getValue());
filterChain.doFilter(request, response);
control.replay();
clock.advance(REQUEST_TIME);
filter.doFilter(request, response, filterChain);
}
@Test
public void testExceptionStatsCounting() throws Exception {
request.setAttribute(HttpStatsFilter.REQUEST_START_TIME, REQUEST_TIME.getValue());
expectLastCall().times(2);
clock.advance(REQUEST_TIME);
filterChain.doFilter(anyObject(HttpServletRequest.class), anyObject(HttpServletResponse.class));
expectLastCall().andThrow(new IOException());
filterChain.doFilter(anyObject(HttpServletRequest.class), anyObject(HttpServletResponse.class));
expectLastCall().andThrow(new ServletException());
control.replay();
try {
filter.doFilter(request, response, filterChain);
fail("Filter should have re-thrown the exception.");
} catch (IOException e) {
// Exception is expected, but we still want to assert on the stat tracking, so we can't
// just use @Test(expected...)
assertEquals(1, filter.exceptionCount.get());
}
try {
filter.doFilter(request, response, filterChain);
fail("Filter should have re-thrown the exception.");
} catch (ServletException e) {
// See above.
assertEquals(2, filter.exceptionCount.get());
}
}
private void expectAnnotationValue(String value, int times) {
AbstractResourceMethod matchedMethod = createMock(AbstractResourceMethod.class);
expect(extendedUriInfo.getMatchedMethod()).andReturn(matchedMethod).times(times);
TrackRequestStats annotation = createMock(TrackRequestStats.class);
expect(matchedMethod.getAnnotation(TrackRequestStats.class)).andReturn(annotation).times(times);
expect(annotation.value()).andReturn(value).times(times);
}
private void expectAnnotationValue(String value) {
expectAnnotationValue(value, 1);
}
@Test
public void testBasicStatsCounting() throws Exception {
expect(containerResponse.getStatus()).andReturn(HttpServletResponse.SC_OK);
expect(servletRequest.getAttribute(HttpStatsFilter.REQUEST_START_TIME))
.andReturn(clock.nowNanos());
String value = "some_value";
expectAnnotationValue(value);
control.replay();
clock.advance(REQUEST_TIME);
assertEquals(containerResponse, filter.filter(containerRequest, containerResponse));
SlidingStats stat = filter.requestCounters.get(Pair.of(value, HttpServletResponse.SC_OK));
assertEquals(1, stat.getEventCounter().get());
assertEquals(REQUEST_TIME.getValue().longValue(), stat.getTotalCounter().get());
assertEquals(1, filter.statusCounters.get(HttpServletResponse.SC_OK).getEventCounter().get());
}
@Test
public void testMultipleRequests() throws Exception {
int numCalls = 2;
expect(containerResponse.getStatus()).andReturn(HttpServletResponse.SC_OK).times(numCalls);
expect(servletRequest.getAttribute(HttpStatsFilter.REQUEST_START_TIME))
.andReturn(clock.nowNanos()).times(numCalls);
String value = "some_value";
expectAnnotationValue(value, numCalls);
control.replay();
clock.advance(REQUEST_TIME);
for (int i = 0; i < numCalls; i++) {
filter.filter(containerRequest, containerResponse);
}
SlidingStats stat = filter.requestCounters.get(Pair.of(value, HttpServletResponse.SC_OK));
assertEquals(numCalls, stat.getEventCounter().get());
assertEquals(REQUEST_TIME.getValue() * numCalls, stat.getTotalCounter().get());
assertEquals(numCalls,
filter.statusCounters.get(HttpServletResponse.SC_OK).getEventCounter().get());
}
@Test
public void testNoStartTime() throws Exception {
expect(servletRequest.getAttribute(HttpStatsFilter.REQUEST_START_TIME))
.andReturn(null);
expect(containerResponse.getStatus()).andReturn(HttpServletResponse.SC_OK);
control.replay();
assertEquals(containerResponse, filter.filter(containerRequest, containerResponse));
assertEquals(0, filter.statusCounters.asMap().keySet().size());
}
@Test
public void testNoMatchedMethod() throws Exception {
expect(containerResponse.getStatus()).andReturn(HttpServletResponse.SC_NOT_FOUND);
expect(servletRequest.getAttribute(HttpStatsFilter.REQUEST_START_TIME))
.andReturn(clock.nowNanos());
expect(extendedUriInfo.getMatchedMethod()).andReturn(null);
control.replay();
clock.advance(REQUEST_TIME);
assertEquals(containerResponse, filter.filter(containerRequest, containerResponse));
assertEquals(1,
filter.statusCounters.get(HttpServletResponse.SC_NOT_FOUND).getEventCounter().get());
}
}