/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.http.netty;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.network.NetworkService;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.MockBigArrays;
import org.elasticsearch.http.netty.cors.CorsHandler;
import org.elasticsearch.indices.breaker.NoneCircuitBreakerService;
import org.elasticsearch.rest.RestResponse;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.common.util.MockBigArrays;
import org.elasticsearch.cache.recycler.MockPageCacheRecycler;
import org.elasticsearch.threadpool.ThreadPool;
import org.jboss.netty.buffer.ChannelBuffer;
import org.jboss.netty.buffer.ChannelBuffers;
import org.jboss.netty.channel.*;
import org.jboss.netty.handler.codec.http.*;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import java.net.SocketAddress;
import java.util.ArrayList;
import java.util.List;
import static org.elasticsearch.http.netty.NettyHttpServerTransport.SETTING_CORS_ALLOW_CREDENTIALS;
import static org.elasticsearch.http.netty.NettyHttpServerTransport.SETTING_CORS_ALLOW_METHODS;
import static org.elasticsearch.http.netty.NettyHttpServerTransport.SETTING_CORS_ALLOW_ORIGIN;
import static org.elasticsearch.http.netty.NettyHttpServerTransport.SETTING_CORS_ENABLED;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.notNullValue;
import static org.hamcrest.Matchers.nullValue;
public class NettyHttpChannelTests extends ESTestCase {
private NetworkService networkService;
private ThreadPool threadPool;
private MockBigArrays bigArrays;
private NettyHttpServerTransport httpServerTransport;
@Before
public void setup() throws Exception {
networkService = new NetworkService(Settings.EMPTY);
threadPool = new ThreadPool("test");
MockPageCacheRecycler mockPageCacheRecycler = new MockPageCacheRecycler(Settings.EMPTY, threadPool);
bigArrays = new MockBigArrays(mockPageCacheRecycler, new NoneCircuitBreakerService());
}
@After
public void shutdown() throws Exception {
if (threadPool != null) {
threadPool.shutdownNow();
}
if (httpServerTransport != null) {
httpServerTransport.close();
}
}
@Test
public void testCorsEnabledWithoutAllowOrigins() {
// Set up a HTTP transport with only the CORS enabled setting
Settings settings = Settings.builder()
.put(NettyHttpServerTransport.SETTING_CORS_ENABLED, true)
.build();
HttpResponse response = execRequestWithCors(settings, "remote-host", "request-host");
// inspect response and validate
assertThat(response.headers().get(HttpHeaders.Names.ACCESS_CONTROL_ALLOW_ORIGIN), nullValue());
}
@Test
public void testCorsEnabledWithAllowOrigins() {
final String originValue = "remote-host";
// create a http transport with CORS enabled and allow origin configured
Settings settings = Settings.builder()
.put(SETTING_CORS_ENABLED, true)
.put(SETTING_CORS_ALLOW_ORIGIN, originValue)
.build();
HttpResponse response = execRequestWithCors(settings, originValue, "request-host");
// inspect response and validate
assertThat(response.headers().get(HttpHeaders.Names.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
String allowedOrigins = response.headers().get(HttpHeaders.Names.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
}
@Test
public void testCorsAllowOriginWithSameHost() {
String originValue = "remote-host";
String host = "remote-host";
// create a http transport with CORS enabled
Settings settings = Settings.builder()
.put(NettyHttpServerTransport.SETTING_CORS_ENABLED, true)
.build();
HttpResponse response = execRequestWithCors(settings, originValue, host);
// inspect response and validate
assertThat(response.headers().get(HttpHeaders.Names.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
String allowedOrigins = response.headers().get(HttpHeaders.Names.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
originValue = "http://" + originValue;
response = execRequestWithCors(settings, originValue, host);
assertThat(response.headers().get(HttpHeaders.Names.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
allowedOrigins = response.headers().get(HttpHeaders.Names.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
originValue = originValue + ":5555";
host = host + ":5555";
response = execRequestWithCors(settings, originValue, host);
assertThat(response.headers().get(HttpHeaders.Names.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
allowedOrigins = response.headers().get(HttpHeaders.Names.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
originValue = originValue.replace("http", "https");
response = execRequestWithCors(settings, originValue, host);
assertThat(response.headers().get(HttpHeaders.Names.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
allowedOrigins = response.headers().get(HttpHeaders.Names.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
}
@Test
public void testThatStringLiteralWorksOnMatch() {
final String originValue = "remote-host";
Settings settings = Settings.builder()
.put(SETTING_CORS_ENABLED, true)
.put(SETTING_CORS_ALLOW_ORIGIN, originValue)
.put(SETTING_CORS_ALLOW_METHODS, "get, options, post")
.put(SETTING_CORS_ALLOW_CREDENTIALS, true)
.build();
HttpResponse response = execRequestWithCors(settings, originValue, "request-host");
// inspect response and validate
assertThat(response.headers().get(HttpHeaders.Names.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
String allowedOrigins = response.headers().get(HttpHeaders.Names.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
assertThat(response.headers().get(HttpHeaders.Names.ACCESS_CONTROL_ALLOW_CREDENTIALS), equalTo("true"));
}
@Test
public void testThatAnyOriginWorks() {
final String originValue = CorsHandler.ANY_ORIGIN;
Settings settings = Settings.builder()
.put(SETTING_CORS_ENABLED, true)
.put(SETTING_CORS_ALLOW_ORIGIN, originValue)
.build();
HttpResponse response = execRequestWithCors(settings, originValue, "request-host");
// inspect response and validate
assertThat(response.headers().get(HttpHeaders.Names.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
String allowedOrigins = response.headers().get(HttpHeaders.Names.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
assertThat(response.headers().get(HttpHeaders.Names.ACCESS_CONTROL_ALLOW_CREDENTIALS), nullValue());
}
@Test
public void testHeadersSet() {
Settings settings = Settings.builder().build();
httpServerTransport = new NettyHttpServerTransport(settings, networkService, bigArrays);
HttpRequest httpRequest = new TestHttpRequest();
httpRequest.headers().add(HttpHeaders.Names.ORIGIN, "remote");
WriteCapturingChannel writeCapturingChannel = new WriteCapturingChannel();
NettyHttpRequest request = new NettyHttpRequest(httpRequest, writeCapturingChannel);
// send a response
NettyHttpChannel channel = new NettyHttpChannel(httpServerTransport, request, null, randomBoolean());
TestReponse resp = new TestReponse();
final String customHeader = "custom-header";
final String customHeaderValue = "xyz";
resp.addHeader(customHeader, customHeaderValue);
channel.sendResponse(resp);
// inspect what was written
List<Object> writtenObjects = writeCapturingChannel.getWrittenObjects();
assertThat(writtenObjects.size(), is(1));
HttpResponse response = (HttpResponse) writtenObjects.get(0);
assertThat(response.headers().get("non-existent-header"), nullValue());
assertThat(response.headers().get(customHeader), equalTo(customHeaderValue));
assertThat(response.headers().get(HttpHeaders.Names.CONTENT_LENGTH), equalTo(Integer.toString(resp.content().length())));
assertThat(response.headers().get(HttpHeaders.Names.CONTENT_TYPE), equalTo(resp.contentType()));
}
private HttpResponse execRequestWithCors(final Settings settings, final String originValue, final String host) {
// construct request and send it over the transport layer
httpServerTransport = new NettyHttpServerTransport(settings, networkService, bigArrays);
HttpRequest httpRequest = new TestHttpRequest();
httpRequest.headers().add(HttpHeaders.Names.ORIGIN, originValue);
httpRequest.headers().add(HttpHeaders.Names.HOST, host);
WriteCapturingChannel writeCapturingChannel = new WriteCapturingChannel();
NettyHttpRequest request = new NettyHttpRequest(httpRequest, writeCapturingChannel);
NettyHttpChannel channel = new NettyHttpChannel(httpServerTransport, request, null, randomBoolean());
channel.sendResponse(new TestReponse());
// get the response
List<Object> writtenObjects = writeCapturingChannel.getWrittenObjects();
assertThat(writtenObjects.size(), is(1));
return (HttpResponse) writtenObjects.get(0);
}
private static class WriteCapturingChannel implements Channel {
private List<Object> writtenObjects = new ArrayList<>();
@Override
public Integer getId() {
return null;
}
@Override
public ChannelFactory getFactory() {
return null;
}
@Override
public Channel getParent() {
return null;
}
@Override
public ChannelConfig getConfig() {
return null;
}
@Override
public ChannelPipeline getPipeline() {
return null;
}
@Override
public boolean isOpen() {
return false;
}
@Override
public boolean isBound() {
return false;
}
@Override
public boolean isConnected() {
return false;
}
@Override
public SocketAddress getLocalAddress() {
return null;
}
@Override
public SocketAddress getRemoteAddress() {
return null;
}
@Override
public ChannelFuture write(Object message) {
writtenObjects.add(message);
return null;
}
@Override
public ChannelFuture write(Object message, SocketAddress remoteAddress) {
writtenObjects.add(message);
return null;
}
@Override
public ChannelFuture bind(SocketAddress localAddress) {
return null;
}
@Override
public ChannelFuture connect(SocketAddress remoteAddress) {
return null;
}
@Override
public ChannelFuture disconnect() {
return null;
}
@Override
public ChannelFuture unbind() {
return null;
}
@Override
public ChannelFuture close() {
return null;
}
@Override
public ChannelFuture getCloseFuture() {
return null;
}
@Override
public int getInterestOps() {
return 0;
}
@Override
public boolean isReadable() {
return false;
}
@Override
public boolean isWritable() {
return false;
}
@Override
public ChannelFuture setInterestOps(int interestOps) {
return null;
}
@Override
public ChannelFuture setReadable(boolean readable) {
return null;
}
@Override
public boolean getUserDefinedWritability(int index) {
return false;
}
@Override
public void setUserDefinedWritability(int index, boolean isWritable) {
}
@Override
public Object getAttachment() {
return null;
}
@Override
public void setAttachment(Object attachment) {
}
@Override
public int compareTo(Channel o) {
return 0;
}
public List<Object> getWrittenObjects() {
return writtenObjects;
}
}
private static class TestHttpRequest implements HttpRequest {
private HttpHeaders headers = new DefaultHttpHeaders();
private ChannelBuffer content = ChannelBuffers.EMPTY_BUFFER;
@Override
public HttpMethod getMethod() {
return null;
}
@Override
public void setMethod(HttpMethod method) {
}
@Override
public String getUri() {
return "";
}
@Override
public void setUri(String uri) {
}
@Override
public HttpVersion getProtocolVersion() {
return HttpVersion.HTTP_1_1;
}
@Override
public void setProtocolVersion(HttpVersion version) {
}
@Override
public HttpHeaders headers() {
return headers;
}
@Override
public ChannelBuffer getContent() {
return content;
}
@Override
public void setContent(ChannelBuffer content) {
this.content = content;
}
@Override
public boolean isChunked() {
return false;
}
@Override
public void setChunked(boolean chunked) {
}
}
private static class TestReponse extends RestResponse {
@Override
public String contentType() {
return "text";
}
@Override
public BytesReference content() {
return BytesArray.EMPTY;
}
@Override
public RestStatus status() {
return RestStatus.OK;
}
}
}