/*
* Copyright 2011-2016 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.glowroot.agent.plugin.servlet;
import java.util.Collections;
import java.util.Enumeration;
import java.util.List;
import java.util.Map;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import com.google.common.collect.Lists;
import org.junit.After;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import org.glowroot.agent.it.harness.Container;
import org.glowroot.agent.it.harness.Containers;
import org.glowroot.wire.api.model.TraceOuterClass.Trace;
import static org.assertj.core.api.Assertions.assertThat;
public class RequestHeaderIT {
private static final String PLUGIN_ID = "servlet";
private static Container container;
@BeforeClass
public static void setUp() throws Exception {
container = Containers.create();
}
@AfterClass
public static void tearDown() throws Exception {
container.close();
}
@After
public void afterEachTest() throws Exception {
container.checkAndReset();
}
@Test
public void testStandardRequestHeaders() throws Exception {
// given
container.getConfigService().setPluginProperty(PLUGIN_ID, "captureRequestHeaders",
"Content-Type, Content-Length");
// when
Trace trace = container.execute(SetStandardRequestHeaders.class);
// then
Map<String, Object> requestHeaders =
ResponseHeaderIT.getDetailMap(trace, "Request headers");
assertThat(requestHeaders.get("Content-Type")).isEqualTo("text/plain;charset=UTF-8");
assertThat(requestHeaders.get("Content-Length")).isEqualTo("1");
assertThat(requestHeaders.get("Extra")).isNull();
}
@Test
public void testStandardRequestHeadersLowercase() throws Exception {
// given
container.getConfigService().setPluginProperty(PLUGIN_ID, "captureRequestHeaders",
"Content-Type, Content-Length");
// when
Trace trace = container.execute(SetStandardRequestHeadersLowercase.class);
// then
Map<String, Object> requestHeaders =
ResponseHeaderIT.getDetailMap(trace, "Request headers");
assertThat(requestHeaders.get("Content-Type")).isEqualTo("text/plain;charset=UTF-8");
assertThat(requestHeaders.get("content-length")).isEqualTo("1");
assertThat(requestHeaders.get("extra")).isNull();
}
@Test
public void testLotsOfRequestHeaders() throws Exception {
// given
container.getConfigService().setPluginProperty(PLUGIN_ID, "captureRequestHeaders",
"One,Two");
// when
Trace trace = container.execute(SetOtherRequestHeaders.class);
// then
Map<String, Object> requestHeaders =
ResponseHeaderIT.getDetailMap(trace, "Request headers");
@SuppressWarnings("unchecked")
List<String> one = (List<String>) requestHeaders.get("One");
assertThat(one).containsExactly("ab", "xy");
assertThat(requestHeaders.get("Two")).isEqualTo("1");
assertThat(requestHeaders.get("Three")).isNull();
}
@Test
public void testBadRequestHeaders() throws Exception {
// given
container.getConfigService().setPluginProperty(PLUGIN_ID, "captureRequestHeaders",
"Content-Type, Content-Length");
// when
Trace trace = container.execute(GetBadRequestHeaders.class);
// then
Map<String, Object> requestHeaders =
ResponseHeaderIT.getDetailMap(trace, "Request headers");
assertThat(requestHeaders).isNull();
}
@Test
public void testBadRequestHeaders2() throws Exception {
// given
container.getConfigService().setPluginProperty(PLUGIN_ID, "captureRequestHeaders",
"Content-Type, Content-Length, h1");
// when
Trace trace = container.execute(GetBadRequestHeaders2.class);
// then
Map<String, Object> requestHeaders =
ResponseHeaderIT.getDetailMap(trace, "Request headers");
assertThat(requestHeaders).hasSize(1);
assertThat(requestHeaders.get("h1")).isEqualTo("");
}
@SuppressWarnings("serial")
public static class SetStandardRequestHeaders extends TestServlet {
@Override
protected void before(HttpServletRequest request, HttpServletResponse response) {
((MockHttpServletRequest) request).addHeader("Content-Type",
"text/plain;charset=UTF-8");
((MockHttpServletRequest) request).addHeader("Content-Length", "1");
((MockHttpServletRequest) request).addHeader("Extra", "abc");
}
}
@SuppressWarnings("serial")
public static class SetStandardRequestHeadersLowercase extends TestServlet {
@Override
protected void before(HttpServletRequest request, HttpServletResponse response) {
((MockHttpServletRequest) request).addHeader("content-type",
"text/plain;charset=UTF-8");
((MockHttpServletRequest) request).addHeader("content-length", "1");
((MockHttpServletRequest) request).addHeader("extra", "abc");
}
}
@SuppressWarnings("serial")
public static class SetOtherRequestHeaders extends TestServlet {
@Override
protected void before(HttpServletRequest request, HttpServletResponse response) {
((MockHttpServletRequest) request).addHeader("One", "ab");
((MockHttpServletRequest) request).addHeader("One", "xy");
((MockHttpServletRequest) request).addHeader("Two", "1");
((MockHttpServletRequest) request).addHeader("Three", "xyz");
}
}
@SuppressWarnings("serial")
public static class GetBadRequestHeaders extends TestServlet {
@Override
public void executeApp() throws Exception {
MockHttpServletRequest request = new BadMockHttpServletRequest("GET", "/testservlet");
MockHttpServletResponse response = new PatchedMockHttpServletResponse();
service((ServletRequest) request, (ServletResponse) response);
}
}
@SuppressWarnings("serial")
public static class GetBadRequestHeaders2 extends TestServlet {
@Override
public void executeApp() throws Exception {
MockHttpServletRequest request = new BadMockHttpServletRequest2("GET", "/testservlet");
MockHttpServletResponse response = new PatchedMockHttpServletResponse();
service((ServletRequest) request, (ServletResponse) response);
}
}
public static class BadMockHttpServletRequest extends MockHttpServletRequest {
public BadMockHttpServletRequest(String method, String requestURI) {
super(method, requestURI);
}
@Override
public Enumeration<String> getHeaderNames() {
return Collections.enumeration(Lists.newArrayList((String) null));
}
}
public static class BadMockHttpServletRequest2 extends MockHttpServletRequest {
public BadMockHttpServletRequest2(String method, String requestURI) {
super(method, requestURI);
}
@Override
public Enumeration<String> getHeaderNames() {
return Collections.enumeration(Lists.newArrayList("h1"));
}
@Override
public Enumeration<String> getHeaders(String name) {
if (name.equals("h1")) {
return Collections.enumeration(Collections.<String>emptyList());
} else {
return super.getHeaders(name);
}
}
}
}