/* * Licensed to Elasticsearch under one or more contributor * license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright * ownership. Elasticsearch licenses this file to you under * the Apache License, Version 2.0 (the "License"); you may * not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.elasticsearch.cluster.service; import org.elasticsearch.cluster.ClusterStateTaskConfig; import org.elasticsearch.cluster.metadata.ProcessClusterEventTimeoutException; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.Priority; import org.elasticsearch.common.lease.Releasable; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.common.util.concurrent.PrioritizedEsThreadPoolExecutor; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; import org.junit.After; import org.junit.AfterClass; import org.junit.Before; import org.junit.BeforeClass; import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import static org.elasticsearch.common.util.concurrent.EsExecutors.daemonThreadFactory; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.core.Is.is; public class TaskExecutorTests extends ESTestCase { protected static ThreadPool threadPool; protected PrioritizedEsThreadPoolExecutor threadExecutor; @BeforeClass public static void createThreadPool() { threadPool = new TestThreadPool(getTestClass().getName()); } @AfterClass public static void stopThreadPool() { if (threadPool != null) { threadPool.shutdownNow(); threadPool = null; } } @Before public void setUpExecutor() { threadExecutor = EsExecutors.newSinglePrioritizing("test_thread", daemonThreadFactory(Settings.EMPTY, "test_thread"), threadPool.getThreadContext(), threadPool.scheduler()); } @After public void shutDownThreadExecutor() { ThreadPool.terminate(threadExecutor, 10, TimeUnit.SECONDS); } protected interface TestListener { void onFailure(String source, Exception e); default void processed(String source) { // do nothing by default } } protected interface TestExecutor<T> { void execute(List<T> tasks); default String describeTasks(List<T> tasks) { return tasks.stream().map(T::toString).reduce((s1,s2) -> { if (s1.isEmpty()) { return s2; } else if (s2.isEmpty()) { return s1; } else { return s1 + ", " + s2; } }).orElse(""); } } /** * Task class that works for single tasks as well as batching (see {@link TaskBatcherTests}) */ protected abstract static class TestTask implements TestExecutor<TestTask>, TestListener, ClusterStateTaskConfig { @Override public void execute(List<TestTask> tasks) { tasks.forEach(TestTask::run); } @Nullable @Override public TimeValue timeout() { return null; } @Override public Priority priority() { return Priority.NORMAL; } public abstract void run(); } class UpdateTask extends SourcePrioritizedRunnable { final TestTask testTask; UpdateTask(String source, TestTask testTask) { super(testTask.priority(), source); this.testTask = testTask; } @Override public void run() { logger.trace("will process {}", source); testTask.execute(Collections.singletonList(testTask)); testTask.processed(source); } } // can be overridden by TaskBatcherTests protected void submitTask(String source, TestTask testTask) { SourcePrioritizedRunnable task = new UpdateTask(source, testTask); TimeValue timeout = testTask.timeout(); if (timeout != null) { threadExecutor.execute(task, timeout, () -> threadPool.generic().execute(() -> { logger.debug("task [{}] timed out after [{}]", task, timeout); testTask.onFailure(source, new ProcessClusterEventTimeoutException(timeout, source)); })); } else { threadExecutor.execute(task); } } public void testTimedOutTaskCleanedUp() throws Exception { final CountDownLatch block = new CountDownLatch(1); final CountDownLatch blockCompleted = new CountDownLatch(1); TestTask blockTask = new TestTask() { @Override public void run() { try { block.await(); } catch (InterruptedException e) { throw new RuntimeException(e); } blockCompleted.countDown(); } @Override public void onFailure(String source, Exception e) { throw new RuntimeException(e); } }; submitTask("block-task", blockTask); final CountDownLatch block2 = new CountDownLatch(1); TestTask unblockTask = new TestTask() { @Override public void run() { block2.countDown(); } @Override public void onFailure(String source, Exception e) { block2.countDown(); } @Override public TimeValue timeout() { return TimeValue.ZERO; } }; submitTask("unblock-task", unblockTask); block.countDown(); block2.await(); blockCompleted.await(); } public void testTimeoutTask() throws Exception { final CountDownLatch block = new CountDownLatch(1); TestTask test1 = new TestTask() { @Override public void run() { try { block.await(); } catch (InterruptedException e) { throw new RuntimeException(e); } } @Override public void onFailure(String source, Exception e) { throw new RuntimeException(e); } }; submitTask("block-task", test1); final CountDownLatch timedOut = new CountDownLatch(1); final AtomicBoolean executeCalled = new AtomicBoolean(); TestTask test2 = new TestTask() { @Override public TimeValue timeout() { return TimeValue.timeValueMillis(2); } @Override public void run() { executeCalled.set(true); } @Override public void onFailure(String source, Exception e) { timedOut.countDown(); } }; submitTask("block-task", test2); timedOut.await(); block.countDown(); final CountDownLatch allProcessed = new CountDownLatch(1); TestTask test3 = new TestTask() { @Override public void run() { allProcessed.countDown(); } @Override public void onFailure(String source, Exception e) { throw new RuntimeException(e); } }; submitTask("block-task", test3); allProcessed.await(); // executed another task to double check that execute on the timed out update task is not called... assertThat(executeCalled.get(), equalTo(false)); } static class TaskExecutor implements TestExecutor<Integer> { List<Integer> tasks = new ArrayList<>(); @Override public void execute(List<Integer> tasks) { this.tasks.addAll(tasks); } } /** * Note, this test can only work as long as we have a single thread executor executing the state update tasks! */ public void testPrioritizedTasks() throws Exception { BlockingTask block = new BlockingTask(Priority.IMMEDIATE); submitTask("test", block); int taskCount = randomIntBetween(5, 20); // will hold all the tasks in the order in which they were executed List<PrioritizedTask> tasks = new ArrayList<>(taskCount); CountDownLatch latch = new CountDownLatch(taskCount); for (int i = 0; i < taskCount; i++) { Priority priority = randomFrom(Priority.values()); PrioritizedTask task = new PrioritizedTask(priority, latch, tasks); submitTask("test", task); } block.close(); latch.await(); Priority prevPriority = null; for (PrioritizedTask task : tasks) { if (prevPriority == null) { prevPriority = task.priority(); } else { assertThat(task.priority().sameOrAfter(prevPriority), is(true)); } } } protected static class BlockingTask extends TestTask implements Releasable { private final CountDownLatch latch = new CountDownLatch(1); private final Priority priority; BlockingTask(Priority priority) { super(); this.priority = priority; } @Override public void run() { try { latch.await(); } catch (InterruptedException e) { throw new RuntimeException(e); } } @Override public void onFailure(String source, Exception e) { } @Override public Priority priority() { return priority; } public void close() { latch.countDown(); } } protected static class PrioritizedTask extends TestTask { private final CountDownLatch latch; private final List<PrioritizedTask> tasks; private final Priority priority; private PrioritizedTask(Priority priority, CountDownLatch latch, List<PrioritizedTask> tasks) { super(); this.latch = latch; this.tasks = tasks; this.priority = priority; } @Override public void run() { tasks.add(this); latch.countDown(); } @Override public Priority priority() { return priority; } @Override public void onFailure(String source, Exception e) { latch.countDown(); } } }