/* * Copyright 2012-2017 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.springframework.boot.actuate.trace; import java.io.IOException; import java.security.Principal; import java.util.Collections; import java.util.Enumeration; import java.util.HashSet; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.TimeUnit; import javax.servlet.Filter; import javax.servlet.FilterChain; import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponseWrapper; import javax.servlet.http.HttpSession; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.springframework.boot.actuate.trace.TraceProperties.Include; import org.springframework.boot.autoconfigure.web.servlet.error.ErrorAttributes; import org.springframework.core.Ordered; import org.springframework.http.HttpStatus; import org.springframework.web.context.request.ServletRequestAttributes; import org.springframework.web.filter.OncePerRequestFilter; /** * Servlet {@link Filter} that logs all requests to a {@link TraceRepository}. * * @author Dave Syer * @author Wallace Wadge * @author Andy Wilkinson * @author Venil Noronha * @author Madhura Bhave */ public class WebRequestTraceFilter extends OncePerRequestFilter implements Ordered { private static final Log logger = LogFactory.getLog(WebRequestTraceFilter.class); private boolean dumpRequests = false; // Not LOWEST_PRECEDENCE, but near the end, so it has a good chance of catching all // enriched headers, but users can add stuff after this if they want to private int order = Ordered.LOWEST_PRECEDENCE - 10; private final TraceRepository repository; private ErrorAttributes errorAttributes; private final TraceProperties properties; /** * Create a new {@link WebRequestTraceFilter} instance. * @param repository the trace repository * @param properties the trace properties */ public WebRequestTraceFilter(TraceRepository repository, TraceProperties properties) { this.repository = repository; this.properties = properties; } /** * Debugging feature. If enabled, and trace logging is enabled then web request * headers will be logged. * @param dumpRequests if requests should be logged */ public void setDumpRequests(boolean dumpRequests) { this.dumpRequests = dumpRequests; } @Override public int getOrder() { return this.order; } public void setOrder(int order) { this.order = order; } @Override protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { long startTime = System.nanoTime(); Map<String, Object> trace = getTrace(request); logTrace(request, trace); int status = HttpStatus.INTERNAL_SERVER_ERROR.value(); try { filterChain.doFilter(request, response); status = response.getStatus(); } finally { addTimeTaken(trace, startTime); enhanceTrace(trace, status == response.getStatus() ? response : new CustomStatusResponseWrapper(response, status)); this.repository.add(trace); } } protected Map<String, Object> getTrace(HttpServletRequest request) { HttpSession session = request.getSession(false); Throwable exception = (Throwable) request .getAttribute("javax.servlet.error.exception"); Principal userPrincipal = request.getUserPrincipal(); Map<String, Object> trace = new LinkedHashMap<>(); Map<String, Object> headers = new LinkedHashMap<>(); trace.put("method", request.getMethod()); trace.put("path", request.getRequestURI()); trace.put("headers", headers); if (isIncluded(Include.REQUEST_HEADERS)) { headers.put("request", getRequestHeaders(request)); } add(trace, Include.PATH_INFO, "pathInfo", request.getPathInfo()); add(trace, Include.PATH_TRANSLATED, "pathTranslated", request.getPathTranslated()); add(trace, Include.CONTEXT_PATH, "contextPath", request.getContextPath()); add(trace, Include.USER_PRINCIPAL, "userPrincipal", (userPrincipal == null ? null : userPrincipal.getName())); if (isIncluded(Include.PARAMETERS)) { trace.put("parameters", getParameterMapCopy(request)); } add(trace, Include.QUERY_STRING, "query", request.getQueryString()); add(trace, Include.AUTH_TYPE, "authType", request.getAuthType()); add(trace, Include.REMOTE_ADDRESS, "remoteAddress", request.getRemoteAddr()); add(trace, Include.SESSION_ID, "sessionId", (session == null ? null : session.getId())); add(trace, Include.REMOTE_USER, "remoteUser", request.getRemoteUser()); if (isIncluded(Include.ERRORS) && exception != null && this.errorAttributes != null) { trace.put("error", this.errorAttributes .getErrorAttributes(new ServletRequestAttributes(request), true)); } return trace; } private Map<String, Object> getRequestHeaders(HttpServletRequest request) { Map<String, Object> headers = new LinkedHashMap<>(); Set<String> excludedHeaders = getExcludeHeaders(); Enumeration<String> names = request.getHeaderNames(); while (names.hasMoreElements()) { String name = names.nextElement(); if (!excludedHeaders.contains(name.toLowerCase())) { headers.put(name, getHeaderValue(request, name)); } } postProcessRequestHeaders(headers); return headers; } private Set<String> getExcludeHeaders() { Set<String> excludedHeaders = new HashSet<>(); if (!isIncluded(Include.COOKIES)) { excludedHeaders.add("cookie"); } if (!isIncluded(Include.AUTHORIZATION_HEADER)) { excludedHeaders.add("authorization"); } return excludedHeaders; } private Object getHeaderValue(HttpServletRequest request, String name) { List<String> value = Collections.list(request.getHeaders(name)); if (value.size() == 1) { return value.get(0); } if (value.isEmpty()) { return ""; } return value; } private Map<String, String[]> getParameterMapCopy(HttpServletRequest request) { return new LinkedHashMap<String, String[]>(request.getParameterMap()); } /** * Post process request headers before they are added to the trace. * @param headers a mutable map containing the request headers to trace * @since 1.4.0 */ protected void postProcessRequestHeaders(Map<String, Object> headers) { } private void addTimeTaken(Map<String, Object> trace, long startTime) { long timeTaken = System.nanoTime() - startTime; add(trace, Include.TIME_TAKEN, "timeTaken", "" + TimeUnit.NANOSECONDS.toMillis(timeTaken)); } @SuppressWarnings("unchecked") protected void enhanceTrace(Map<String, Object> trace, HttpServletResponse response) { if (isIncluded(Include.RESPONSE_HEADERS)) { Map<String, Object> headers = (Map<String, Object>) trace.get("headers"); headers.put("response", getResponseHeaders(response)); } } private Map<String, String> getResponseHeaders(HttpServletResponse response) { Map<String, String> headers = new LinkedHashMap<>(); for (String header : response.getHeaderNames()) { String value = response.getHeader(header); headers.put(header, value); } if (!isIncluded(Include.COOKIES)) { headers.remove("Set-Cookie"); } headers.put("status", String.valueOf(response.getStatus())); return headers; } private void logTrace(HttpServletRequest request, Map<String, Object> trace) { if (logger.isTraceEnabled()) { logger.trace("Processing request " + request.getMethod() + " " + request.getRequestURI()); if (this.dumpRequests) { logger.trace("Headers: " + trace.get("headers")); } } } private void add(Map<String, Object> trace, Include include, String name, Object value) { if (isIncluded(include) && value != null) { trace.put(name, value); } } private boolean isIncluded(Include include) { return this.properties.getInclude().contains(include); } public void setErrorAttributes(ErrorAttributes errorAttributes) { this.errorAttributes = errorAttributes; } private static final class CustomStatusResponseWrapper extends HttpServletResponseWrapper { private final int status; private CustomStatusResponseWrapper(HttpServletResponse response, int status) { super(response); this.status = status; } @Override public int getStatus() { return this.status; } } }