/*
* Copyright 2015-2016 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.integration.http.inbound;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.instanceOf;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.HttpHeaders;
import org.springframework.integration.test.util.TestUtils;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.test.annotation.DirtiesContext;
import org.springframework.test.context.ContextConfiguration;
import org.springframework.test.context.junit4.SpringJUnit4ClassRunner;
import org.springframework.web.cors.CorsConfiguration;
import org.springframework.web.method.HandlerMethod;
import org.springframework.web.servlet.HandlerExecutionChain;
import org.springframework.web.servlet.HandlerInterceptor;
/**
* @author Artem Bilan
* @since 4.2
*/
@RunWith(SpringJUnit4ClassRunner.class)
@ContextConfiguration
@DirtiesContext
public class CrossOriginTests {
@Autowired
private IntegrationRequestMappingHandlerMapping handlerMapping;
private MockHttpServletRequest request;
@Before
public void setUp() {
this.request = new MockHttpServletRequest();
this.request.setMethod("GET");
this.request.addHeader(HttpHeaders.ORIGIN, "http://domain.com/");
}
@Test
public void noEndpointWithoutOriginHeader() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest("GET", "/no");
HandlerExecutionChain chain = this.handlerMapping.getHandler(request);
CorsConfiguration config = getCorsConfiguration(chain, false);
assertNull(config);
}
@Test
public void noEndpointWithOriginHeader() throws Exception {
this.request.setRequestURI("/no");
HandlerExecutionChain chain = this.handlerMapping.getHandler(request);
CorsConfiguration config = getCorsConfiguration(chain, false);
assertNull(config);
}
@Test
public void noEndpointPostWithOriginHeader() throws Exception {
this.request.setMethod("POST");
this.request.setRequestURI("/no");
HandlerExecutionChain chain = this.handlerMapping.getHandler(this.request);
CorsConfiguration config = getCorsConfiguration(chain, false);
assertNull(config);
}
@Test
public void defaultEndpointWithCrossOrigin() throws Exception {
this.request.setRequestURI("/default");
HandlerExecutionChain chain = this.handlerMapping.getHandler(this.request);
CorsConfiguration config = getCorsConfiguration(chain, false);
assertNotNull(config);
assertArrayEquals(new String[] { "GET" }, config.getAllowedMethods().toArray());
assertArrayEquals(new String[] { "*" }, config.getAllowedOrigins().toArray());
assertTrue(config.getAllowCredentials());
assertArrayEquals(new String[] { "*" }, config.getAllowedHeaders().toArray());
assertNull(config.getExposedHeaders());
assertEquals(new Long(1800), config.getMaxAge());
}
@Test
public void customized() throws Exception {
this.request.setRequestURI("/customized");
HandlerExecutionChain chain = this.handlerMapping.getHandler(this.request);
CorsConfiguration config = getCorsConfiguration(chain, false);
assertNotNull(config);
assertArrayEquals(new String[] { "DELETE" }, config.getAllowedMethods().toArray());
assertArrayEquals(new String[] { "http://site1.com", "http://site2.com" }, config.getAllowedOrigins().toArray());
assertArrayEquals(new String[] { "header1", "header2" }, config.getAllowedHeaders().toArray());
assertArrayEquals(new String[] { "header3", "header4" }, config.getExposedHeaders().toArray());
assertEquals(new Long(123), config.getMaxAge());
assertEquals(false, config.getAllowCredentials());
}
@Test
public void preFlightRequest() throws Exception {
this.request.setMethod("OPTIONS");
this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET");
this.request.setRequestURI("/default");
HandlerExecutionChain chain = this.handlerMapping.getHandler(this.request);
CorsConfiguration config = getCorsConfiguration(chain, true);
assertNotNull(config);
assertArrayEquals(new String[] { "GET" }, config.getAllowedMethods().toArray());
assertArrayEquals(new String[] { "*" }, config.getAllowedOrigins().toArray());
assertTrue(config.getAllowCredentials());
assertArrayEquals(new String[] { "*" }, config.getAllowedHeaders().toArray());
assertNull(config.getExposedHeaders());
assertEquals(new Long(1800), config.getMaxAge());
}
@Test
public void ambiguousHeaderPreFlightRequest() throws Exception {
this.request.setMethod("OPTIONS");
this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET");
this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS, "header1");
this.request.setRequestURI("/ambiguous-header");
HandlerExecutionChain chain = this.handlerMapping.getHandler(this.request);
CorsConfiguration config = getCorsConfiguration(chain, true);
assertNotNull(config);
assertArrayEquals(new String[] { "*" }, config.getAllowedMethods().toArray());
assertArrayEquals(new String[] { "*" }, config.getAllowedOrigins().toArray());
assertArrayEquals(new String[] { "*" }, config.getAllowedHeaders().toArray());
assertTrue(config.getAllowCredentials());
assertNull(config.getExposedHeaders());
assertNull(config.getMaxAge());
}
@Test
public void ambiguousProducesPreFlightRequest() throws Exception {
this.request.setMethod("OPTIONS");
this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET");
this.request.setRequestURI("/ambiguous-produces");
HandlerExecutionChain chain = this.handlerMapping.getHandler(this.request);
CorsConfiguration config = getCorsConfiguration(chain, true);
assertNotNull(config);
assertArrayEquals(new String[] { "*" }, config.getAllowedMethods().toArray());
assertArrayEquals(new String[] { "*" }, config.getAllowedOrigins().toArray());
assertArrayEquals(new String[] { "*" }, config.getAllowedHeaders().toArray());
assertTrue(config.getAllowCredentials());
assertNull(config.getExposedHeaders());
assertNull(config.getMaxAge());
}
@Test
public void testOptionsHeaderHandling() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest("OPTIONS", "/default");
request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com");
HandlerExecutionChain handler = this.handlerMapping.getHandler(request);
assertNotNull(handler);
Object handlerMethod = handler.getHandler();
assertNotNull(handlerMethod);
assertThat(handlerMethod, instanceOf(HandlerMethod.class));
assertThat(((HandlerMethod) handlerMethod).getBeanType().getName(),
containsString("HttpOptionsHandler"));
}
private CorsConfiguration getCorsConfiguration(HandlerExecutionChain chain, boolean isPreFlightRequest) {
if (isPreFlightRequest) {
Object handler = chain.getHandler();
assertTrue(handler.getClass().getSimpleName().equals("PreFlightHandler"));
return TestUtils.getPropertyValue(handler, "config", CorsConfiguration.class);
}
else {
HandlerInterceptor[] interceptors = chain.getInterceptors();
if (interceptors != null) {
for (HandlerInterceptor interceptor : interceptors) {
if (interceptor.getClass().getSimpleName().equals("CorsInterceptor")) {
return TestUtils.getPropertyValue(interceptor, "config", CorsConfiguration.class);
}
}
}
}
return null;
}
}