/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.rest;
import org.elasticsearch.client.node.NodeClient;
import org.elasticsearch.common.breaker.CircuitBreaker;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.component.AbstractLifecycleComponent;
import org.elasticsearch.common.logging.DeprecationLogger;
import org.elasticsearch.common.settings.ClusterSettings;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.transport.BoundTransportAddress;
import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.common.xcontent.yaml.YamlXContent;
import org.elasticsearch.http.HttpInfo;
import org.elasticsearch.http.HttpServerTransport;
import org.elasticsearch.http.HttpStats;
import org.elasticsearch.http.HttpTransportSettings;
import org.elasticsearch.indices.breaker.HierarchyCircuitBreakerService;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.rest.FakeRestRequest;
import org.junit.Before;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.UnaryOperator;
import static org.hamcrest.Matchers.containsString;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.doCallRealMethod;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
public class RestControllerTests extends ESTestCase {
private static final ByteSizeValue BREAKER_LIMIT = new ByteSizeValue(20);
private CircuitBreaker inFlightRequestsBreaker;
private RestController restController;
private HierarchyCircuitBreakerService circuitBreakerService;
@Before
public void setup() {
Settings settings = Settings.EMPTY;
circuitBreakerService = new HierarchyCircuitBreakerService(
Settings.builder()
.put(HierarchyCircuitBreakerService.IN_FLIGHT_REQUESTS_CIRCUIT_BREAKER_LIMIT_SETTING.getKey(), BREAKER_LIMIT)
.build(),
new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS));
// we can do this here only because we know that we don't adjust breaker settings dynamically in the test
inFlightRequestsBreaker = circuitBreakerService.getBreaker(CircuitBreaker.IN_FLIGHT_REQUESTS);
HttpServerTransport httpServerTransport = new TestHttpServerTransport();
restController = new RestController(settings, Collections.emptySet(), null, null, circuitBreakerService);
restController.registerHandler(RestRequest.Method.GET, "/",
(request, channel, client) -> channel.sendResponse(
new BytesRestResponse(RestStatus.OK, BytesRestResponse.TEXT_CONTENT_TYPE, BytesArray.EMPTY)));
restController.registerHandler(RestRequest.Method.GET, "/error", (request, channel, client) -> {
throw new IllegalArgumentException("test error");
});
httpServerTransport.start();
}
public void testApplyRelevantHeaders() throws Exception {
final ThreadContext threadContext = new ThreadContext(Settings.EMPTY);
Set<String> headers = new HashSet<>(Arrays.asList("header.1", "header.2"));
final RestController restController = new RestController(Settings.EMPTY, headers, null, null, circuitBreakerService);
Map<String, List<String>> restHeaders = new HashMap<>();
restHeaders.put("header.1", Collections.singletonList("true"));
restHeaders.put("header.2", Collections.singletonList("true"));
restHeaders.put("header.3", Collections.singletonList("false"));
restController.dispatchRequest(new FakeRestRequest.Builder(xContentRegistry()).withHeaders(restHeaders).build(), null, null,
threadContext, (RestRequest request, RestChannel channel, NodeClient client) -> {
assertEquals("true", threadContext.getHeader("header.1"));
assertEquals("true", threadContext.getHeader("header.2"));
assertNull(threadContext.getHeader("header.3"));
});
// the rest controller relies on the caller to stash the context, so we should expect these values here as we didn't stash the
// context in this test
assertEquals("true", threadContext.getHeader("header.1"));
assertEquals("true", threadContext.getHeader("header.2"));
assertNull(threadContext.getHeader("header.3"));
}
public void testCanTripCircuitBreaker() throws Exception {
RestController controller = new RestController(Settings.EMPTY, Collections.emptySet(), null, null, circuitBreakerService);
// trip circuit breaker by default
controller.registerHandler(RestRequest.Method.GET, "/trip", new FakeRestHandler(true));
controller.registerHandler(RestRequest.Method.GET, "/do-not-trip", new FakeRestHandler(false));
assertTrue(controller.canTripCircuitBreaker(new FakeRestRequest.Builder(xContentRegistry()).withPath("/trip").build()));
// assume trip even on unknown paths
assertTrue(controller.canTripCircuitBreaker(new FakeRestRequest.Builder(xContentRegistry()).withPath("/unknown-path").build()));
assertFalse(controller.canTripCircuitBreaker(new FakeRestRequest.Builder(xContentRegistry()).withPath("/do-not-trip").build()));
}
public void testRegisterAsDeprecatedHandler() {
RestController controller = mock(RestController.class);
RestRequest.Method method = randomFrom(RestRequest.Method.values());
String path = "/_" + randomAlphaOfLengthBetween(1, 6);
RestHandler handler = mock(RestHandler.class);
String deprecationMessage = randomAlphaOfLengthBetween(1, 10);
DeprecationLogger logger = mock(DeprecationLogger.class);
// don't want to test everything -- just that it actually wraps the handler
doCallRealMethod().when(controller).registerAsDeprecatedHandler(method, path, handler, deprecationMessage, logger);
controller.registerAsDeprecatedHandler(method, path, handler, deprecationMessage, logger);
verify(controller).registerHandler(eq(method), eq(path), any(DeprecationRestHandler.class));
}
public void testRegisterWithDeprecatedHandler() {
final RestController controller = mock(RestController.class);
final RestRequest.Method method = randomFrom(RestRequest.Method.values());
final String path = "/_" + randomAlphaOfLengthBetween(1, 6);
final RestHandler handler = mock(RestHandler.class);
final RestRequest.Method deprecatedMethod = randomFrom(RestRequest.Method.values());
final String deprecatedPath = "/_" + randomAlphaOfLengthBetween(1, 6);
final DeprecationLogger logger = mock(DeprecationLogger.class);
final String deprecationMessage = "[" + deprecatedMethod.name() + " " + deprecatedPath + "] is deprecated! Use [" +
method.name() + " " + path + "] instead.";
// don't want to test everything -- just that it actually wraps the handlers
doCallRealMethod().when(controller).registerWithDeprecatedHandler(method, path, handler, deprecatedMethod, deprecatedPath, logger);
controller.registerWithDeprecatedHandler(method, path, handler, deprecatedMethod, deprecatedPath, logger);
verify(controller).registerHandler(method, path, handler);
verify(controller).registerAsDeprecatedHandler(deprecatedMethod, deprecatedPath, handler, deprecationMessage, logger);
}
public void testRestHandlerWrapper() throws Exception {
AtomicBoolean handlerCalled = new AtomicBoolean(false);
AtomicBoolean wrapperCalled = new AtomicBoolean(false);
RestHandler handler = (RestRequest request, RestChannel channel, NodeClient client) -> {
handlerCalled.set(true);
};
UnaryOperator<RestHandler> wrapper = h -> {
assertSame(handler, h);
return (RestRequest request, RestChannel channel, NodeClient client) -> wrapperCalled.set(true);
};
final RestController restController = new RestController(Settings.EMPTY, Collections.emptySet(), wrapper, null,
circuitBreakerService);
final ThreadContext threadContext = new ThreadContext(Settings.EMPTY);
restController.dispatchRequest(new FakeRestRequest.Builder(xContentRegistry()).build(), null, null, threadContext, handler);
assertTrue(wrapperCalled.get());
assertFalse(handlerCalled.get());
}
/**
* Useful for testing with deprecation handler.
*/
private static class FakeRestHandler implements RestHandler {
private final boolean canTripCircuitBreaker;
private FakeRestHandler(boolean canTripCircuitBreaker) {
this.canTripCircuitBreaker = canTripCircuitBreaker;
}
@Override
public void handleRequest(RestRequest request, RestChannel channel, NodeClient client) throws Exception {
//no op
}
@Override
public boolean canTripCircuitBreaker() {
return canTripCircuitBreaker;
}
}
public void testDispatchRequestAddsAndFreesBytesOnSuccess() {
int contentLength = BREAKER_LIMIT.bytesAsInt();
String content = randomAlphaOfLength(contentLength);
TestRestRequest request = new TestRestRequest("/", content, XContentType.JSON);
AssertingChannel channel = new AssertingChannel(request, true, RestStatus.OK);
restController.dispatchRequest(request, channel, new ThreadContext(Settings.EMPTY));
assertEquals(0, inFlightRequestsBreaker.getTrippedCount());
assertEquals(0, inFlightRequestsBreaker.getUsed());
}
public void testDispatchRequestAddsAndFreesBytesOnError() {
int contentLength = BREAKER_LIMIT.bytesAsInt();
String content = randomAlphaOfLength(contentLength);
TestRestRequest request = new TestRestRequest("/error", content, XContentType.JSON);
AssertingChannel channel = new AssertingChannel(request, true, RestStatus.BAD_REQUEST);
restController.dispatchRequest(request, channel, new ThreadContext(Settings.EMPTY));
assertEquals(0, inFlightRequestsBreaker.getTrippedCount());
assertEquals(0, inFlightRequestsBreaker.getUsed());
}
public void testDispatchRequestAddsAndFreesBytesOnlyOnceOnError() {
int contentLength = BREAKER_LIMIT.bytesAsInt();
String content = randomAlphaOfLength(contentLength);
// we will produce an error in the rest handler and one more when sending the error response
TestRestRequest request = new TestRestRequest("/error", content, XContentType.JSON);
ExceptionThrowingChannel channel = new ExceptionThrowingChannel(request, true);
restController.dispatchRequest(request, channel, new ThreadContext(Settings.EMPTY));
assertEquals(0, inFlightRequestsBreaker.getTrippedCount());
assertEquals(0, inFlightRequestsBreaker.getUsed());
}
public void testDispatchRequestLimitsBytes() {
int contentLength = BREAKER_LIMIT.bytesAsInt() + 1;
String content = randomAlphaOfLength(contentLength);
TestRestRequest request = new TestRestRequest("/", content, XContentType.JSON);
AssertingChannel channel = new AssertingChannel(request, true, RestStatus.SERVICE_UNAVAILABLE);
restController.dispatchRequest(request, channel, new ThreadContext(Settings.EMPTY));
assertEquals(1, inFlightRequestsBreaker.getTrippedCount());
assertEquals(0, inFlightRequestsBreaker.getUsed());
}
public void testDispatchRequiresContentTypeForRequestsWithContent() {
String content = randomAlphaOfLengthBetween(1, BREAKER_LIMIT.bytesAsInt());
TestRestRequest request = new TestRestRequest("/", content, null);
AssertingChannel channel = new AssertingChannel(request, true, RestStatus.NOT_ACCEPTABLE);
restController = new RestController(
Settings.builder().put(HttpTransportSettings.SETTING_HTTP_CONTENT_TYPE_REQUIRED.getKey(), true).build(),
Collections.emptySet(), null, null, circuitBreakerService);
restController.registerHandler(RestRequest.Method.GET, "/",
(r, c, client) -> c.sendResponse(
new BytesRestResponse(RestStatus.OK, BytesRestResponse.TEXT_CONTENT_TYPE, BytesArray.EMPTY)));
assertFalse(channel.getSendResponseCalled());
restController.dispatchRequest(request, channel, new ThreadContext(Settings.EMPTY));
assertTrue(channel.getSendResponseCalled());
}
public void testDispatchDoesNotRequireContentTypeForRequestsWithoutContent() {
FakeRestRequest fakeRestRequest = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).build();
AssertingChannel channel = new AssertingChannel(fakeRestRequest, true, RestStatus.OK);
assertFalse(channel.getSendResponseCalled());
restController.dispatchRequest(fakeRestRequest, channel, new ThreadContext(Settings.EMPTY));
assertTrue(channel.getSendResponseCalled());
}
public void testDispatchFailsWithPlainText() {
String content = randomAlphaOfLengthBetween(1, BREAKER_LIMIT.bytesAsInt());
FakeRestRequest fakeRestRequest = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY)
.withContent(new BytesArray(content), null).withPath("/foo")
.withHeaders(Collections.singletonMap("Content-Type", Collections.singletonList("text/plain"))).build();
AssertingChannel channel = new AssertingChannel(fakeRestRequest, true, RestStatus.NOT_ACCEPTABLE);
restController.registerHandler(RestRequest.Method.GET, "/foo", new RestHandler() {
@Override
public void handleRequest(RestRequest request, RestChannel channel, NodeClient client) throws Exception {
channel.sendResponse(new BytesRestResponse(RestStatus.OK, BytesRestResponse.TEXT_CONTENT_TYPE, BytesArray.EMPTY));
}
});
assertFalse(channel.getSendResponseCalled());
restController.dispatchRequest(fakeRestRequest, channel, new ThreadContext(Settings.EMPTY));
assertTrue(channel.getSendResponseCalled());
}
public void testDispatchUnsupportedContentType() {
FakeRestRequest fakeRestRequest = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY)
.withContent(new BytesArray("{}"), null).withPath("/")
.withHeaders(Collections.singletonMap("Content-Type", Collections.singletonList("application/x-www-form-urlencoded"))).build();
AssertingChannel channel = new AssertingChannel(fakeRestRequest, true, RestStatus.NOT_ACCEPTABLE);
assertFalse(channel.getSendResponseCalled());
restController.dispatchRequest(fakeRestRequest, channel, new ThreadContext(Settings.EMPTY));
assertTrue(channel.getSendResponseCalled());
}
public void testDispatchWorksWithNewlineDelimitedJson() {
final String mimeType = "application/x-ndjson";
String content = randomAlphaOfLengthBetween(1, BREAKER_LIMIT.bytesAsInt());
FakeRestRequest fakeRestRequest = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY)
.withContent(new BytesArray(content), null).withPath("/foo")
.withHeaders(Collections.singletonMap("Content-Type", Collections.singletonList(mimeType))).build();
AssertingChannel channel = new AssertingChannel(fakeRestRequest, true, RestStatus.OK);
restController.registerHandler(RestRequest.Method.GET, "/foo", new RestHandler() {
@Override
public void handleRequest(RestRequest request, RestChannel channel, NodeClient client) throws Exception {
channel.sendResponse(new BytesRestResponse(RestStatus.OK, BytesRestResponse.TEXT_CONTENT_TYPE, BytesArray.EMPTY));
}
@Override
public boolean supportsContentStream() {
return true;
}
});
assertFalse(channel.getSendResponseCalled());
restController.dispatchRequest(fakeRestRequest, channel, new ThreadContext(Settings.EMPTY));
assertTrue(channel.getSendResponseCalled());
}
public void testDispatchWithContentStream() {
final String mimeType = randomFrom("application/json", "application/smile");
String content = randomAlphaOfLengthBetween(1, BREAKER_LIMIT.bytesAsInt());
FakeRestRequest fakeRestRequest = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY)
.withContent(new BytesArray(content), null).withPath("/foo")
.withHeaders(Collections.singletonMap("Content-Type", Collections.singletonList(mimeType))).build();
AssertingChannel channel = new AssertingChannel(fakeRestRequest, true, RestStatus.OK);
restController.registerHandler(RestRequest.Method.GET, "/foo", new RestHandler() {
@Override
public void handleRequest(RestRequest request, RestChannel channel, NodeClient client) throws Exception {
channel.sendResponse(new BytesRestResponse(RestStatus.OK, BytesRestResponse.TEXT_CONTENT_TYPE, BytesArray.EMPTY));
}
@Override
public boolean supportsContentStream() {
return true;
}
});
assertFalse(channel.getSendResponseCalled());
restController.dispatchRequest(fakeRestRequest, channel, new ThreadContext(Settings.EMPTY));
assertTrue(channel.getSendResponseCalled());
}
public void testDispatchWithContentStreamNoContentType() {
FakeRestRequest fakeRestRequest = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY)
.withContent(new BytesArray("{}"), null).withPath("/foo").build();
AssertingChannel channel = new AssertingChannel(fakeRestRequest, true, RestStatus.NOT_ACCEPTABLE);
restController.registerHandler(RestRequest.Method.GET, "/foo", new RestHandler() {
@Override
public void handleRequest(RestRequest request, RestChannel channel, NodeClient client) throws Exception {
channel.sendResponse(new BytesRestResponse(RestStatus.OK, BytesRestResponse.TEXT_CONTENT_TYPE, BytesArray.EMPTY));
}
@Override
public boolean supportsContentStream() {
return true;
}
});
assertFalse(channel.getSendResponseCalled());
restController.dispatchRequest(fakeRestRequest, channel, new ThreadContext(Settings.EMPTY));
assertTrue(channel.getSendResponseCalled());
}
public void testNonStreamingXContentCausesErrorResponse() throws IOException {
FakeRestRequest fakeRestRequest = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY)
.withContent(YamlXContent.contentBuilder().startObject().endObject().bytes(), XContentType.YAML).withPath("/foo").build();
AssertingChannel channel = new AssertingChannel(fakeRestRequest, true, RestStatus.NOT_ACCEPTABLE);
restController.registerHandler(RestRequest.Method.GET, "/foo", new RestHandler() {
@Override
public void handleRequest(RestRequest request, RestChannel channel, NodeClient client) throws Exception {
channel.sendResponse(new BytesRestResponse(RestStatus.OK, BytesRestResponse.TEXT_CONTENT_TYPE, BytesArray.EMPTY));
}
@Override
public boolean supportsContentStream() {
return true;
}
});
assertFalse(channel.getSendResponseCalled());
restController.dispatchRequest(fakeRestRequest, channel, new ThreadContext(Settings.EMPTY));
assertTrue(channel.getSendResponseCalled());
}
public void testUnknownContentWithContentStream() {
FakeRestRequest fakeRestRequest = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY)
.withContent(new BytesArray("aaaabbbbb"), null).withPath("/foo")
.withHeaders(Collections.singletonMap("Content-Type", Collections.singletonList("foo/bar")))
.build();
AssertingChannel channel = new AssertingChannel(fakeRestRequest, true, RestStatus.NOT_ACCEPTABLE);
restController.registerHandler(RestRequest.Method.GET, "/foo", new RestHandler() {
@Override
public void handleRequest(RestRequest request, RestChannel channel, NodeClient client) throws Exception {
channel.sendResponse(new BytesRestResponse(RestStatus.OK, BytesRestResponse.TEXT_CONTENT_TYPE, BytesArray.EMPTY));
}
@Override
public boolean supportsContentStream() {
return true;
}
});
assertFalse(channel.getSendResponseCalled());
restController.dispatchRequest(fakeRestRequest, channel, new ThreadContext(Settings.EMPTY));
assertTrue(channel.getSendResponseCalled());
}
public void testDispatchBadRequest() {
final FakeRestRequest fakeRestRequest = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).build();
final AssertingChannel channel = new AssertingChannel(fakeRestRequest, true, RestStatus.BAD_REQUEST);
restController.dispatchBadRequest(
fakeRestRequest,
channel,
new ThreadContext(Settings.EMPTY),
randomBoolean() ? new IllegalStateException("bad request") : new Throwable("bad request"));
assertTrue(channel.getSendResponseCalled());
assertThat(channel.getRestResponse().content().utf8ToString(), containsString("bad request"));
}
public void testDispatchBadRequestUnknownCause() {
final FakeRestRequest fakeRestRequest = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).build();
final AssertingChannel channel = new AssertingChannel(fakeRestRequest, true, RestStatus.BAD_REQUEST);
restController.dispatchBadRequest(fakeRestRequest, channel, new ThreadContext(Settings.EMPTY), null);
assertTrue(channel.getSendResponseCalled());
assertThat(channel.getRestResponse().content().utf8ToString(), containsString("unknown cause"));
}
private static final class TestHttpServerTransport extends AbstractLifecycleComponent implements
HttpServerTransport {
TestHttpServerTransport() {
super(Settings.EMPTY);
}
@Override
protected void doStart() {
}
@Override
protected void doStop() {
}
@Override
protected void doClose() {
}
@Override
public BoundTransportAddress boundAddress() {
TransportAddress transportAddress = buildNewFakeTransportAddress();
return new BoundTransportAddress(new TransportAddress[] {transportAddress} ,transportAddress);
}
@Override
public HttpInfo info() {
return null;
}
@Override
public HttpStats stats() {
return null;
}
}
private static final class AssertingChannel extends AbstractRestChannel {
private final RestStatus expectedStatus;
private final AtomicReference<RestResponse> responseReference = new AtomicReference<>();
protected AssertingChannel(RestRequest request, boolean detailedErrorsEnabled, RestStatus expectedStatus) {
super(request, detailedErrorsEnabled);
this.expectedStatus = expectedStatus;
}
@Override
public void sendResponse(RestResponse response) {
assertEquals(expectedStatus, response.status());
responseReference.set(response);
}
RestResponse getRestResponse() {
return responseReference.get();
}
boolean getSendResponseCalled() {
return getRestResponse() != null;
}
}
private static final class ExceptionThrowingChannel extends AbstractRestChannel {
protected ExceptionThrowingChannel(RestRequest request, boolean detailedErrorsEnabled) {
super(request, detailedErrorsEnabled);
}
@Override
public void sendResponse(RestResponse response) {
throw new IllegalStateException("always throwing an exception for testing");
}
}
private static final class TestRestRequest extends RestRequest {
private final BytesReference content;
private TestRestRequest(String path, String content, XContentType xContentType) {
super(NamedXContentRegistry.EMPTY, Collections.emptyMap(), path, xContentType == null ?
Collections.emptyMap() : Collections.singletonMap("Content-Type", Collections.singletonList(xContentType.mediaType())));
this.content = new BytesArray(content);
}
@Override
public Method method() {
return Method.GET;
}
@Override
public String uri() {
return null;
}
@Override
public boolean hasContent() {
return true;
}
@Override
public BytesReference content() {
return content;
}
}
}