/* * Copyright 2013 The Netty Project * * The Netty Project licenses this file to you under the Apache License, version * 2.0 (the "License"); you may not use this file except in compliance with the * License. You may obtain a copy of the License at: * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations under * the License. */ package io.netty.handler.codec.http.cors; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.SimpleChannelInboundHandler; import io.netty.channel.embedded.EmbeddedChannel; import io.netty.handler.codec.http.DefaultFullHttpRequest; import io.netty.handler.codec.http.DefaultFullHttpResponse; import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.HttpMethod; import io.netty.handler.codec.http.HttpResponse; import org.junit.Test; import java.util.Arrays; import java.util.concurrent.Callable; import static io.netty.handler.codec.http.HttpHeaders.Names.*; import static io.netty.handler.codec.http.HttpMethod.*; import static io.netty.handler.codec.http.HttpResponseStatus.FORBIDDEN; import static io.netty.handler.codec.http.HttpResponseStatus.OK; import static io.netty.handler.codec.http.HttpVersion.*; import static org.hamcrest.CoreMatchers.*; import static org.hamcrest.MatcherAssert.*; public class CorsHandlerTest { @Test public void nonCorsRequest() { final HttpResponse response = simpleRequest(CorsConfig.withAnyOrigin().build(), null); assertThat(response.headers().contains(ACCESS_CONTROL_ALLOW_ORIGIN), is(false)); } @Test public void simpleRequestWithAnyOrigin() { final HttpResponse response = simpleRequest(CorsConfig.withAnyOrigin().build(), "http://localhost:7777"); assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN), is("*")); } @Test public void simpleRequestWithOrigin() { final String origin = "http://localhost:8888"; final HttpResponse response = simpleRequest(CorsConfig.withOrigin(origin).build(), origin); assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN), is(origin)); } @Test public void simpleRequestWithOrigins() { final String origin1 = "http://localhost:8888"; final String origin2 = "https://localhost:8888"; final String[] origins = {origin1, origin2}; final HttpResponse response1 = simpleRequest(CorsConfig.withOrigins(origins).build(), origin1); assertThat(response1.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN), is(origin1)); final HttpResponse response2 = simpleRequest(CorsConfig.withOrigins(origins).build(), origin2); assertThat(response2.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN), is(origin2)); } @Test public void simpleRequestWithNoMatchingOrigin() { final String origin = "http://localhost:8888"; final HttpResponse response = simpleRequest(CorsConfig.withOrigins("https://localhost:8888").build(), origin); assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN), is(nullValue())); } @Test public void preflightDeleteRequestWithCustomHeaders() { final CorsConfig config = CorsConfig.withOrigin("http://localhost:8888") .allowedRequestMethods(GET, DELETE) .build(); final HttpResponse response = preflightRequest(config, "http://localhost:8888", "content-type, xheader1"); assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN), is("http://localhost:8888")); assertThat(response.headers().getAll(ACCESS_CONTROL_ALLOW_METHODS), hasItems("GET", "DELETE")); assertThat(response.headers().get(VARY), equalTo(ORIGIN)); } @Test public void preflightGetRequestWithCustomHeaders() { final CorsConfig config = CorsConfig.withOrigin("http://localhost:8888") .allowedRequestMethods(OPTIONS, GET, DELETE) .allowedRequestHeaders("content-type", "xheader1") .build(); final HttpResponse response = preflightRequest(config, "http://localhost:8888", "content-type, xheader1"); assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN), is("http://localhost:8888")); assertThat(response.headers().getAll(ACCESS_CONTROL_ALLOW_METHODS), hasItems("OPTIONS", "GET")); assertThat(response.headers().getAll(ACCESS_CONTROL_ALLOW_HEADERS), hasItems("content-type", "xheader1")); assertThat(response.headers().get(VARY), equalTo(ORIGIN)); } @Test public void preflightRequestWithDefaultHeaders() { final CorsConfig config = CorsConfig.withOrigin("http://localhost:8888").build(); final HttpResponse response = preflightRequest(config, "http://localhost:8888", "content-type, xheader1"); assertThat(response.headers().get(CONTENT_LENGTH), is("0")); assertThat(response.headers().get(DATE), is(notNullValue())); assertThat(response.headers().get(VARY), equalTo(ORIGIN)); } @Test public void preflightRequestWithCustomHeader() { final CorsConfig config = CorsConfig.withOrigin("http://localhost:8888") .preflightResponseHeader("CustomHeader", "somevalue") .build(); final HttpResponse response = preflightRequest(config, "http://localhost:8888", "content-type, xheader1"); assertThat(response.headers().get("CustomHeader"), equalTo("somevalue")); assertThat(response.headers().get(VARY), equalTo(ORIGIN)); } @Test public void preflightRequestWithCustomHeaders() { final CorsConfig config = CorsConfig.withOrigin("http://localhost:8888") .preflightResponseHeader("CustomHeader", "value1", "value2") .build(); final HttpResponse response = preflightRequest(config, "http://localhost:8888", "content-type, xheader1"); assertThat(response.headers().getAll("CustomHeader"), hasItems("value1", "value2")); assertThat(response.headers().get(VARY), equalTo(ORIGIN)); } @Test public void preflightRequestWithCustomHeadersIterable() { final CorsConfig config = CorsConfig.withOrigin("http://localhost:8888") .preflightResponseHeader("CustomHeader", Arrays.asList("value1", "value2")) .build(); final HttpResponse response = preflightRequest(config, "http://localhost:8888", "content-type, xheader1"); assertThat(response.headers().getAll("CustomHeader"), hasItems("value1", "value2")); assertThat(response.headers().get(VARY), equalTo(ORIGIN)); } @Test public void preflightRequestWithValueGenerator() { final CorsConfig config = CorsConfig.withOrigin("http://localhost:8888") .preflightResponseHeader("GenHeader", new Callable<String>() { @Override public String call() throws Exception { return "generatedValue"; } }).build(); final HttpResponse response = preflightRequest(config, "http://localhost:8888", "content-type, xheader1"); assertThat(response.headers().get("GenHeader"), equalTo("generatedValue")); assertThat(response.headers().get(VARY), equalTo(ORIGIN)); } @Test public void preflightRequestWithNullOrigin() { final String origin = "null"; final CorsConfig config = CorsConfig.withOrigin(origin) .allowNullOrigin() .allowCredentials() .build(); final HttpResponse response = preflightRequest(config, origin, "content-type, xheader1"); assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN), is(equalTo("*"))); assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_CREDENTIALS), is(nullValue())); } @Test public void preflightRequestAllowCredentials() { final String origin = "null"; final CorsConfig config = CorsConfig.withOrigin(origin).allowCredentials().build(); final HttpResponse response = preflightRequest(config, origin, "content-type, xheader1"); assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_CREDENTIALS), is(equalTo("true"))); } @Test public void preflightRequestDoNotAllowCredentials() { final CorsConfig config = CorsConfig.withOrigin("http://localhost:8888").build(); final HttpResponse response = preflightRequest(config, "http://localhost:8888", ""); // the only valid value for Access-Control-Allow-Credentials is true. assertThat(response.headers().contains(ACCESS_CONTROL_ALLOW_CREDENTIALS), is(false)); } @Test public void simpleRequestCustomHeaders() { final CorsConfig config = CorsConfig.withAnyOrigin().exposeHeaders("custom1", "custom2").build(); final HttpResponse response = simpleRequest(config, "http://localhost:7777"); assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN), equalTo("*")); assertThat(response.headers().getAll(ACCESS_CONTROL_EXPOSE_HEADERS), hasItems("custom1", "custom1")); } @Test public void simpleRequestAllowCredentials() { final CorsConfig config = CorsConfig.withAnyOrigin().allowCredentials().build(); final HttpResponse response = simpleRequest(config, "http://localhost:7777"); assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_CREDENTIALS), equalTo("true")); } @Test public void simpleRequestDoNotAllowCredentials() { final CorsConfig config = CorsConfig.withAnyOrigin().build(); final HttpResponse response = simpleRequest(config, "http://localhost:7777"); assertThat(response.headers().contains(ACCESS_CONTROL_ALLOW_CREDENTIALS), is(false)); } @Test public void anyOriginAndAllowCredentialsShouldEchoRequestOrigin() { final CorsConfig config = CorsConfig.withAnyOrigin().allowCredentials().build(); final HttpResponse response = simpleRequest(config, "http://localhost:7777"); assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_CREDENTIALS), equalTo("true")); assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN), equalTo("http://localhost:7777")); assertThat(response.headers().get(VARY), equalTo(ORIGIN)); } @Test public void simpleRequestExposeHeaders() { final CorsConfig config = CorsConfig.withAnyOrigin().exposeHeaders("one", "two").build(); final HttpResponse response = simpleRequest(config, "http://localhost:7777"); assertThat(response.headers().getAll(ACCESS_CONTROL_EXPOSE_HEADERS), hasItems("one", "two")); } @Test public void simpleRequestShortCurcuit() { final CorsConfig config = CorsConfig.withOrigin("http://localhost:8080").shortCurcuit().build(); final HttpResponse response = simpleRequest(config, "http://localhost:7777"); assertThat(response.getStatus(), is(FORBIDDEN)); } @Test public void simpleRequestNoShortCurcuit() { final CorsConfig config = CorsConfig.withOrigin("http://localhost:8080").build(); final HttpResponse response = simpleRequest(config, "http://localhost:7777"); assertThat(response.getStatus(), is(OK)); assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN), is(nullValue())); } @Test public void shortCurcuitNonCorsRequest() { final CorsConfig config = CorsConfig.withOrigin("https://localhost").shortCurcuit().build(); final HttpResponse response = simpleRequest(config, null); assertThat(response.getStatus(), is(OK)); assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN), is(nullValue())); } @Test public void preflightRequestShouldReleaseRequest() { final CorsConfig config = CorsConfig.withOrigin("http://localhost:8888") .preflightResponseHeader("CustomHeader", Arrays.asList("value1", "value2")) .build(); final EmbeddedChannel channel = new EmbeddedChannel(new CorsHandler(config)); final FullHttpRequest request = optionsRequest("http://localhost:8888", "content-type, xheader1"); channel.writeInbound(request); assertThat(request.refCnt(), is(0)); } @Test public void forbiddenShouldReleaseRequest() { final CorsConfig config = CorsConfig.withOrigin("https://localhost").shortCurcuit().build(); final EmbeddedChannel channel = new EmbeddedChannel(new CorsHandler(config), new EchoHandler()); final FullHttpRequest request = createHttpRequest(GET); request.headers().set(ORIGIN, "http://localhost:8888"); channel.writeInbound(request); assertThat(request.refCnt(), is(0)); } private static HttpResponse simpleRequest(final CorsConfig config, final String origin) { return simpleRequest(config, origin, null); } private static HttpResponse simpleRequest(final CorsConfig config, final String origin, final String requestHeaders) { return simpleRequest(config, origin, requestHeaders, GET); } private static HttpResponse simpleRequest(final CorsConfig config, final String origin, final String requestHeaders, final HttpMethod method) { final EmbeddedChannel channel = new EmbeddedChannel(new CorsHandler(config), new EchoHandler()); final FullHttpRequest httpRequest = createHttpRequest(method); if (origin != null) { httpRequest.headers().set(ORIGIN, origin); } if (requestHeaders != null) { httpRequest.headers().set(ACCESS_CONTROL_REQUEST_HEADERS, requestHeaders); } channel.writeInbound(httpRequest); return (HttpResponse) channel.readOutbound(); } private static HttpResponse preflightRequest(final CorsConfig config, final String origin, final String requestHeaders) { final EmbeddedChannel channel = new EmbeddedChannel(new CorsHandler(config)); channel.writeInbound(optionsRequest(origin, requestHeaders)); return (HttpResponse) channel.readOutbound(); } private static FullHttpRequest optionsRequest(final String origin, final String requestHeaders) { final FullHttpRequest httpRequest = createHttpRequest(OPTIONS); httpRequest.headers().set(ORIGIN, origin); httpRequest.headers().set(ACCESS_CONTROL_REQUEST_METHOD, httpRequest.getMethod().toString()); httpRequest.headers().set(ACCESS_CONTROL_REQUEST_HEADERS, requestHeaders); return httpRequest; } private static FullHttpRequest createHttpRequest(HttpMethod method) { return new DefaultFullHttpRequest(HTTP_1_1, method, "/info"); } private static class EchoHandler extends SimpleChannelInboundHandler<Object> { @Override public void channelRead0(ChannelHandlerContext ctx, Object msg) throws Exception { ctx.writeAndFlush(new DefaultFullHttpResponse(HTTP_1_1, OK)); } } }