/* (c) 2014 Open Source Geospatial Foundation - all rights reserved * This code is licensed under the GPL 2.0 license, available at the root * application directory. */ package org.geoserver.flow.controller; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicInteger; import java.util.logging.Level; import java.util.logging.Logger; import javax.servlet.http.HttpServletResponse; import org.geoserver.flow.ControlFlowCallback; import org.geoserver.flow.FlowController; import org.geoserver.ows.HttpErrorCodeException; import org.geoserver.ows.Request; import org.geotools.util.CanonicalSet; import org.geotools.util.logging.Logging; import com.google.common.base.Predicate; /** * Limits the rate of requests, and slows them down after the number of requests per unit of time is * filled, or throws a HTTP 429 if no delay if configured * * @author Andrea Aime - GeoSolutions */ public class RateFlowController implements FlowController { /** * The next epoc at which the counter will reset */ public static final String X_RATE_LIMIT_RESET = "X-Rate-Limit-Reset"; /** * How many request remain in this time slot before the rate limiting occurs */ public static final String X_RATE_LIMIT_REMAINING = "X-Rate-Limit-Remaining"; /** * How many requests per time slot before the rate limiting kicks in */ public static final String X_RATE_LIMIT_LIMIT = "X-Rate-Limit-Limit"; /** * The context in which the rate limiting occurs */ public static final String X_RATE_LIMIT_CONTEXT = "X-Rate-Limit-Context"; static final Logger LOGGER = Logging.getLogger(ControlFlowCallback.class); /** * The minimum number of counters we have need to have around before a cleanup is initiated */ static int COUNTERS_CLEANUP_THRESHOLD = Integer.parseInt(System.getProperty( "org.geoserver.flow.countersCleanupThreshold", "200")); /** * The cleanup interval before a cleanup is initiated */ static int COUNTERS_CLEANUP_INTERVAL = Integer.parseInt(System.getProperty( "org.geoserver.flow.countersCleanupInterval", "10000")); final class Counter { volatile long timePeriodId; AtomicInteger requests = new AtomicInteger(0); public int addRequest(long currPeriodId) { if (currPeriodId != timePeriodId) { synchronized (this) { if (currPeriodId != timePeriodId) { timePeriodId = currPeriodId; requests.set(0); } } } // increment and return if we have gone above the limit return requests.incrementAndGet(); } public synchronized long getTimePeriodId() { return timePeriodId; } } /** * Thread local holding the current user id */ static ThreadLocal<String> USER_ID = new ThreadLocal<String>(); /** * Generates a unique key identifying the user making the request */ KeyGenerator keyGenerator; /** * Contains all active counters */ Map<String, Counter> counters = new ConcurrentHashMap<>(); /** * Used to make user keys unique before using them as synchronization locks */ CanonicalSet<String> canonicalizer = CanonicalSet.newInstance(String.class); /** * Checks if we should apply this request rate limit to the request */ Predicate<Request> matcher; int maxRequests; long timeInterval; long delay; String action; /** * Last time we've performed a queue cleanup */ volatile long lastCleanup = System.currentTimeMillis(); /** * Builds a UserFlowController that will trigger stale queue expiration once 100 queues have * been accumulated and */ public RateFlowController(Predicate<Request> matcher, int maxRequests, long timeInterval, long delay, KeyGenerator keyGenerator) { this.matcher = matcher; this.maxRequests = maxRequests; this.timeInterval = timeInterval; this.delay = delay; this.keyGenerator = keyGenerator; if (delay > 0) { this.action = "Delay excess requests " + delay + "ms"; } else { this.action = "Reject excess requests"; } } @Override public void requestComplete(Request request) { // nothing to do } public boolean requestIncoming(Request request, long timeout) { if (!matcher.apply(request)) { return true; } boolean retval = true; long now = System.currentTimeMillis(); long currPeriodId = now / timeInterval; String userKey = keyGenerator.getUserKey(request); // grab/generate the counter Counter counter = counters.get(userKey); if (counter == null) { userKey = canonicalizer.unique(userKey); synchronized (userKey) { counter = counters.get(userKey); if (counter == null) { counter = new Counter(); counters.put(userKey, counter); } } } // update the counters int requests = counter.addRequest(currPeriodId); int residual = maxRequests - requests; // set the headers HttpServletResponse response = request.getHttpResponse(); response.addHeader(X_RATE_LIMIT_CONTEXT, matcher.toString()); response.addIntHeader(X_RATE_LIMIT_LIMIT, maxRequests); response.addIntHeader(X_RATE_LIMIT_REMAINING, Math.max(residual, 0)); response.addDateHeader(X_RATE_LIMIT_RESET, ((currPeriodId + 1) * timeInterval)); response.addHeader("X-Rate-Limit-Action", action); // counter cleanup handling if (LOGGER.isLoggable(Level.FINE)) { LOGGER.fine(this + ", residual in current time period " + residual); } if (residual < 0) { if (delay <= 0) { throw new HttpErrorCodeException(429, "Too many requests requests in the current time period, check X-Rate-Limit HTTP response headers"); } else if(delay > timeout) { // no point in waiting return false; } else { if (LOGGER.isLoggable(Level.FINE)) { LOGGER.fine(this + ", delaying current request"); } try { Thread.sleep(delay); } catch (InterruptedException e) { LOGGER.log(Level.WARNING, this + ", the delay was abruptly interrupted", e); } } } // cleanup stale counters if necessary long elapsed = now - lastCleanup; if (counters.size() > COUNTERS_CLEANUP_THRESHOLD && (elapsed > (timeInterval) || (elapsed > 10000))) { int cleanupCount = 0; synchronized (counters) { for(Map.Entry<String, Counter> entry : counters.entrySet()) { Counter c = entry.getValue(); long timePeriodId = c.getTimePeriodId(); long age = (currPeriodId - timePeriodId) * timeInterval; if(age > COUNTERS_CLEANUP_THRESHOLD) { counters.remove(entry.getKey()); } } lastCleanup = now; if (LOGGER.isLoggable(Level.FINE)) { LOGGER.fine(this + ", purged " + cleanupCount + " stale counters"); } } } return retval; } public KeyGenerator getKeyGenerator() { return keyGenerator; } public Predicate<Request> getMatcher() { return matcher; } public int getMaxRequests() { return maxRequests; } public long getTimeInterval() { return timeInterval; } public long getDelay() { return delay; } @Override public int getPriority() { // higher priority, we want to go thought the rate limiters before going through // the concurrency ones, as the rate limiters can delay the request and are user specific return Integer.MIN_VALUE + maxRequests * (int) (86400 / timeInterval); } @Override public String toString() { return getClass().getSimpleName() + " [" + matcher + ", action=" + action + "]"; } }