/* * Licensed to Crate under one or more contributor license agreements. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership. Crate licenses this file * to you under the Apache License, Version 2.0 (the "License"); you may * not use this file except in compliance with the License. You may * obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. See the License for the specific language governing * permissions and limitations under the License. * * However, if you have executed another commercial license agreement * with Crate these terms will supersede the license and you may use the * software solely pursuant to the terms of the relevant commercial * agreement. */ package io.crate.executor.transport.distributed; import com.google.common.annotations.VisibleForTesting; import io.crate.Streamer; import io.crate.data.*; import io.crate.exceptions.SQLExceptions; import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; import javax.annotation.Nullable; import java.util.ArrayList; import java.util.Collection; import java.util.List; import java.util.UUID; import java.util.concurrent.atomic.AtomicInteger; /** * Consumer which sends requests to downstream nodes every {@link #pageSize} rows. * * The rows from the source {@link BatchIterator} are "bucketed" using a {@link MultiBucketBuilder}. So a downstream * can either receive a part of the data or all data. * * Every time requests to the downstreams are made consumption of the source BatchIterator is stopped until a response * from all downstreams is received. */ public class DistributingConsumer implements BatchConsumer { private final Logger logger; private final UUID jobId; private final int targetPhaseId; private final byte inputId; private final int bucketIdx; private final TransportDistributedResultAction distributedResultAction; private final Streamer<?>[] streamers; private final int pageSize; private final Bucket[] buckets; private final List<Downstream> downstreams; private final boolean traceEnabled; @VisibleForTesting final MultiBucketBuilder multiBucketBuilder; private volatile Throwable failure; public DistributingConsumer(Logger logger, UUID jobId, MultiBucketBuilder multiBucketBuilder, int targetPhaseId, byte inputId, int bucketIdx, Collection<String> downstreamNodeIds, TransportDistributedResultAction distributedResultAction, Streamer<?>[] streamers, int pageSize) { this.traceEnabled = logger.isTraceEnabled(); this.logger = logger; this.jobId = jobId; this.multiBucketBuilder = multiBucketBuilder; this.targetPhaseId = targetPhaseId; this.inputId = inputId; this.bucketIdx = bucketIdx; this.distributedResultAction = distributedResultAction; this.streamers = streamers; this.pageSize = pageSize; this.buckets = new Bucket[downstreamNodeIds.size()]; downstreams = new ArrayList<>(downstreamNodeIds.size()); for (String downstreamNodeId : downstreamNodeIds) { downstreams.add(new Downstream(downstreamNodeId)); } } @Override public void accept(BatchIterator iterator, @Nullable Throwable failure) { if (failure == null) { consumeIt(iterator); } else { forwardFailure(null, failure); } } private void consumeIt(BatchIterator it) { Row row = RowBridging.toRow(it.rowData()); try { while (it.moveNext()) { multiBucketBuilder.add(row); if (multiBucketBuilder.size() >= pageSize) { forwardResults(it, false); return; } } } catch (Throwable t) { forwardFailure(it, t); return; } if (it.allLoaded()) { forwardResults(it, true); } else { it.loadNextBatch().whenComplete((r, t) -> { if (t == null) { consumeIt(it); } else { forwardFailure(it, t); } }); } } private void forwardFailure(@Nullable final BatchIterator it, final Throwable f) { Throwable failure = SQLExceptions.unwrap(f); // make sure it's streamable AtomicInteger numActiveRequests = new AtomicInteger(downstreams.size()); DistributedResultRequest request = new DistributedResultRequest(jobId, targetPhaseId, inputId, bucketIdx, failure, false); for (int i = 0; i < downstreams.size(); i++) { Downstream downstream = downstreams.get(i); if (downstream.needsMoreData == false) { countdownAndMaybeCloseIt(numActiveRequests, it); } else { if (traceEnabled) { logger.trace("forwardFailure targetNode={} targetPhase={}/{} bucket={} failure={}", downstream.nodeId, targetPhaseId, inputId, bucketIdx, failure); } distributedResultAction.pushResult(downstream.nodeId, request, new ActionListener<DistributedResultResponse>() { @Override public void onResponse(DistributedResultResponse response) { downstream.needsMoreData = false; countdownAndMaybeCloseIt(numActiveRequests, it); } @Override public void onFailure(Exception e) { if (traceEnabled) { logger.trace("Error sending failure to downstream={} targetPhase={}/{} bucket={}", e, downstream.nodeId, targetPhaseId, inputId, bucketIdx); } countdownAndMaybeCloseIt(numActiveRequests, it); } }); } } } private void countdownAndMaybeCloseIt(AtomicInteger numActiveRequests, @Nullable BatchIterator it) { if (numActiveRequests.decrementAndGet() == 0) { if (it != null) { it.close(); } } } private void forwardResults(BatchIterator it, boolean isLast) { multiBucketBuilder.build(buckets); AtomicInteger numActiveRequests = new AtomicInteger(downstreams.size()); for (int i = 0; i < downstreams.size(); i++) { Downstream downstream = downstreams.get(i); if (downstream.needsMoreData == false) { countdownAndMaybeContinue(it, numActiveRequests); continue; } if (traceEnabled) { logger.trace("forwardResults targetNode={} targetPhase={}/{} bucket={} isLast={}", downstream.nodeId, targetPhaseId, inputId, bucketIdx, isLast); } distributedResultAction.pushResult( downstream.nodeId, new DistributedResultRequest(jobId, targetPhaseId, inputId, bucketIdx, streamers, buckets[i], isLast), new ActionListener<DistributedResultResponse>() { @Override public void onResponse(DistributedResultResponse response) { downstream.needsMoreData = response.needMore(); countdownAndMaybeContinue(it, numActiveRequests); } @Override public void onFailure(Exception e) { failure = e; downstream.needsMoreData = false; // continue because it's necessary to send something to downstreams still waiting for data countdownAndMaybeContinue(it, numActiveRequests); } } ); } } private void countdownAndMaybeContinue(BatchIterator it, AtomicInteger numActiveRequests) { if (numActiveRequests.decrementAndGet() == 0) { if (downstreams.stream().anyMatch(Downstream::needsMoreData)) { if (failure == null) { consumeIt(it); } else { forwardFailure(it, failure); } } else { it.close(); } } } private static class Downstream { private final String nodeId; private boolean needsMoreData = true; Downstream(String nodeId) { this.nodeId = nodeId; } boolean needsMoreData() { return needsMoreData; } } }