/* * Licensed to Elasticsearch under one or more contributor * license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright * ownership. Elasticsearch licenses this file to you under * the Apache License, Version 2.0 (the "License"); you may * not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.elasticsearch.rest; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.rest.FakeRestChannel; import org.elasticsearch.test.rest.FakeRestRequest; import org.junit.Test; import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; import java.util.HashSet; import java.util.List; import java.util.Set; import java.util.concurrent.atomic.AtomicInteger; import static org.hamcrest.CoreMatchers.equalTo; public class RestFilterChainTests extends ESTestCase { @Test public void testRestFilters() throws Exception { RestController restController = new RestController(Settings.EMPTY); int numFilters = randomInt(10); Set<Integer> orders = new HashSet<>(numFilters); while (orders.size() < numFilters) { orders.add(randomInt(10)); } List<RestFilter> filters = new ArrayList<>(); for (Integer order : orders) { TestFilter testFilter = new TestFilter(order, randomFrom(Operation.values())); filters.add(testFilter); restController.registerFilter(testFilter); } ArrayList<RestFilter> restFiltersByOrder = new ArrayList<>(filters); Collections.sort(restFiltersByOrder, new Comparator<RestFilter>() { @Override public int compare(RestFilter o1, RestFilter o2) { return Integer.compare(o1.order(), o2.order()); } }); List<RestFilter> expectedRestFilters = new ArrayList<>(); for (RestFilter filter : restFiltersByOrder) { TestFilter testFilter = (TestFilter) filter; expectedRestFilters.add(testFilter); if (!(testFilter.callback == Operation.CONTINUE_PROCESSING) ) { break; } } restController.registerHandler(RestRequest.Method.GET, "/", new RestHandler() { @Override public void handleRequest(RestRequest request, RestChannel channel) throws Exception { channel.sendResponse(new TestResponse()); } @Override public boolean canTripCircuitBreaker() { return true; } }); FakeRestRequest fakeRestRequest = new FakeRestRequest(); FakeRestChannel fakeRestChannel = new FakeRestChannel(fakeRestRequest, randomBoolean(), 1); restController.dispatchRequest(fakeRestRequest, fakeRestChannel); assertThat(fakeRestChannel.await(), equalTo(true)); List<TestFilter> testFiltersByLastExecution = new ArrayList<>(); for (RestFilter restFilter : filters) { testFiltersByLastExecution.add((TestFilter)restFilter); } Collections.sort(testFiltersByLastExecution, new Comparator<TestFilter>() { @Override public int compare(TestFilter o1, TestFilter o2) { return Long.compare(o1.executionToken, o2.executionToken); } }); ArrayList<TestFilter> finalTestFilters = new ArrayList<>(); for (RestFilter filter : testFiltersByLastExecution) { TestFilter testFilter = (TestFilter) filter; finalTestFilters.add(testFilter); if (!(testFilter.callback == Operation.CONTINUE_PROCESSING) ) { break; } } assertThat(finalTestFilters.size(), equalTo(expectedRestFilters.size())); for (int i = 0; i < finalTestFilters.size(); i++) { TestFilter testFilter = finalTestFilters.get(i); assertThat(testFilter, equalTo(expectedRestFilters.get(i))); assertThat(testFilter.runs.get(), equalTo(1)); } } @Test public void testTooManyContinueProcessing() throws Exception { final int additionalContinueCount = randomInt(10); TestFilter testFilter = new TestFilter(randomInt(), new Callback() { @Override public void execute(final RestRequest request, final RestChannel channel, final RestFilterChain filterChain) throws Exception { for (int i = 0; i <= additionalContinueCount; i++) { filterChain.continueProcessing(request, channel); } } }); RestController restController = new RestController(Settings.EMPTY); restController.registerFilter(testFilter); restController.registerHandler(RestRequest.Method.GET, "/", new RestHandler() { @Override public void handleRequest(RestRequest request, RestChannel channel) throws Exception { channel.sendResponse(new TestResponse()); } @Override public boolean canTripCircuitBreaker() { return true; } }); FakeRestRequest fakeRestRequest = new FakeRestRequest(); FakeRestChannel fakeRestChannel = new FakeRestChannel(fakeRestRequest, randomBoolean(), additionalContinueCount + 1); restController.dispatchRequest(fakeRestRequest, fakeRestChannel); fakeRestChannel.await(); assertThat(testFilter.runs.get(), equalTo(1)); assertThat(fakeRestChannel.responses().get(), equalTo(1)); assertThat(fakeRestChannel.errors().get(), equalTo(additionalContinueCount)); } private static enum Operation implements Callback { CONTINUE_PROCESSING { @Override public void execute(RestRequest request, RestChannel channel, RestFilterChain filterChain) throws Exception { filterChain.continueProcessing(request, channel); } }, CHANNEL_RESPONSE { @Override public void execute(RestRequest request, RestChannel channel, RestFilterChain filterChain) throws Exception { channel.sendResponse(new TestResponse()); } } } private static interface Callback { void execute(RestRequest request, RestChannel channel, RestFilterChain filterChain) throws Exception; } private final AtomicInteger counter = new AtomicInteger(); private class TestFilter extends RestFilter { private final int order; private final Callback callback; AtomicInteger runs = new AtomicInteger(); volatile int executionToken = Integer.MAX_VALUE; //the filters that don't run will go last in the sorted list TestFilter(int order, Callback callback) { this.order = order; this.callback = callback; } @Override public void process(RestRequest request, RestChannel channel, RestFilterChain filterChain) throws Exception { this.runs.incrementAndGet(); this.executionToken = counter.incrementAndGet(); this.callback.execute(request, channel, filterChain); } @Override public int order() { return order; } @Override public String toString() { return "[order:" + order + ", executionToken:" + executionToken + "]"; } } private static class TestResponse extends RestResponse { @Override public String contentType() { return null; } @Override public BytesReference content() { return null; } @Override public RestStatus status() { return RestStatus.OK; } } }