/* * Copyright 2016 LINE Corporation * * LINE Corporation licenses this file to you under the Apache License, * version 2.0 (the "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at: * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. */ package com.linecorp.armeria.client.limit; import static java.util.Objects.requireNonNull; import java.util.Queue; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import com.linecorp.armeria.client.Client; import com.linecorp.armeria.client.ClientRequestContext; import com.linecorp.armeria.client.ResponseTimeoutException; import com.linecorp.armeria.client.SimpleDecoratingClient; import com.linecorp.armeria.common.Request; import com.linecorp.armeria.common.RequestContext; import com.linecorp.armeria.common.Response; import com.linecorp.armeria.common.util.SafeCloseable; import io.netty.util.concurrent.ScheduledFuture; /** * An abstract {@link Client} decorator that limits the concurrent number of active requests. * * <p>{@link #numActiveRequests()} increases when {@link Client#execute(ClientRequestContext, Request)} is * invoked and decreases when the {@link Response} returned by the * {@link Client#execute(ClientRequestContext, Request)} is closed. When {@link #numActiveRequests()} reaches * at the configured {@code maxConcurrency} the {@link Request}s are deferred until the currently active * {@link Request}s are completed. * * @param <I> the {@link Request} type * @param <O> the {@link Response} type */ public abstract class ConcurrencyLimitingClient<I extends Request, O extends Response> extends SimpleDecoratingClient<I, O> { private static final long DEFAULT_TIMEOUT_MILLIS = 10000L; private final int maxConcurrency; private final long timeoutMillis; private final AtomicInteger numActiveRequests = new AtomicInteger(); private final Queue<PendingTask> pendingRequests = new ConcurrentLinkedQueue<>(); /** * Creates a new instance that decorates the specified {@code delegate} to limit the concurrent number of * active requests to {@code maxConcurrency}, with the default timeout of {@value #DEFAULT_TIMEOUT_MILLIS} * milliseconds. * * @param delegate the delegate {@link Client} * @param maxConcurrency the maximum number of concurrent active requests. {@code 0} to disable the limit. */ protected ConcurrencyLimitingClient(Client<? super I, ? extends O> delegate, int maxConcurrency) { this(delegate, maxConcurrency, DEFAULT_TIMEOUT_MILLIS, TimeUnit.MILLISECONDS); } /** * Creates a new instance that decorates the specified {@code delegate} to limit the concurrent number of * active requests to {@code maxConcurrency}. * * @param delegate the delegate {@link Client} * @param maxConcurrency the maximum number of concurrent active requests. {@code 0} to disable the limit. * @param timeout the amount of time until this decorator fails the request if the request was not * delegated to the {@code delegate} before then */ protected ConcurrencyLimitingClient(Client<? super I, ? extends O> delegate, int maxConcurrency, long timeout, TimeUnit unit) { super(delegate); validateAll(maxConcurrency, timeout, unit); this.maxConcurrency = maxConcurrency; timeoutMillis = unit.toMillis(timeout); } static void validateAll(int maxConcurrency, long timeout, TimeUnit unit) { validateMaxConcurrency(maxConcurrency); if (timeout < 0) { throw new IllegalArgumentException("timeout: " + timeout + " (expected: >= 0)"); } requireNonNull(unit, "unit"); } static void validateMaxConcurrency(int maxConcurrency) { if (maxConcurrency < 0) { throw new IllegalArgumentException("maxConcurrency: " + maxConcurrency + " (expected: >= 0)"); } } /** * Returns the number of the {@link Request}s that are being executed. */ public int numActiveRequests() { return numActiveRequests.get(); } @Override public O execute(ClientRequestContext ctx, I req) throws Exception { return maxConcurrency == 0 ? unlimitedExecute(ctx, req) : limitedExecute(ctx, req); } private O limitedExecute(ClientRequestContext ctx, I req) throws Exception { final Deferred<O> deferred = defer(ctx, req); final PendingTask currentTask = new PendingTask(ctx, req, deferred); pendingRequests.add(currentTask); drain(); if (!currentTask.isRun() && timeoutMillis != 0) { // Current request was not delegated. Schedule a timeout. final ScheduledFuture<?> timeoutFuture = ctx.eventLoop().schedule( () -> deferred.close(ResponseTimeoutException.get()), timeoutMillis, TimeUnit.MILLISECONDS); currentTask.set(timeoutFuture); } return deferred.response(); } private O unlimitedExecute(ClientRequestContext ctx, I req) throws Exception { numActiveRequests.incrementAndGet(); boolean success = false; try { final O res = delegate().execute(ctx, req); res.closeFuture().whenComplete((unused, cause) -> numActiveRequests.decrementAndGet()); success = true; return res; } finally { if (!success) { numActiveRequests.decrementAndGet(); } } } void drain() { while (!pendingRequests.isEmpty()) { final int currentActiveRequests = numActiveRequests.get(); if (currentActiveRequests >= maxConcurrency) { break; } if (numActiveRequests.compareAndSet(currentActiveRequests, currentActiveRequests + 1)) { final PendingTask task = pendingRequests.poll(); if (task == null) { numActiveRequests.decrementAndGet(); if (!pendingRequests.isEmpty()) { // Another request might have been added to the queue while numActiveRequests reached // at its limit. continue; } else { break; } } task.run(); } } } /** * Defers the specified {@link Request}. * * @return a new {@link Deferred} which provides the interface for updating the result of * {@link Request} execution later. */ protected abstract Deferred<O> defer(ClientRequestContext ctx, I req) throws Exception; /** * Provides the interface for updating the result of a {@link Request} execution when its {@link Response} * is ready. * * @param <O> the {@link Response} type */ public interface Deferred<O extends Response> { /** * Returns the {@link Response} which will delegate to the {@link Response} set by * {@link #delegate(Response)}. */ O response(); /** * Delegates the {@link #response() response} to the specified {@link Response}. */ void delegate(O response); /** * Closes the {@link #response()} without delegating. */ void close(Throwable cause); } private final class PendingTask extends AtomicReference<ScheduledFuture<?>> implements Runnable { private static final long serialVersionUID = -7092037489640350376L; private final ClientRequestContext ctx; private final I req; private final Deferred<O> deferred; private boolean isRun; PendingTask(ClientRequestContext ctx, I req, Deferred<O> deferred) { this.ctx = ctx; this.req = req; this.deferred = deferred; } boolean isRun() { return isRun; } @Override public void run() { isRun = true; ScheduledFuture<?> timeoutFuture = get(); if (timeoutFuture != null) { if (timeoutFuture.isDone() || !timeoutFuture.cancel(false)) { // Timeout task ran already or is determined to run. numActiveRequests.decrementAndGet(); return; } } try (SafeCloseable ignored = RequestContext.push(ctx)) { try { final O actualRes = delegate().execute(ctx, req); actualRes.closeFuture().whenCompleteAsync((unused, cause) -> { numActiveRequests.decrementAndGet(); drain(); }, ctx.eventLoop()); deferred.delegate(actualRes); } catch (Throwable t) { numActiveRequests.decrementAndGet(); deferred.close(t); } } } } }