/* * Copyright 2002-2016 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.springframework.integration.http.inbound; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertThat; import static org.mockito.Mockito.mock; import java.io.IOException; import java.io.PrintWriter; import java.io.Serializable; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; import org.hamcrest.Matchers; import org.junit.Test; import org.springframework.beans.factory.BeanFactory; import org.springframework.expression.common.LiteralExpression; import org.springframework.http.HttpInputMessage; import org.springframework.http.HttpOutputMessage; import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; import org.springframework.http.converter.AbstractHttpMessageConverter; import org.springframework.http.converter.ByteArrayHttpMessageConverter; import org.springframework.http.converter.HttpMessageConverter; import org.springframework.http.converter.HttpMessageNotReadableException; import org.springframework.http.converter.HttpMessageNotWritableException; import org.springframework.integration.channel.DirectChannel; import org.springframework.integration.channel.QueueChannel; import org.springframework.integration.handler.AbstractReplyProducingMessageHandler; import org.springframework.integration.http.AbstractHttpInboundTests; import org.springframework.integration.http.HttpHeaders; import org.springframework.integration.http.converter.SerializingHttpMessageConverter; import org.springframework.integration.support.MessageBuilder; import org.springframework.messaging.Message; import org.springframework.messaging.support.GenericMessage; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.SerializationUtils; /** * @author Mark Fisher * @author Gary Russell * @author Gunnar Hillert * @author Artem Bilan * @author Biju Kunjummen * @since 2.0 */ public class HttpRequestHandlingMessagingGatewayTests extends AbstractHttpInboundTests { @Test @SuppressWarnings("unchecked") public void getRequestGeneratesMapPayload() throws Exception { QueueChannel requestChannel = new QueueChannel(); HttpRequestHandlingMessagingGateway gateway = new HttpRequestHandlingMessagingGateway(false); gateway.setBeanFactory(mock(BeanFactory.class)); gateway.setRequestChannel(requestChannel); gateway.afterPropertiesSet(); gateway.start(); MockHttpServletRequest request = new MockHttpServletRequest(); request.setMethod("GET"); request.setParameter("foo", "bar"); MockHttpServletResponse response = new MockHttpServletResponse(); gateway.handleRequest(request, response); Message<?> message = requestChannel.receive(0); assertNotNull(message); assertEquals(LinkedMultiValueMap.class, message.getPayload().getClass()); LinkedMultiValueMap<String, String> map = (LinkedMultiValueMap<String, String>) message.getPayload(); assertEquals(1, map.get("foo").size()); assertEquals("bar", map.getFirst("foo")); } @Test public void stringExpectedWithoutReply() throws Exception { QueueChannel requestChannel = new QueueChannel(); HttpRequestHandlingMessagingGateway gateway = new HttpRequestHandlingMessagingGateway(false); gateway.setBeanFactory(mock(BeanFactory.class)); gateway.setRequestPayloadType(String.class); gateway.setRequestChannel(requestChannel); gateway.afterPropertiesSet(); gateway.start(); MockHttpServletRequest request = new MockHttpServletRequest(); request.setMethod("POST"); request.setContentType("text/plain"); request.setContent("hello".getBytes()); MockHttpServletResponse response = new MockHttpServletResponse(); gateway.handleRequest(request, response); Message<?> message = requestChannel.receive(0); assertNotNull(message); assertEquals(String.class, message.getPayload().getClass()); assertEquals("hello", message.getPayload()); } @Test public void stringExpectedWithReply() throws Exception { stringExpectedWithReplyGuts(true); } @Test public void stringExpectedWithReplyNoContentType() throws Exception { stringExpectedWithReplyGuts(false); } private void stringExpectedWithReplyGuts(boolean contentType) throws Exception { DirectChannel requestChannel = new DirectChannel(); requestChannel.subscribe(new AbstractReplyProducingMessageHandler() { @Override protected Object handleRequestMessage(Message<?> requestMessage) { return requestMessage.getPayload().toString().toUpperCase(); } }); HttpRequestHandlingMessagingGateway gateway = new HttpRequestHandlingMessagingGateway(true); gateway.setStatusCodeExpression(new LiteralExpression("foo")); gateway.setBeanFactory(mock(BeanFactory.class)); gateway.setRequestPayloadType(String.class); gateway.setRequestChannel(requestChannel); gateway.afterPropertiesSet(); gateway.start(); MockHttpServletRequest request = new MockHttpServletRequest(); request.setMethod("POST"); request.addHeader("Accept", "x-application/octet-stream"); if (contentType) { request.setContentType("text/plain"); } request.setContent("hello".getBytes()); MockHttpServletResponse response = new MockHttpServletResponse(); gateway.handleRequest(request, response); assertEquals("HELLO", response.getContentAsString()); } @Test // INT-1767 public void noAcceptHeaderOnRequest() throws Exception { DirectChannel requestChannel = new DirectChannel(); requestChannel.subscribe(new AbstractReplyProducingMessageHandler() { @Override protected Object handleRequestMessage(Message<?> requestMessage) { return requestMessage.getPayload().toString().toUpperCase(); } }); HttpRequestHandlingMessagingGateway gateway = new HttpRequestHandlingMessagingGateway(true); gateway.setBeanFactory(mock(BeanFactory.class)); gateway.setRequestPayloadType(String.class); gateway.setRequestChannel(requestChannel); gateway.afterPropertiesSet(); gateway.start(); MockHttpServletRequest request = new MockHttpServletRequest(); request.setMethod("POST"); request.setContentType("text/plain"); request.setContent("hello".getBytes()); MockHttpServletResponse response = new MockHttpServletResponse(); gateway.handleRequest(request, response); assertEquals("HELLO", response.getContentAsString()); //INT-3120 assertNull(response.getHeader("Accept-Charset")); } @Test public void testExceptionConversion() throws Exception { QueueChannel requestChannel = new QueueChannel() { @Override protected boolean doSend(Message<?> message, long timeout) { throw new RuntimeException("Planned"); } }; HttpRequestHandlingMessagingGateway gateway = new HttpRequestHandlingMessagingGateway(true); gateway.setBeanFactory(mock(BeanFactory.class)); gateway.setRequestChannel(requestChannel); gateway.setConvertExceptions(true); gateway.setMessageConverters(Arrays.<HttpMessageConverter<?>>asList(new TestHttpMessageConverter())); gateway.afterPropertiesSet(); gateway.start(); MockHttpServletRequest request = new MockHttpServletRequest(); request.addHeader("Accept", "application/x-java-serialized-object"); request.setMethod("GET"); MockHttpServletResponse response = new MockHttpServletResponse(); gateway.handleRequest(request, response); String content = response.getContentAsString(); assertEquals("Planned", content); } @Test public void multiValueParameterMap() throws Exception { QueueChannel channel = new QueueChannel(); HttpRequestHandlingMessagingGateway gateway = new HttpRequestHandlingMessagingGateway(false); gateway.setBeanFactory(mock(BeanFactory.class)); gateway.setRequestChannel(channel); gateway.afterPropertiesSet(); gateway.start(); MockHttpServletRequest request = new MockHttpServletRequest("GET", "/test"); request.setParameter("foo", "123"); request.addParameter("bar", "456"); request.addParameter("bar", "789"); MockHttpServletResponse response = new MockHttpServletResponse(); gateway.handleRequest(request, response); Message<?> message = channel.receive(0); assertNotNull(message); assertNotNull(message.getPayload()); assertEquals(LinkedMultiValueMap.class, message.getPayload().getClass()); @SuppressWarnings("unchecked") LinkedMultiValueMap<String, String> map = (LinkedMultiValueMap<String, String>) message.getPayload(); List<String> fooValues = map.get("foo"); List<String> barValues = map.get("bar"); assertEquals(1, fooValues.size()); assertEquals("123", fooValues.get(0)); assertEquals(2, barValues.size()); assertEquals("456", barValues.get(0)); assertEquals("789", barValues.get(1)); } @Test public void serializableRequestBody() throws Exception { QueueChannel channel = new QueueChannel(); HttpRequestHandlingMessagingGateway gateway = new HttpRequestHandlingMessagingGateway(false); gateway.setBeanFactory(mock(BeanFactory.class)); gateway.setRequestPayloadType(TestBean.class); gateway.setRequestChannel(channel); List<HttpMessageConverter<?>> converters = new ArrayList<HttpMessageConverter<?>>(); converters.add(new SerializingHttpMessageConverter()); gateway.setMessageConverters(converters); gateway.afterPropertiesSet(); gateway.start(); MockHttpServletRequest request = new MockHttpServletRequest("POST", "/test"); request.setContentType("application/x-java-serialized-object"); TestBean testBean = new TestBean(); testBean.setName("T. Bean"); testBean.setAge(42); request.setContent(SerializationUtils.serialize(testBean)); MockHttpServletResponse response = new MockHttpServletResponse(); gateway.handleRequest(request, response); byte[] bytes = response.getContentAsByteArray(); assertNotNull(bytes); Message<?> message = channel.receive(0); assertNotNull(message); assertNotNull(message.getPayload()); assertEquals(TestBean.class, message.getPayload().getClass()); TestBean result = (TestBean) message.getPayload(); assertEquals("T. Bean", result.name); assertEquals(84, result.age); } @Test public void INT2680DuplicateContentTypeHeader() throws Exception { final DirectChannel requestChannel = new DirectChannel(); requestChannel.subscribe(new AbstractReplyProducingMessageHandler() { @Override protected Object handleRequestMessage(Message<?> requestMessage) { return MessageBuilder.withPayload("Cartman".getBytes()) .setHeader("Content-type", "text/plain") .build(); } }); final List<MediaType> supportedMediaTypes = new ArrayList<MediaType>(); supportedMediaTypes.add(MediaType.TEXT_HTML); final ByteArrayHttpMessageConverter messageConverter = new ByteArrayHttpMessageConverter(); messageConverter.setSupportedMediaTypes(supportedMediaTypes); final List<HttpMessageConverter<?>> messageConverters = new ArrayList<HttpMessageConverter<?>>(); messageConverters.add(messageConverter); final HttpRequestHandlingMessagingGateway gateway = new HttpRequestHandlingMessagingGateway(true); gateway.setBeanFactory(mock(BeanFactory.class)); gateway.setMessageConverters(messageConverters); gateway.setRequestChannel(requestChannel); gateway.start(); final MockHttpServletRequest request = new MockHttpServletRequest(); request.setMethod("GET"); request.addHeader("Accept", "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8"); final ContentTypeCheckingMockHttpServletResponse response = new ContentTypeCheckingMockHttpServletResponse(); gateway.handleRequest(request, response); assertEquals("Cartman", response.getContentAsString()); /* Before fixing INT2680, 2 content type headers were being written. */ final List<String> contentTypes = response.getContentTypeList(); assertEquals("Exptecting only 1 content type being set.", Integer.valueOf(1), Integer.valueOf(contentTypes.size())); assertEquals("text/plain", contentTypes.get(0)); } @Test public void timeoutDefault() throws Exception { QueueChannel requestChannel = new QueueChannel(); HttpRequestHandlingMessagingGateway gateway = new HttpRequestHandlingMessagingGateway(true); gateway.setBeanFactory(mock(BeanFactory.class)); gateway.setRequestChannel(requestChannel); gateway.setReplyTimeout(0); gateway.afterPropertiesSet(); gateway.start(); MockHttpServletRequest request = new MockHttpServletRequest(); request.setMethod("GET"); MockHttpServletResponse response = new MockHttpServletResponse(); gateway.handleRequest(request, response); Message<?> message = requestChannel.receive(0); assertNotNull(message); assertEquals(500, response.getStatus()); } @Test public void timeoutStatusExpression() throws Exception { QueueChannel requestChannel = new QueueChannel(); HttpRequestHandlingMessagingGateway gateway = new HttpRequestHandlingMessagingGateway(true); gateway.setBeanFactory(mock(BeanFactory.class)); gateway.setRequestChannel(requestChannel); gateway.setReplyTimeout(0); gateway.setStatusCodeExpression(new LiteralExpression("501")); gateway.afterPropertiesSet(); gateway.start(); MockHttpServletRequest request = new MockHttpServletRequest(); request.setMethod("GET"); MockHttpServletResponse response = new MockHttpServletResponse(); gateway.handleRequest(request, response); Message<?> message = requestChannel.receive(0); assertNotNull(message); assertEquals(501, response.getStatus()); } @Test public void timeoutErrorFlow() throws Exception { QueueChannel requestChannel = new QueueChannel(); HttpRequestHandlingMessagingGateway gateway = new HttpRequestHandlingMessagingGateway(true); gateway.setBeanFactory(mock(BeanFactory.class)); gateway.setRequestChannel(requestChannel); gateway.setReplyTimeout(0); DirectChannel errorChannel = new DirectChannel(); errorChannel.subscribe(new AbstractReplyProducingMessageHandler() { @Override protected Object handleRequestMessage(Message<?> requestMessage) { return new GenericMessage<String>("foo", Collections.<String, Object>singletonMap(HttpHeaders.STATUS_CODE, HttpStatus.GATEWAY_TIMEOUT)); } }); gateway.setErrorChannel(errorChannel); gateway.afterPropertiesSet(); gateway.start(); MockHttpServletRequest request = new MockHttpServletRequest(); request.setMethod("GET"); MockHttpServletResponse response = new MockHttpServletResponse(); gateway.handleRequest(request, response); Message<?> message = requestChannel.receive(0); assertNotNull(message); assertEquals(504, response.getStatus()); } @Test public void timeoutErrorFlowTimeout() throws Exception { QueueChannel requestChannel = new QueueChannel(); HttpRequestHandlingMessagingGateway gateway = new HttpRequestHandlingMessagingGateway(true); gateway.setBeanFactory(mock(BeanFactory.class)); gateway.setRequestChannel(requestChannel); gateway.setReplyTimeout(0); QueueChannel errorChannel = new QueueChannel(); gateway.setErrorChannel(errorChannel); gateway.setStatusCodeExpression(new LiteralExpression("501")); gateway.afterPropertiesSet(); gateway.start(); MockHttpServletRequest request = new MockHttpServletRequest(); request.setMethod("GET"); MockHttpServletResponse response = new MockHttpServletResponse(); gateway.handleRequest(request, response); Message<?> message = requestChannel.receive(0); assertNotNull(message); assertEquals(501, response.getStatus()); assertThat(response.getContentAsString(), Matchers.containsString("from error channel")); } private class ContentTypeCheckingMockHttpServletResponse extends MockHttpServletResponse { private final List<String> contentTypeList = new ArrayList<String>(); @Override public void addHeader(String name, String value) { if ("Content-Type".equalsIgnoreCase(name)) { this.contentTypeList.add(value); } super.addHeader(name, value); } public List<String> getContentTypeList() { return contentTypeList; } } private static class TestHttpMessageConverter extends AbstractHttpMessageConverter<Exception> { TestHttpMessageConverter() { setSupportedMediaTypes(Arrays.asList(MediaType.ALL)); } @Override protected Exception readInternal(Class<? extends Exception> clazz, HttpInputMessage inputMessage) throws IOException, HttpMessageNotReadableException { return null; } @Override protected boolean supports(Class<?> clazz) { return true; } @Override protected void writeInternal(Exception t, HttpOutputMessage outputMessage) throws IOException, HttpMessageNotWritableException { new PrintWriter(outputMessage.getBody()).append(t.getCause().getMessage()).flush(); } } @SuppressWarnings("serial") public static class TestBean implements Serializable { String name; int age; public void setName(String name) { this.name = name; } public void setAge(int age) { this.age = age * 2; } } }