/* * Licensed to Elasticsearch under one or more contributor * license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright * ownership. Elasticsearch licenses this file to you under * the Apache License, Version 2.0 (the "License"); you may * not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.elasticsearch.action.support.tasks; import org.apache.logging.log4j.message.ParameterizedMessage; import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.FailedNodeException; import org.elasticsearch.action.NoSuchNodeException; import org.elasticsearch.action.TaskOperationFailure; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.HandledTransportAction; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNodes; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.collect.ImmutableOpenMap; import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.concurrent.AtomicArray; import org.elasticsearch.tasks.Task; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.NodeShouldNotConnectException; import org.elasticsearch.transport.TransportChannel; import org.elasticsearch.transport.TransportException; import org.elasticsearch.transport.TransportRequest; import org.elasticsearch.transport.TransportRequestHandler; import org.elasticsearch.transport.TransportRequestOptions; import org.elasticsearch.transport.TransportResponse; import org.elasticsearch.transport.TransportResponseHandler; import org.elasticsearch.transport.TransportService; import java.io.IOException; import java.util.ArrayList; import java.util.List; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReferenceArray; import java.util.function.Consumer; import java.util.function.Supplier; import static java.util.Collections.emptyList; /** * The base class for transport actions that are interacting with currently running tasks. */ public abstract class TransportTasksAction< OperationTask extends Task, TasksRequest extends BaseTasksRequest<TasksRequest>, TasksResponse extends BaseTasksResponse, TaskResponse extends Writeable > extends HandledTransportAction<TasksRequest, TasksResponse> { protected final ClusterService clusterService; protected final TransportService transportService; protected final Supplier<TasksRequest> requestSupplier; protected final Supplier<TasksResponse> responseSupplier; protected final String transportNodeAction; protected TransportTasksAction(Settings settings, String actionName, ThreadPool threadPool, ClusterService clusterService, TransportService transportService, ActionFilters actionFilters, IndexNameExpressionResolver indexNameExpressionResolver, Supplier<TasksRequest> requestSupplier, Supplier<TasksResponse> responseSupplier, String nodeExecutor) { super(settings, actionName, threadPool, transportService, actionFilters, indexNameExpressionResolver, requestSupplier); this.clusterService = clusterService; this.transportService = transportService; this.transportNodeAction = actionName + "[n]"; this.requestSupplier = requestSupplier; this.responseSupplier = responseSupplier; transportService.registerRequestHandler(transportNodeAction, NodeTaskRequest::new, nodeExecutor, new NodeTransportHandler()); } @Override protected final void doExecute(TasksRequest request, ActionListener<TasksResponse> listener) { logger.warn("attempt to execute a transport tasks operation without a task"); throw new UnsupportedOperationException("task parameter is required for this operation"); } @Override protected void doExecute(Task task, TasksRequest request, ActionListener<TasksResponse> listener) { new AsyncAction(task, request, listener).start(); } private void nodeOperation(NodeTaskRequest nodeTaskRequest, ActionListener<NodeTasksResponse> listener) { TasksRequest request = nodeTaskRequest.tasksRequest; List<OperationTask> tasks = new ArrayList<>(); processTasks(request, tasks::add); if (tasks.isEmpty()) { listener.onResponse(new NodeTasksResponse(clusterService.localNode().getId(), emptyList(), emptyList())); return; } AtomicArray<Tuple<TaskResponse, Exception>> responses = new AtomicArray<>(tasks.size()); final AtomicInteger counter = new AtomicInteger(tasks.size()); for (int i = 0; i < tasks.size(); i++) { final int taskIndex = i; ActionListener<TaskResponse> taskListener = new ActionListener<TaskResponse>() { @Override public void onResponse(TaskResponse response) { responses.setOnce(taskIndex, response == null ? null : new Tuple<>(response, null)); respondIfFinished(); } @Override public void onFailure(Exception e) { responses.setOnce(taskIndex, new Tuple<>(null, e)); respondIfFinished(); } private void respondIfFinished() { if (counter.decrementAndGet() != 0) { return; } List<TaskResponse> results = new ArrayList<>(); List<TaskOperationFailure> exceptions = new ArrayList<>(); for (Tuple<TaskResponse, Exception> response : responses.asList()) { if (response.v1() == null) { assert response.v2() != null; exceptions.add(new TaskOperationFailure(clusterService.localNode().getId(), tasks.get(taskIndex).getId(), response.v2())); } else { assert response.v2() == null; results.add(response.v1()); } } listener.onResponse(new NodeTasksResponse(clusterService.localNode().getId(), results, exceptions)); } }; try { taskOperation(request, tasks.get(taskIndex), taskListener); } catch (Exception e) { taskListener.onFailure(e); } } } protected String[] filterNodeIds(DiscoveryNodes nodes, String[] nodesIds) { return nodesIds; } protected String[] resolveNodes(TasksRequest request, ClusterState clusterState) { if (request.getTaskId().isSet()) { return new String[]{request.getTaskId().getNodeId()}; } else { return clusterState.nodes().resolveNodes(request.getNodes()); } } protected void processTasks(TasksRequest request, Consumer<OperationTask> operation) { if (request.getTaskId().isSet()) { // we are only checking one task, we can optimize it Task task = taskManager.getTask(request.getTaskId().getId()); if (task != null) { if (request.match(task)) { operation.accept((OperationTask) task); } else { throw new ResourceNotFoundException("task [{}] doesn't support this operation", request.getTaskId()); } } else { throw new ResourceNotFoundException("task [{}] is missing", request.getTaskId()); } } else { for (Task task : taskManager.getTasks().values()) { if (request.match(task)) { operation.accept((OperationTask) task); } } } } protected abstract TasksResponse newResponse(TasksRequest request, List<TaskResponse> tasks, List<TaskOperationFailure> taskOperationFailures, List<FailedNodeException> failedNodeExceptions); @SuppressWarnings("unchecked") protected TasksResponse newResponse(TasksRequest request, AtomicReferenceArray responses) { List<TaskResponse> tasks = new ArrayList<>(); List<FailedNodeException> failedNodeExceptions = new ArrayList<>(); List<TaskOperationFailure> taskOperationFailures = new ArrayList<>(); for (int i = 0; i < responses.length(); i++) { Object response = responses.get(i); if (response instanceof FailedNodeException) { failedNodeExceptions.add((FailedNodeException) response); } else { NodeTasksResponse tasksResponse = (NodeTasksResponse) response; if (tasksResponse.results != null) { tasks.addAll(tasksResponse.results); } if (tasksResponse.exceptions != null) { taskOperationFailures.addAll(tasksResponse.exceptions); } } } return newResponse(request, tasks, taskOperationFailures, failedNodeExceptions); } protected abstract TaskResponse readTaskResponse(StreamInput in) throws IOException; /** * Perform the required operation on the task. It is OK start an asynchronous operation or to throw an exception but not both. */ protected abstract void taskOperation(TasksRequest request, OperationTask task, ActionListener<TaskResponse> listener); protected boolean transportCompress() { return false; } protected abstract boolean accumulateExceptions(); private class AsyncAction { private final TasksRequest request; private final String[] nodesIds; private final DiscoveryNode[] nodes; private final ActionListener<TasksResponse> listener; private final AtomicReferenceArray<Object> responses; private final AtomicInteger counter = new AtomicInteger(); private final Task task; private AsyncAction(Task task, TasksRequest request, ActionListener<TasksResponse> listener) { this.task = task; this.request = request; this.listener = listener; ClusterState clusterState = clusterService.state(); String[] nodesIds = resolveNodes(request, clusterState); this.nodesIds = filterNodeIds(clusterState.nodes(), nodesIds); ImmutableOpenMap<String, DiscoveryNode> nodes = clusterState.nodes().getNodes(); this.nodes = new DiscoveryNode[nodesIds.length]; for (int i = 0; i < this.nodesIds.length; i++) { this.nodes[i] = nodes.get(this.nodesIds[i]); } this.responses = new AtomicReferenceArray<>(this.nodesIds.length); } private void start() { if (nodesIds.length == 0) { // nothing to do try { listener.onResponse(newResponse(request, responses)); } catch (Exception e) { logger.debug("failed to generate empty response", e); listener.onFailure(e); } } else { TransportRequestOptions.Builder builder = TransportRequestOptions.builder(); if (request.getTimeout() != null) { builder.withTimeout(request.getTimeout()); } builder.withCompress(transportCompress()); for (int i = 0; i < nodesIds.length; i++) { final String nodeId = nodesIds[i]; final int idx = i; final DiscoveryNode node = nodes[i]; try { if (node == null) { onFailure(idx, nodeId, new NoSuchNodeException(nodeId)); } else { NodeTaskRequest nodeRequest = new NodeTaskRequest(request); nodeRequest.setParentTask(clusterService.localNode().getId(), task.getId()); transportService.sendRequest(node, transportNodeAction, nodeRequest, builder.build(), new TransportResponseHandler<NodeTasksResponse>() { @Override public NodeTasksResponse newInstance() { return new NodeTasksResponse(); } @Override public void handleResponse(NodeTasksResponse response) { onOperation(idx, response); } @Override public void handleException(TransportException exp) { onFailure(idx, node.getId(), exp); } @Override public String executor() { return ThreadPool.Names.SAME; } }); } } catch (Exception e) { onFailure(idx, nodeId, e); } } } } private void onOperation(int idx, NodeTasksResponse nodeResponse) { responses.set(idx, nodeResponse); if (counter.incrementAndGet() == responses.length()) { finishHim(); } } private void onFailure(int idx, String nodeId, Throwable t) { if (logger.isDebugEnabled() && !(t instanceof NodeShouldNotConnectException)) { logger.debug( (org.apache.logging.log4j.util.Supplier<?>) () -> new ParameterizedMessage("failed to execute on node [{}]", nodeId), t); } if (accumulateExceptions()) { responses.set(idx, new FailedNodeException(nodeId, "Failed node [" + nodeId + "]", t)); } if (counter.incrementAndGet() == responses.length()) { finishHim(); } } private void finishHim() { TasksResponse finalResponse; try { finalResponse = newResponse(request, responses); } catch (Exception e) { logger.debug("failed to combine responses from nodes", e); listener.onFailure(e); return; } listener.onResponse(finalResponse); } } class NodeTransportHandler implements TransportRequestHandler<NodeTaskRequest> { @Override public void messageReceived(final NodeTaskRequest request, final TransportChannel channel) throws Exception { nodeOperation(request, new ActionListener<NodeTasksResponse>() { @Override public void onResponse( TransportTasksAction<OperationTask, TasksRequest, TasksResponse, TaskResponse>.NodeTasksResponse response) { try { channel.sendResponse(response); } catch (Exception e) { onFailure(e); } } @Override public void onFailure(Exception e) { try { channel.sendResponse(e); } catch (IOException e1) { e1.addSuppressed(e); logger.warn("Failed to send failure", e1); } } }); } } private class NodeTaskRequest extends TransportRequest { private TasksRequest tasksRequest; protected NodeTaskRequest() { super(); } protected NodeTaskRequest(TasksRequest tasksRequest) { super(); this.tasksRequest = tasksRequest; } @Override public void readFrom(StreamInput in) throws IOException { super.readFrom(in); tasksRequest = requestSupplier.get(); tasksRequest.readFrom(in); } @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); tasksRequest.writeTo(out); } } private class NodeTasksResponse extends TransportResponse { protected String nodeId; protected List<TaskOperationFailure> exceptions; protected List<TaskResponse> results; NodeTasksResponse() { } NodeTasksResponse(String nodeId, List<TaskResponse> results, List<TaskOperationFailure> exceptions) { this.nodeId = nodeId; this.results = results; this.exceptions = exceptions; } public String getNodeId() { return nodeId; } public List<TaskOperationFailure> getExceptions() { return exceptions; } @Override public void readFrom(StreamInput in) throws IOException { super.readFrom(in); nodeId = in.readString(); int resultsSize = in.readVInt(); results = new ArrayList<>(resultsSize); for (; resultsSize > 0; resultsSize--) { final TaskResponse result = in.readBoolean() ? readTaskResponse(in) : null; results.add(result); } if (in.readBoolean()) { int taskFailures = in.readVInt(); exceptions = new ArrayList<>(taskFailures); for (int i = 0; i < taskFailures; i++) { exceptions.add(new TaskOperationFailure(in)); } } else { exceptions = null; } } @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); out.writeString(nodeId); out.writeVInt(results.size()); for (TaskResponse result : results) { if (result != null) { out.writeBoolean(true); result.writeTo(out); } else { out.writeBoolean(false); } } out.writeBoolean(exceptions != null); if (exceptions != null) { int taskFailures = exceptions.size(); out.writeVInt(taskFailures); for (TaskOperationFailure exception : exceptions) { exception.writeTo(out); } } } } }