/* * Copyright 2002-2017 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.springframework.web.servlet.mvc.method.annotation; import java.lang.annotation.ElementType; import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; import java.lang.reflect.Method; import java.util.Arrays; import java.util.Properties; import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; import org.springframework.beans.DirectFieldAccessor; import org.springframework.context.support.PropertySourcesPlaceholderConfigurer; import org.springframework.core.annotation.AnnotatedElementUtils; import org.springframework.core.annotation.AnnotationUtils; import org.springframework.core.env.PropertiesPropertySource; import org.springframework.http.HttpHeaders; import org.springframework.mock.web.test.MockHttpServletRequest; import org.springframework.stereotype.Controller; import org.springframework.util.CollectionUtils; import org.springframework.web.bind.annotation.CrossOrigin; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RequestMethod; import org.springframework.web.context.support.StaticWebApplicationContext; import org.springframework.web.cors.CorsConfiguration; import org.springframework.web.servlet.HandlerExecutionChain; import org.springframework.web.servlet.HandlerInterceptor; import org.springframework.web.servlet.mvc.condition.ConsumesRequestCondition; import org.springframework.web.servlet.mvc.condition.HeadersRequestCondition; import org.springframework.web.servlet.mvc.condition.ParamsRequestCondition; import org.springframework.web.servlet.mvc.condition.PatternsRequestCondition; import org.springframework.web.servlet.mvc.condition.ProducesRequestCondition; import org.springframework.web.servlet.mvc.condition.RequestMethodsRequestCondition; import org.springframework.web.servlet.mvc.method.RequestMappingInfo; import static org.hamcrest.CoreMatchers.*; import static org.junit.Assert.*; /** * Test fixture for {@link CrossOrigin @CrossOrigin} annotated methods. * * @author Sebastien Deleuze * @author Sam Brannen * @author Nicolas Labrot */ public class CrossOriginTests { private final TestRequestMappingInfoHandlerMapping handlerMapping = new TestRequestMappingInfoHandlerMapping(); private final MockHttpServletRequest request = new MockHttpServletRequest(); @Rule public ExpectedException exception = ExpectedException.none(); @Before public void setup() { StaticWebApplicationContext wac = new StaticWebApplicationContext(); Properties props = new Properties(); props.setProperty("myOrigin", "http://example.com"); wac.getEnvironment().getPropertySources().addFirst(new PropertiesPropertySource("ps", props)); wac.registerSingleton("ppc", PropertySourcesPlaceholderConfigurer.class); wac.refresh(); this.handlerMapping.setRemoveSemicolonContent(false); wac.getAutowireCapableBeanFactory().initializeBean(this.handlerMapping, "hm"); this.request.setMethod("GET"); this.request.addHeader(HttpHeaders.ORIGIN, "http://domain.com/"); } @Test public void noAnnotationWithoutOrigin() throws Exception { this.handlerMapping.registerHandler(new MethodLevelController()); MockHttpServletRequest request = new MockHttpServletRequest("GET", "/no"); HandlerExecutionChain chain = this.handlerMapping.getHandler(request); assertNull(getCorsConfiguration(chain, false)); } @Test // SPR-12931 public void noAnnotationWithOrigin() throws Exception { this.handlerMapping.registerHandler(new MethodLevelController()); this.request.setRequestURI("/no"); HandlerExecutionChain chain = this.handlerMapping.getHandler(request); assertNull(getCorsConfiguration(chain, false)); } @Test // SPR-12931 public void noAnnotationPostWithOrigin() throws Exception { this.handlerMapping.registerHandler(new MethodLevelController()); this.request.setMethod("POST"); this.request.setRequestURI("/no"); HandlerExecutionChain chain = this.handlerMapping.getHandler(request); assertNull(getCorsConfiguration(chain, false)); } @Test public void defaultAnnotation() throws Exception { this.handlerMapping.registerHandler(new MethodLevelController()); this.request.setRequestURI("/default"); HandlerExecutionChain chain = this.handlerMapping.getHandler(request); CorsConfiguration config = getCorsConfiguration(chain, false); assertNotNull(config); assertArrayEquals(new String[] {"GET"}, config.getAllowedMethods().toArray()); assertArrayEquals(new String[] {"*"}, config.getAllowedOrigins().toArray()); assertTrue(config.getAllowCredentials()); assertArrayEquals(new String[] {"*"}, config.getAllowedHeaders().toArray()); assertTrue(CollectionUtils.isEmpty(config.getExposedHeaders())); assertEquals(new Long(1800), config.getMaxAge()); } @Test public void customized() throws Exception { this.handlerMapping.registerHandler(new MethodLevelController()); this.request.setRequestURI("/customized"); HandlerExecutionChain chain = this.handlerMapping.getHandler(request); CorsConfiguration config = getCorsConfiguration(chain, false); assertNotNull(config); assertArrayEquals(new String[] {"DELETE"}, config.getAllowedMethods().toArray()); assertArrayEquals(new String[] {"http://site1.com", "http://site2.com"}, config.getAllowedOrigins().toArray()); assertArrayEquals(new String[] {"header1", "header2"}, config.getAllowedHeaders().toArray()); assertArrayEquals(new String[] {"header3", "header4"}, config.getExposedHeaders().toArray()); assertEquals(new Long(123), config.getMaxAge()); assertFalse(config.getAllowCredentials()); } @Test public void customOriginDefinedViaValueAttribute() throws Exception { this.handlerMapping.registerHandler(new MethodLevelController()); this.request.setRequestURI("/customOrigin"); HandlerExecutionChain chain = this.handlerMapping.getHandler(request); CorsConfiguration config = getCorsConfiguration(chain, false); assertNotNull(config); assertEquals(Arrays.asList("http://example.com"), config.getAllowedOrigins()); assertTrue(config.getAllowCredentials()); } @Test public void customOriginDefinedViaPlaceholder() throws Exception { this.handlerMapping.registerHandler(new MethodLevelController()); this.request.setRequestURI("/someOrigin"); HandlerExecutionChain chain = this.handlerMapping.getHandler(request); CorsConfiguration config = getCorsConfiguration(chain, false); assertNotNull(config); assertEquals(Arrays.asList("http://example.com"), config.getAllowedOrigins()); assertTrue(config.getAllowCredentials()); } @Test public void bogusAllowCredentialsValue() throws Exception { exception.expect(IllegalStateException.class); exception.expectMessage(containsString("@CrossOrigin's allowCredentials")); exception.expectMessage(containsString("current value is [bogus]")); this.handlerMapping.registerHandler(new MethodLevelControllerWithBogusAllowCredentialsValue()); } @Test public void classLevel() throws Exception { this.handlerMapping.registerHandler(new ClassLevelController()); this.request.setRequestURI("/foo"); HandlerExecutionChain chain = this.handlerMapping.getHandler(request); CorsConfiguration config = getCorsConfiguration(chain, false); assertNotNull(config); assertArrayEquals(new String[] {"GET"}, config.getAllowedMethods().toArray()); assertArrayEquals(new String[] {"*"}, config.getAllowedOrigins().toArray()); assertFalse(config.getAllowCredentials()); this.request.setRequestURI("/bar"); chain = this.handlerMapping.getHandler(request); config = getCorsConfiguration(chain, false); assertNotNull(config); assertArrayEquals(new String[] {"GET"}, config.getAllowedMethods().toArray()); assertArrayEquals(new String[] {"*"}, config.getAllowedOrigins().toArray()); assertFalse(config.getAllowCredentials()); this.request.setRequestURI("/baz"); chain = this.handlerMapping.getHandler(request); config = getCorsConfiguration(chain, false); assertNotNull(config); assertArrayEquals(new String[] {"GET"}, config.getAllowedMethods().toArray()); assertArrayEquals(new String[] {"*"}, config.getAllowedOrigins().toArray()); assertTrue(config.getAllowCredentials()); } @Test // SPR-13468 public void classLevelComposedAnnotation() throws Exception { this.handlerMapping.registerHandler(new ClassLevelMappingWithComposedAnnotation()); this.request.setRequestURI("/foo"); HandlerExecutionChain chain = this.handlerMapping.getHandler(request); CorsConfiguration config = getCorsConfiguration(chain, false); assertNotNull(config); assertArrayEquals(new String[] {"GET"}, config.getAllowedMethods().toArray()); assertArrayEquals(new String[] {"http://foo.com"}, config.getAllowedOrigins().toArray()); assertTrue(config.getAllowCredentials()); } @Test // SPR-13468 public void methodLevelComposedAnnotation() throws Exception { this.handlerMapping.registerHandler(new MethodLevelMappingWithComposedAnnotation()); this.request.setRequestURI("/foo"); HandlerExecutionChain chain = this.handlerMapping.getHandler(request); CorsConfiguration config = getCorsConfiguration(chain, false); assertNotNull(config); assertArrayEquals(new String[] {"GET"}, config.getAllowedMethods().toArray()); assertArrayEquals(new String[] {"http://foo.com"}, config.getAllowedOrigins().toArray()); assertTrue(config.getAllowCredentials()); } @Test public void preFlightRequest() throws Exception { this.handlerMapping.registerHandler(new MethodLevelController()); this.request.setMethod("OPTIONS"); this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET"); this.request.setRequestURI("/default"); HandlerExecutionChain chain = this.handlerMapping.getHandler(request); CorsConfiguration config = getCorsConfiguration(chain, true); assertNotNull(config); assertArrayEquals(new String[] {"GET"}, config.getAllowedMethods().toArray()); assertArrayEquals(new String[] {"*"}, config.getAllowedOrigins().toArray()); assertTrue(config.getAllowCredentials()); assertArrayEquals(new String[] {"*"}, config.getAllowedHeaders().toArray()); assertTrue(CollectionUtils.isEmpty(config.getExposedHeaders())); assertEquals(new Long(1800), config.getMaxAge()); } @Test public void ambiguousHeaderPreFlightRequest() throws Exception { this.handlerMapping.registerHandler(new MethodLevelController()); this.request.setMethod("OPTIONS"); this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET"); this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS, "header1"); this.request.setRequestURI("/ambiguous-header"); HandlerExecutionChain chain = this.handlerMapping.getHandler(request); CorsConfiguration config = getCorsConfiguration(chain, true); assertNotNull(config); assertArrayEquals(new String[] {"*"}, config.getAllowedMethods().toArray()); assertArrayEquals(new String[] {"*"}, config.getAllowedOrigins().toArray()); assertArrayEquals(new String[] {"*"}, config.getAllowedHeaders().toArray()); assertTrue(config.getAllowCredentials()); assertTrue(CollectionUtils.isEmpty(config.getExposedHeaders())); assertNull(config.getMaxAge()); } @Test public void ambiguousProducesPreFlightRequest() throws Exception { this.handlerMapping.registerHandler(new MethodLevelController()); this.request.setMethod("OPTIONS"); this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET"); this.request.setRequestURI("/ambiguous-produces"); HandlerExecutionChain chain = this.handlerMapping.getHandler(request); CorsConfiguration config = getCorsConfiguration(chain, true); assertNotNull(config); assertArrayEquals(new String[] {"*"}, config.getAllowedMethods().toArray()); assertArrayEquals(new String[] {"*"}, config.getAllowedOrigins().toArray()); assertArrayEquals(new String[] {"*"}, config.getAllowedHeaders().toArray()); assertTrue(config.getAllowCredentials()); assertTrue(CollectionUtils.isEmpty(config.getExposedHeaders())); assertNull(config.getMaxAge()); } @Test public void preFlightRequestWithoutRequestMethodHeader() throws Exception { MockHttpServletRequest request = new MockHttpServletRequest("OPTIONS", "/default"); request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com"); assertNull(this.handlerMapping.getHandler(request)); } private CorsConfiguration getCorsConfiguration(HandlerExecutionChain chain, boolean isPreFlightRequest) { if (isPreFlightRequest) { Object handler = chain.getHandler(); assertTrue(handler.getClass().getSimpleName().equals("PreFlightHandler")); DirectFieldAccessor accessor = new DirectFieldAccessor(handler); return (CorsConfiguration)accessor.getPropertyValue("config"); } else { HandlerInterceptor[] interceptors = chain.getInterceptors(); if (interceptors != null) { for (HandlerInterceptor interceptor : interceptors) { if (interceptor.getClass().getSimpleName().equals("CorsInterceptor")) { DirectFieldAccessor accessor = new DirectFieldAccessor(interceptor); return (CorsConfiguration) accessor.getPropertyValue("config"); } } } } return null; } @Controller private static class MethodLevelController { @RequestMapping(path = "/no", method = RequestMethod.GET) public void noAnnotation() { } @RequestMapping(path = "/no", method = RequestMethod.POST) public void noAnnotationPost() { } @CrossOrigin @RequestMapping(path = "/default", method = RequestMethod.GET) public void defaultAnnotation() { } @CrossOrigin @RequestMapping(path = "/default", method = RequestMethod.GET, params = "q") public void defaultAnnotationWithParams() { } @CrossOrigin @RequestMapping(path = "/ambiguous-header", method = RequestMethod.GET, headers = "header1=a") public void ambigousHeader1a() { } @CrossOrigin @RequestMapping(path = "/ambiguous-header", method = RequestMethod.GET, headers = "header1=b") public void ambigousHeader1b() { } @CrossOrigin @RequestMapping(path = "/ambiguous-produces", method = RequestMethod.GET, produces = "application/xml") public String ambigousProducesXml() { return "<a></a>"; } @CrossOrigin @RequestMapping(path = "/ambiguous-produces", method = RequestMethod.GET, produces = "application/json") public String ambigousProducesJson() { return "{}"; } @CrossOrigin(origins = { "http://site1.com", "http://site2.com" }, allowedHeaders = { "header1", "header2" }, exposedHeaders = { "header3", "header4" }, methods = RequestMethod.DELETE, maxAge = 123, allowCredentials = "false") @RequestMapping(path = "/customized", method = { RequestMethod.GET, RequestMethod.POST }) public void customized() { } @CrossOrigin("http://example.com") @RequestMapping("/customOrigin") public void customOriginDefinedViaValueAttribute() { } @CrossOrigin("${myOrigin}") @RequestMapping("/someOrigin") public void customOriginDefinedViaPlaceholder() { } } @Controller private static class MethodLevelControllerWithBogusAllowCredentialsValue { @CrossOrigin(allowCredentials = "bogus") @RequestMapping("/bogus") public void bogusAllowCredentialsValue() { } } @Controller @CrossOrigin(allowCredentials = "false") private static class ClassLevelController { @RequestMapping(path = "/foo", method = RequestMethod.GET) public void foo() { } @CrossOrigin @RequestMapping(path = "/bar", method = RequestMethod.GET) public void bar() { } @CrossOrigin(allowCredentials = "true") @RequestMapping(path = "/baz", method = RequestMethod.GET) public void baz() { } } @Target({ElementType.METHOD, ElementType.TYPE}) @Retention(RetentionPolicy.RUNTIME) @CrossOrigin private @interface ComposedCrossOrigin { String[] origins() default {}; String allowCredentials() default ""; } @Controller @ComposedCrossOrigin(origins = "http://foo.com", allowCredentials = "true") private static class ClassLevelMappingWithComposedAnnotation { @RequestMapping(path = "/foo", method = RequestMethod.GET) public void foo() { } } @Controller private static class MethodLevelMappingWithComposedAnnotation { @RequestMapping(path = "/foo", method = RequestMethod.GET) @ComposedCrossOrigin(origins = "http://foo.com", allowCredentials = "true") public void foo() { } } private static class TestRequestMappingInfoHandlerMapping extends RequestMappingHandlerMapping { public void registerHandler(Object handler) { super.detectHandlerMethods(handler); } @Override protected boolean isHandler(Class<?> beanType) { return AnnotationUtils.findAnnotation(beanType, Controller.class) != null; } @Override protected RequestMappingInfo getMappingForMethod(Method method, Class<?> handlerType) { RequestMapping annotation = AnnotatedElementUtils.findMergedAnnotation(method, RequestMapping.class); if (annotation != null) { return new RequestMappingInfo( new PatternsRequestCondition(annotation.value(), getUrlPathHelper(), getPathMatcher(), true, true), new RequestMethodsRequestCondition(annotation.method()), new ParamsRequestCondition(annotation.params()), new HeadersRequestCondition(annotation.headers()), new ConsumesRequestCondition(annotation.consumes(), annotation.headers()), new ProducesRequestCondition(annotation.produces(), annotation.headers()), null); } else { return null; } } } }