/* * Licensed to Elasticsearch under one or more contributor * license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright * ownership. Elasticsearch licenses this file to you under * the Apache License, Version 2.0 (the "License"); you may * not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.elasticsearch.action.support; import org.elasticsearch.ElasticsearchTimeoutException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionRequest; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.ActionResponse; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.tasks.Task; import org.elasticsearch.tasks.TaskManager; import org.elasticsearch.test.ESTestCase; import org.junit.Before; import org.junit.Test; import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; import java.util.HashSet; import java.util.List; import java.util.Set; import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import static org.hamcrest.CoreMatchers.*; public class TransportActionFilterChainTests extends ESTestCase { private AtomicInteger counter; @Before public void init() throws Exception { counter = new AtomicInteger(); } @Test public void testActionFiltersRequest() throws ExecutionException, InterruptedException { int numFilters = randomInt(10); Set<Integer> orders = new HashSet<>(numFilters); while (orders.size() < numFilters) { orders.add(randomInt(10)); } Set<ActionFilter> filters = new HashSet<>(); for (Integer order : orders) { filters.add(new RequestTestFilter(order, randomFrom(RequestOperation.values()))); } String actionName = randomAsciiOfLength(randomInt(30)); ActionFilters actionFilters = new ActionFilters(filters); TransportAction<TestRequest, TestResponse> transportAction = new TransportAction<TestRequest, TestResponse>(Settings.EMPTY, actionName, null, actionFilters, null, new TaskManager(Settings.EMPTY)) { @Override protected void doExecute(TestRequest request, ActionListener<TestResponse> listener) { listener.onResponse(new TestResponse()); } }; ArrayList<ActionFilter> actionFiltersByOrder = new ArrayList<>(filters); Collections.sort(actionFiltersByOrder, new Comparator<ActionFilter>() { @Override public int compare(ActionFilter o1, ActionFilter o2) { return Integer.compare(o1.order(), o2.order()); } }); List<ActionFilter> expectedActionFilters = new ArrayList<>(); boolean errorExpected = false; for (ActionFilter filter : actionFiltersByOrder) { RequestTestFilter testFilter = (RequestTestFilter) filter; expectedActionFilters.add(testFilter); if (testFilter.callback == RequestOperation.LISTENER_FAILURE) { errorExpected = true; } if (!(testFilter.callback == RequestOperation.CONTINUE_PROCESSING) ) { break; } } PlainListenableActionFuture<TestResponse> future = new PlainListenableActionFuture<>(null); transportAction.execute(new TestRequest(), future); try { assertThat(future.get(), notNullValue()); assertThat("shouldn't get here if an error is expected", errorExpected, equalTo(false)); } catch(Throwable t) { assertThat("shouldn't get here if an error is not expected " + t.getMessage(), errorExpected, equalTo(true)); } List<RequestTestFilter> testFiltersByLastExecution = new ArrayList<>(); for (ActionFilter actionFilter : actionFilters.filters()) { testFiltersByLastExecution.add((RequestTestFilter) actionFilter); } Collections.sort(testFiltersByLastExecution, new Comparator<RequestTestFilter>() { @Override public int compare(RequestTestFilter o1, RequestTestFilter o2) { return Integer.compare(o1.executionToken, o2.executionToken); } }); ArrayList<RequestTestFilter> finalTestFilters = new ArrayList<>(); for (ActionFilter filter : testFiltersByLastExecution) { RequestTestFilter testFilter = (RequestTestFilter) filter; finalTestFilters.add(testFilter); if (!(testFilter.callback == RequestOperation.CONTINUE_PROCESSING) ) { break; } } assertThat(finalTestFilters.size(), equalTo(expectedActionFilters.size())); for (int i = 0; i < finalTestFilters.size(); i++) { RequestTestFilter testFilter = finalTestFilters.get(i); assertThat(testFilter, equalTo(expectedActionFilters.get(i))); assertThat(testFilter.runs.get(), equalTo(1)); assertThat(testFilter.lastActionName, equalTo(actionName)); } } @Test public void testActionFiltersResponse() throws ExecutionException, InterruptedException { int numFilters = randomInt(10); Set<Integer> orders = new HashSet<>(numFilters); while (orders.size() < numFilters) { orders.add(randomInt(10)); } Set<ActionFilter> filters = new HashSet<>(); for (Integer order : orders) { filters.add(new ResponseTestFilter(order, randomFrom(ResponseOperation.values()))); } String actionName = randomAsciiOfLength(randomInt(30)); ActionFilters actionFilters = new ActionFilters(filters); TransportAction<TestRequest, TestResponse> transportAction = new TransportAction<TestRequest, TestResponse>(Settings.EMPTY, actionName, null, actionFilters, null, new TaskManager(Settings.EMPTY)) { @Override protected void doExecute(TestRequest request, ActionListener<TestResponse> listener) { listener.onResponse(new TestResponse()); } }; ArrayList<ActionFilter> actionFiltersByOrder = new ArrayList<>(filters); Collections.sort(actionFiltersByOrder, new Comparator<ActionFilter>() { @Override public int compare(ActionFilter o1, ActionFilter o2) { return Integer.compare(o2.order(), o1.order()); } }); List<ActionFilter> expectedActionFilters = new ArrayList<>(); boolean errorExpected = false; for (ActionFilter filter : actionFiltersByOrder) { ResponseTestFilter testFilter = (ResponseTestFilter) filter; expectedActionFilters.add(testFilter); if (testFilter.callback == ResponseOperation.LISTENER_FAILURE) { errorExpected = true; } if (testFilter.callback != ResponseOperation.CONTINUE_PROCESSING) { break; } } PlainListenableActionFuture<TestResponse> future = new PlainListenableActionFuture<>(null); transportAction.execute(new TestRequest(), future); try { assertThat(future.get(), notNullValue()); assertThat("shouldn't get here if an error is expected", errorExpected, equalTo(false)); } catch(Throwable t) { assertThat("shouldn't get here if an error is not expected " + t.getMessage(), errorExpected, equalTo(true)); } List<ResponseTestFilter> testFiltersByLastExecution = new ArrayList<>(); for (ActionFilter actionFilter : actionFilters.filters()) { testFiltersByLastExecution.add((ResponseTestFilter) actionFilter); } Collections.sort(testFiltersByLastExecution, new Comparator<ResponseTestFilter>() { @Override public int compare(ResponseTestFilter o1, ResponseTestFilter o2) { return Integer.compare(o1.executionToken, o2.executionToken); } }); ArrayList<ResponseTestFilter> finalTestFilters = new ArrayList<>(); for (ActionFilter filter : testFiltersByLastExecution) { ResponseTestFilter testFilter = (ResponseTestFilter) filter; finalTestFilters.add(testFilter); if (testFilter.callback != ResponseOperation.CONTINUE_PROCESSING) { break; } } assertThat(finalTestFilters.size(), equalTo(expectedActionFilters.size())); for (int i = 0; i < finalTestFilters.size(); i++) { ResponseTestFilter testFilter = finalTestFilters.get(i); assertThat(testFilter, equalTo(expectedActionFilters.get(i))); assertThat(testFilter.runs.get(), equalTo(1)); assertThat(testFilter.lastActionName, equalTo(actionName)); } } @Test public void testTooManyContinueProcessingRequest() throws ExecutionException, InterruptedException { final int additionalContinueCount = randomInt(10); RequestTestFilter testFilter = new RequestTestFilter(randomInt(), new RequestCallback() { @Override public void execute(Task task, final String action, final ActionRequest actionRequest, final ActionListener actionListener, final ActionFilterChain actionFilterChain) { for (int i = 0; i <= additionalContinueCount; i++) { actionFilterChain.proceed(task, action, actionRequest, actionListener); } } }); Set<ActionFilter> filters = new HashSet<>(); filters.add(testFilter); String actionName = randomAsciiOfLength(randomInt(30)); ActionFilters actionFilters = new ActionFilters(filters); TransportAction<TestRequest, TestResponse> transportAction = new TransportAction<TestRequest, TestResponse>(Settings.EMPTY, actionName, null, actionFilters, null, new TaskManager(Settings.EMPTY)) { @Override protected void doExecute(TestRequest request, ActionListener<TestResponse> listener) { listener.onResponse(new TestResponse()); } }; final CountDownLatch latch = new CountDownLatch(additionalContinueCount + 1); final AtomicInteger responses = new AtomicInteger(); final List<Throwable> failures = new CopyOnWriteArrayList<>(); transportAction.execute(new TestRequest(), new ActionListener<TestResponse>() { @Override public void onResponse(TestResponse testResponse) { responses.incrementAndGet(); latch.countDown(); } @Override public void onFailure(Throwable e) { failures.add(e); latch.countDown(); } }); if (!latch.await(10, TimeUnit.SECONDS)) { fail("timeout waiting for the filter to notify the listener as many times as expected"); } assertThat(testFilter.runs.get(), equalTo(1)); assertThat(testFilter.lastActionName, equalTo(actionName)); assertThat(responses.get(), equalTo(1)); assertThat(failures.size(), equalTo(additionalContinueCount)); for (Throwable failure : failures) { assertThat(failure, instanceOf(IllegalStateException.class)); } } @Test public void testTooManyContinueProcessingResponse() throws ExecutionException, InterruptedException { final int additionalContinueCount = randomInt(10); ResponseTestFilter testFilter = new ResponseTestFilter(randomInt(), new ResponseCallback() { @Override public void execute(String action, ActionResponse response, ActionListener listener, ActionFilterChain chain) { for (int i = 0; i <= additionalContinueCount; i++) { chain.proceed(action, response, listener); } } }); Set<ActionFilter> filters = new HashSet<>(); filters.add(testFilter); String actionName = randomAsciiOfLength(randomInt(30)); ActionFilters actionFilters = new ActionFilters(filters); TransportAction<TestRequest, TestResponse> transportAction = new TransportAction<TestRequest, TestResponse>(Settings.EMPTY, actionName, null, actionFilters, null, new TaskManager(Settings.EMPTY)) { @Override protected void doExecute(TestRequest request, ActionListener<TestResponse> listener) { listener.onResponse(new TestResponse()); } }; final CountDownLatch latch = new CountDownLatch(additionalContinueCount + 1); final AtomicInteger responses = new AtomicInteger(); final List<Throwable> failures = new CopyOnWriteArrayList<>(); transportAction.execute(new TestRequest(), new ActionListener<TestResponse>() { @Override public void onResponse(TestResponse testResponse) { responses.incrementAndGet(); latch.countDown(); } @Override public void onFailure(Throwable e) { failures.add(e); latch.countDown(); } }); if (!latch.await(10, TimeUnit.SECONDS)) { fail("timeout waiting for the filter to notify the listener as many times as expected"); } assertThat(testFilter.runs.get(), equalTo(1)); assertThat(testFilter.lastActionName, equalTo(actionName)); assertThat(responses.get(), equalTo(1)); assertThat(failures.size(), equalTo(additionalContinueCount)); for (Throwable failure : failures) { assertThat(failure, instanceOf(IllegalStateException.class)); } } private class RequestTestFilter implements ActionFilter { private final RequestCallback callback; private final int order; AtomicInteger runs = new AtomicInteger(); volatile String lastActionName; volatile int executionToken = Integer.MAX_VALUE; //the filters that don't run will go last in the sorted list RequestTestFilter(int order, RequestCallback callback) { this.order = order; this.callback = callback; } @Override public int order() { return order; } @SuppressWarnings("unchecked") @Override public void apply(Task task, String action, ActionRequest actionRequest, ActionListener actionListener, ActionFilterChain actionFilterChain) { this.runs.incrementAndGet(); this.lastActionName = action; this.executionToken = counter.incrementAndGet(); this.callback.execute(task, action, actionRequest, actionListener, actionFilterChain); } @Override public void apply(String action, ActionResponse response, ActionListener listener, ActionFilterChain chain) { chain.proceed(action, response, listener); } } private class ResponseTestFilter implements ActionFilter { private final ResponseCallback callback; private final int order; AtomicInteger runs = new AtomicInteger(); volatile String lastActionName; volatile int executionToken = Integer.MAX_VALUE; //the filters that don't run will go last in the sorted list ResponseTestFilter(int order, ResponseCallback callback) { this.order = order; this.callback = callback; } @Override public int order() { return order; } @Override public void apply(Task task, String action, ActionRequest request, ActionListener listener, ActionFilterChain chain) { chain.proceed(task, action, request, listener); } @Override public void apply(String action, ActionResponse response, ActionListener listener, ActionFilterChain chain) { this.runs.incrementAndGet(); this.lastActionName = action; this.executionToken = counter.incrementAndGet(); this.callback.execute(action, response, listener, chain); } } private static enum RequestOperation implements RequestCallback { CONTINUE_PROCESSING { @Override public void execute(Task task, String action, ActionRequest actionRequest, ActionListener actionListener, ActionFilterChain actionFilterChain) { actionFilterChain.proceed(task, action, actionRequest, actionListener); } }, LISTENER_RESPONSE { @Override @SuppressWarnings("unchecked") public void execute(Task task, String action, ActionRequest actionRequest, ActionListener actionListener, ActionFilterChain actionFilterChain) { actionListener.onResponse(new TestResponse()); } }, LISTENER_FAILURE { @Override public void execute(Task task, String action, ActionRequest actionRequest, ActionListener actionListener, ActionFilterChain actionFilterChain) { actionListener.onFailure(new ElasticsearchTimeoutException("")); } } } private static enum ResponseOperation implements ResponseCallback { CONTINUE_PROCESSING { @Override public void execute(String action, ActionResponse response, ActionListener listener, ActionFilterChain chain) { chain.proceed(action, response, listener); } }, LISTENER_RESPONSE { @Override @SuppressWarnings("unchecked") public void execute(String action, ActionResponse response, ActionListener listener, ActionFilterChain chain) { listener.onResponse(new TestResponse()); } }, LISTENER_FAILURE { @Override public void execute(String action, ActionResponse response, ActionListener listener, ActionFilterChain chain) { listener.onFailure(new ElasticsearchTimeoutException("")); } } } private static interface RequestCallback { void execute(Task task, String action, ActionRequest actionRequest, ActionListener actionListener, ActionFilterChain actionFilterChain); } private static interface ResponseCallback { void execute(String action, ActionResponse response, ActionListener listener, ActionFilterChain chain); } public static class TestRequest extends ActionRequest { @Override public ActionRequestValidationException validate() { return null; } } private static class TestResponse extends ActionResponse { } }