/* * Licensed to Elasticsearch under one or more contributor * license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright * ownership. Elasticsearch licenses this file to you under * the Apache License, Version 2.0 (the "License"); you may * not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.elasticsearch.action.support; import org.elasticsearch.ElasticsearchTimeoutException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionRequest; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.ActionResponse; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.tasks.Task; import org.elasticsearch.tasks.TaskManager; import org.elasticsearch.test.ESTestCase; import org.junit.Before; import java.util.ArrayList; import java.util.Comparator; import java.util.HashSet; import java.util.List; import java.util.Set; import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.CoreMatchers.instanceOf; import static org.hamcrest.CoreMatchers.notNullValue; public class TransportActionFilterChainTests extends ESTestCase { private AtomicInteger counter; @Before public void init() throws Exception { counter = new AtomicInteger(); } public void testActionFiltersRequest() throws ExecutionException, InterruptedException { int numFilters = randomInt(10); Set<Integer> orders = new HashSet<>(numFilters); while (orders.size() < numFilters) { orders.add(randomInt(10)); } Set<ActionFilter> filters = new HashSet<>(); for (Integer order : orders) { filters.add(new RequestTestFilter(order, randomFrom(RequestOperation.values()))); } String actionName = randomAlphaOfLength(randomInt(30)); ActionFilters actionFilters = new ActionFilters(filters); TransportAction<TestRequest, TestResponse> transportAction = new TransportAction<TestRequest, TestResponse>(Settings.EMPTY, actionName, null, actionFilters, null, new TaskManager(Settings.EMPTY)) { @Override protected void doExecute(TestRequest request, ActionListener<TestResponse> listener) { listener.onResponse(new TestResponse()); } }; ArrayList<ActionFilter> actionFiltersByOrder = new ArrayList<>(filters); actionFiltersByOrder.sort(Comparator.comparingInt(ActionFilter::order)); List<ActionFilter> expectedActionFilters = new ArrayList<>(); boolean errorExpected = false; for (ActionFilter filter : actionFiltersByOrder) { RequestTestFilter testFilter = (RequestTestFilter) filter; expectedActionFilters.add(testFilter); if (testFilter.callback == RequestOperation.LISTENER_FAILURE) { errorExpected = true; } if (!(testFilter.callback == RequestOperation.CONTINUE_PROCESSING) ) { break; } } PlainActionFuture<TestResponse> future = PlainActionFuture.newFuture(); transportAction.execute(new TestRequest(), future); try { assertThat(future.get(), notNullValue()); assertThat("shouldn't get here if an error is expected", errorExpected, equalTo(false)); } catch (ExecutionException e) { assertThat("shouldn't get here if an error is not expected " + e.getMessage(), errorExpected, equalTo(true)); } List<RequestTestFilter> testFiltersByLastExecution = new ArrayList<>(); for (ActionFilter actionFilter : actionFilters.filters()) { testFiltersByLastExecution.add((RequestTestFilter) actionFilter); } testFiltersByLastExecution.sort(Comparator.comparingInt(o -> o.executionToken)); ArrayList<RequestTestFilter> finalTestFilters = new ArrayList<>(); for (ActionFilter filter : testFiltersByLastExecution) { RequestTestFilter testFilter = (RequestTestFilter) filter; finalTestFilters.add(testFilter); if (!(testFilter.callback == RequestOperation.CONTINUE_PROCESSING) ) { break; } } assertThat(finalTestFilters.size(), equalTo(expectedActionFilters.size())); for (int i = 0; i < finalTestFilters.size(); i++) { RequestTestFilter testFilter = finalTestFilters.get(i); assertThat(testFilter, equalTo(expectedActionFilters.get(i))); assertThat(testFilter.runs.get(), equalTo(1)); assertThat(testFilter.lastActionName, equalTo(actionName)); } } public void testTooManyContinueProcessingRequest() throws ExecutionException, InterruptedException { final int additionalContinueCount = randomInt(10); RequestTestFilter testFilter = new RequestTestFilter(randomInt(), new RequestCallback() { @Override public <Request extends ActionRequest, Response extends ActionResponse> void execute(Task task, String action, Request request, ActionListener<Response> listener, ActionFilterChain<Request, Response> actionFilterChain) { for (int i = 0; i <= additionalContinueCount; i++) { actionFilterChain.proceed(task, action, request, listener); } } }); Set<ActionFilter> filters = new HashSet<>(); filters.add(testFilter); String actionName = randomAlphaOfLength(randomInt(30)); ActionFilters actionFilters = new ActionFilters(filters); TransportAction<TestRequest, TestResponse> transportAction = new TransportAction<TestRequest, TestResponse>(Settings.EMPTY, actionName, null, actionFilters, null, new TaskManager(Settings.EMPTY)) { @Override protected void doExecute(TestRequest request, ActionListener<TestResponse> listener) { listener.onResponse(new TestResponse()); } }; final CountDownLatch latch = new CountDownLatch(additionalContinueCount + 1); final AtomicInteger responses = new AtomicInteger(); final List<Throwable> failures = new CopyOnWriteArrayList<>(); transportAction.execute(new TestRequest(), new ActionListener<TestResponse>() { @Override public void onResponse(TestResponse testResponse) { responses.incrementAndGet(); latch.countDown(); } @Override public void onFailure(Exception e) { failures.add(e); latch.countDown(); } }); if (!latch.await(10, TimeUnit.SECONDS)) { fail("timeout waiting for the filter to notify the listener as many times as expected"); } assertThat(testFilter.runs.get(), equalTo(1)); assertThat(testFilter.lastActionName, equalTo(actionName)); assertThat(responses.get(), equalTo(1)); assertThat(failures.size(), equalTo(additionalContinueCount)); for (Throwable failure : failures) { assertThat(failure, instanceOf(IllegalStateException.class)); } } private class RequestTestFilter implements ActionFilter { private final RequestCallback callback; private final int order; AtomicInteger runs = new AtomicInteger(); volatile String lastActionName; volatile int executionToken = Integer.MAX_VALUE; //the filters that don't run will go last in the sorted list RequestTestFilter(int order, RequestCallback callback) { this.order = order; this.callback = callback; } @Override public int order() { return order; } @Override public <Request extends ActionRequest, Response extends ActionResponse> void apply(Task task, String action, Request request, ActionListener<Response> listener, ActionFilterChain<Request, Response> chain) { this.runs.incrementAndGet(); this.lastActionName = action; this.executionToken = counter.incrementAndGet(); this.callback.execute(task, action, request, listener, chain); } } private enum RequestOperation implements RequestCallback { CONTINUE_PROCESSING { @Override public <Request extends ActionRequest, Response extends ActionResponse> void execute(Task task, String action, Request request, ActionListener<Response> listener, ActionFilterChain<Request, Response> actionFilterChain) { actionFilterChain.proceed(task, action, request, listener); } }, LISTENER_RESPONSE { @Override @SuppressWarnings("unchecked") // Safe because its all we test with public <Request extends ActionRequest, Response extends ActionResponse> void execute(Task task, String action, Request request, ActionListener<Response> listener, ActionFilterChain<Request, Response> actionFilterChain) { ((ActionListener<TestResponse>) listener).onResponse(new TestResponse()); } }, LISTENER_FAILURE { @Override public <Request extends ActionRequest, Response extends ActionResponse> void execute(Task task, String action, Request request, ActionListener<Response> listener, ActionFilterChain<Request, Response> actionFilterChain) { listener.onFailure(new ElasticsearchTimeoutException("")); } } } private interface RequestCallback { <Request extends ActionRequest, Response extends ActionResponse> void execute(Task task, String action, Request request, ActionListener<Response> listener, ActionFilterChain<Request, Response> actionFilterChain); } public static class TestRequest extends ActionRequest { @Override public ActionRequestValidationException validate() { return null; } } private static class TestResponse extends ActionResponse { } }