package com.twitter.common.net.http.filters; import java.io.IOException; import java.lang.annotation.ElementType; import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; import java.lang.reflect.Method; import java.util.concurrent.atomic.AtomicLong; import java.util.logging.Logger; import javax.servlet.FilterChain; import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import javax.ws.rs.core.Context; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import com.google.common.cache.CacheBuilder; import com.google.common.cache.CacheLoader; import com.google.common.cache.LoadingCache; import com.google.inject.Inject; import com.sun.jersey.api.core.ExtendedUriInfo; import com.sun.jersey.api.model.AbstractResourceMethod; import com.sun.jersey.spi.container.ContainerRequest; import com.sun.jersey.spi.container.ContainerResponse; import com.sun.jersey.spi.container.ContainerResponseFilter; import com.twitter.common.collections.Pair; import com.twitter.common.stats.SlidingStats; import com.twitter.common.stats.Stats; import com.twitter.common.util.Clock; /** * An HTTP filter that exports counts and timing for requests based on response code. */ public class HttpStatsFilter extends AbstractHttpFilter implements ContainerResponseFilter { /** * Methods tagged with this annotation will be intercepted and stats will be tracked accordingly. */ @Retention(RetentionPolicy.RUNTIME) @Target(ElementType.METHOD) public @interface TrackRequestStats { /** * Indicates the identifier to use when tracking requests with this annotation. */ String value(); } private static final Logger LOG = Logger.getLogger(HttpStatsFilter.class.getName()); @VisibleForTesting static final String REQUEST_START_TIME = "request_start_time"; private final Clock clock; @Context private ExtendedUriInfo extendedUriInfo; @VisibleForTesting final LoadingCache<Pair<String, Integer>, SlidingStats> requestCounters = CacheBuilder.newBuilder() .build(new CacheLoader<Pair<String, Integer>, SlidingStats>() { @Override public SlidingStats load(Pair<String, Integer> identifierAndStatus) { return new SlidingStats("http_" + identifierAndStatus.getFirst() + "_" + identifierAndStatus.getSecond() + "_responses", "nanos"); } }); @Context private HttpServletRequest servletRequest; @VisibleForTesting final LoadingCache<Integer, SlidingStats> statusCounters = CacheBuilder.newBuilder() .build(new CacheLoader<Integer, SlidingStats>() { @Override public SlidingStats load(Integer status) { return new SlidingStats("http_" + status + "_responses", "nanos"); } }); @VisibleForTesting final AtomicLong exceptionCount = Stats.exportLong("http_request_exceptions"); @Inject public HttpStatsFilter(Clock clock) { this.clock = Preconditions.checkNotNull(clock); } private void trackStats(int status) { long endTime = clock.nowNanos(); Object startTimeAttribute = servletRequest.getAttribute(REQUEST_START_TIME); if (startTimeAttribute == null) { LOG.fine("No start time attribute was found on the request, this filter should be wired" + " as both a servlet filter and a container filter."); return; } long elapsed = endTime - ((Long) startTimeAttribute).longValue(); statusCounters.getUnchecked(status).accumulate(elapsed); AbstractResourceMethod matchedMethod = extendedUriInfo.getMatchedMethod(); // It's possible for no method to have matched, e.g. in the case of a 404, don't let those // cases lead to an exception and a 500 response. if (matchedMethod == null) { return; } TrackRequestStats trackRequestStats = matchedMethod.getAnnotation(TrackRequestStats.class); if (trackRequestStats == null) { Method method = matchedMethod.getMethod(); LOG.fine("The method that handled this request (" + method.getDeclaringClass() + "#" + method.getName() + ") is not annotated with " + TrackRequestStats.class.getSimpleName() + ". No request stats will recorded."); return; } requestCounters.getUnchecked(Pair.of(trackRequestStats.value(), status)).accumulate(elapsed); } @Override public ContainerResponse filter(ContainerRequest request, ContainerResponse response) { trackStats(response.getStatus()); return response; } @Override public void doFilter(HttpServletRequest request, HttpServletResponse response, FilterChain chain) throws IOException, ServletException { long startTime = clock.nowNanos(); request.setAttribute(REQUEST_START_TIME, startTime); try { chain.doFilter(request, response); } catch (IOException e) { exceptionCount.incrementAndGet(); throw e; } catch (ServletException e) { exceptionCount.incrementAndGet(); throw e; } } }