package core.framework.impl.web.rate;
import core.framework.api.util.Exceptions;
import core.framework.api.util.Maps;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.Map;
import java.util.concurrent.TimeUnit;
/**
* @author neo
*/
public class RateLimiter {
private final Logger logger = LoggerFactory.getLogger(RateLimiter.class);
private final Map<String, RateConfig> config = Maps.newHashMap();
private final LRUMap<String, Rate> rates;
public RateLimiter(int maxEntries) {
rates = new LRUMap<>(maxEntries);
}
public void config(String group, int maxPermits, int fillRate, TimeUnit unit) {
double fillRatePerNano = ratePerNano(fillRate, unit);
RateConfig previous = config.put(group, new RateConfig(maxPermits, fillRatePerNano));
if (previous != null) throw Exceptions.error("found duplicate group, group={}", group);
}
double ratePerNano(int rate, TimeUnit unit) {
return rate / (double) unit.toNanos(1);
}
public boolean acquire(String group, String clientIP) {
RateConfig config = this.config.get(group);
if (config == null) {
logger.warn("can not find group, group={}", group);
return true; // skip if group is not defined
}
String key = group + "/" + clientIP;
Rate rate;
synchronized (this) {
rate = this.rates.computeIfAbsent(key, k -> new Rate(config.maxPermits));
}
return rate.acquire(config.maxPermits, config.fillRatePerNano);
}
static final class RateConfig {
final int maxPermits;
final double fillRatePerNano;
RateConfig(int maxPermits, double fillRatePerNano) {
this.maxPermits = maxPermits;
this.fillRatePerNano = fillRatePerNano;
}
}
static final class Rate {
double currentPermits;
long lastUpdateTime;
Rate(int currentPermits) {
this.currentPermits = currentPermits;
this.lastUpdateTime = System.nanoTime();
}
boolean acquire(int maxPermits, double fillRatePerNano) {
long currentTime = System.nanoTime();
synchronized (this) {
return acquire(currentTime, maxPermits, fillRatePerNano);
}
}
boolean acquire(long currentTime, int maxPermits, double fillRatePerNano) {
long timeElapsed = currentTime - lastUpdateTime;
currentPermits = Math.min(maxPermits, currentPermits + fillRatePerNano * timeElapsed);
lastUpdateTime = currentTime;
if (currentPermits >= 1) {
currentPermits -= 1;
return true;
} else {
return false;
}
}
}
}