/**
* Dianping.com Inc.
* Copyright (c) 2003-2013 All Rights Reserved.
*/
package com.dianping.pigeon.remoting.provider.process.filter;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import com.dianping.pigeon.remoting.common.codec.json.JacksonSerializer;
import com.google.common.collect.Maps;
import org.apache.commons.lang.StringUtils;
import com.dianping.pigeon.log.Logger;
import com.dianping.pigeon.config.ConfigChangeListener;
import com.dianping.pigeon.config.ConfigManager;
import com.dianping.pigeon.config.ConfigManagerLoader;
import com.dianping.pigeon.log.LoggerLoader;
import com.dianping.pigeon.remoting.common.domain.Disposable;
import com.dianping.pigeon.remoting.common.domain.InvocationContext.TimePhase;
import com.dianping.pigeon.remoting.common.domain.InvocationContext.TimePoint;
import com.dianping.pigeon.remoting.common.domain.InvocationRequest;
import com.dianping.pigeon.remoting.common.domain.InvocationResponse;
import com.dianping.pigeon.remoting.common.exception.RejectedException;
import com.dianping.pigeon.remoting.common.process.ServiceInvocationFilter;
import com.dianping.pigeon.remoting.common.process.ServiceInvocationHandler;
import com.dianping.pigeon.remoting.common.util.Constants;
import com.dianping.pigeon.remoting.provider.config.ProviderConfig;
import com.dianping.pigeon.remoting.provider.domain.ProviderContext;
import com.dianping.pigeon.remoting.provider.process.statistics.ProviderStatisticsChecker;
import com.dianping.pigeon.remoting.provider.process.statistics.ProviderStatisticsHolder;
import com.dianping.pigeon.remoting.provider.publish.ServiceChangeListener;
import com.dianping.pigeon.remoting.provider.publish.ServiceChangeListenerContainer;
import com.dianping.pigeon.remoting.provider.service.method.ServiceMethodCache;
import com.dianping.pigeon.remoting.provider.service.method.ServiceMethodFactory;
import com.dianping.pigeon.threadpool.DefaultThreadPool;
import com.dianping.pigeon.threadpool.ThreadPool;
import com.dianping.pigeon.util.CollectionUtils;
import com.dianping.pigeon.util.ThreadPoolUtils;
/**
* @author xiangwu
*
*/
public class GatewayProcessFilter implements ServiceInvocationFilter<ProviderContext>, Disposable {
private static final Logger logger = LoggerLoader.getLogger(GatewayProcessFilter.class);
private static final ConfigManager configManager = ConfigManagerLoader.getConfigManager();
private static final String KEY_APPLIMIT_ENABLE = "pigeon.provider.applimit.enable";
private static final String KEY_METHODAPPLIMIT_ENABLE = "pigeon.provider.methodapplimit.active";
private static final String KEY_GLOBALLIMIT_ENABLE = "pigeon.provider.globallimit.enable";
private static final String KEY_METHODTHREADSLIMIT_ENABLE = "pigeon.provider.methodthreadslimit.enable";
private static final String KEY_APPLIMIT = "pigeon.provider.applimit";
private static final String KEY_METHODAPPLIMIT = "pigeon.provider.methodapplimit";
private static final String KEY_GLOBALLIMIT = "pigeon.provider.globallimit";
private static volatile Map<String, Long> appLimitMap = new ConcurrentHashMap<String, Long>();
// api#method --> {app1 --> qpslimit, app2 --> qpslimit}
private static volatile Map<String, Map<String, Long>> methodAppLimitMap = Maps.newConcurrentMap();
private static volatile Long globalLimit = Long.MAX_VALUE;
private static final JacksonSerializer jacksonSerializer = new JacksonSerializer();
private static ThreadPool statisticsCheckerPool = new DefaultThreadPool("Pigeon-Server-Statistics-Checker");
private static final ConcurrentHashMap<String, AtomicInteger> methodActives = new ConcurrentHashMap<String, AtomicInteger>();
private static final AtomicInteger total = new AtomicInteger();
private static final int MAX_THREADS = ConfigManagerLoader.getConfigManager()
.getIntValue("pigeon.provider.pool.method.maxthreads", 100);
private static volatile boolean enableMethodThreadsLimit = configManager
.getBooleanValue(KEY_METHODTHREADSLIMIT_ENABLE, true);
private static volatile boolean enableAppLimit = configManager.getBooleanValue(KEY_APPLIMIT_ENABLE, false);
private static volatile boolean enableMethodAppLimit = configManager.getBooleanValue(KEY_METHODAPPLIMIT_ENABLE,
false);
private static volatile boolean enableGlobalLimit = configManager.getBooleanValue(KEY_GLOBALLIMIT_ENABLE, false);
static {
String globalLimitConfig = configManager.getStringValue(KEY_GLOBALLIMIT);
parseGlobalLimitConfig(globalLimitConfig);
String methodAppLimitConfig = configManager.getStringValue(KEY_METHODAPPLIMIT);
parseMethodAppLimitConfig(methodAppLimitConfig);
String appLimitConfig = configManager.getStringValue(KEY_APPLIMIT);
parseAppLimitConfig(appLimitConfig);
ConfigManagerLoader.getConfigManager().registerConfigChangeListener(new InnerConfigChangeListener());
ProviderStatisticsChecker appStatisticsChecker = new ProviderStatisticsChecker();
statisticsCheckerPool.execute(appStatisticsChecker);
ServiceChangeListenerContainer.addServiceChangeListener(new InnerServiceChangeListener());
}
public void destroy() throws Exception {
ThreadPoolUtils.shutdown(statisticsCheckerPool.getExecutor());
}
private static void parseGlobalLimitConfig(String globalLimitConfig) {
if (StringUtils.isNotBlank(globalLimitConfig)) {
try {
Long _globalLimit = Long.parseLong(globalLimitConfig);
if (_globalLimit >= 0) {
globalLimit = _globalLimit;
}
} catch (Throwable t) {
logger.error("error while parsing global limit configuration:" + globalLimitConfig, t);
}
}
}
private static void parseMethodAppLimitConfig(String methodAppLimitConfig) {
if (StringUtils.isNotBlank(methodAppLimitConfig)) {
Map<String, Map<String, Long>> map = Maps.newConcurrentMap();
try {
map = (HashMap) jacksonSerializer.toObject(HashMap.class, methodAppLimitConfig);
for (Map<String, Long> appLimitMap : map.values()) {
for (String app : new HashSet<String>(appLimitMap.keySet())) {
Long limit = Long.parseLong("" + appLimitMap.get(app));
appLimitMap.put(app, limit);
}
}
methodAppLimitMap.clear();
methodAppLimitMap = new ConcurrentHashMap<>(map);
} catch (Throwable t) {
logger.error("error while parsing method app limit configuration:" + methodAppLimitConfig, t);
}
}
}
private static void parseAppLimitConfig(String appLimitConfig) {
if (StringUtils.isNotBlank(appLimitConfig)) {
ConcurrentHashMap<String, Long> map = new ConcurrentHashMap<String, Long>();
try {
String[] appLimitConfigPair = appLimitConfig.split(",");
for (String str : appLimitConfigPair) {
if (StringUtils.isNotBlank(str)) {
String[] pair = str.split(":");
if (pair != null && pair.length == 2) {
map.put(pair[0], Long.valueOf(pair[1]));
}
}
}
appLimitMap.clear();
appLimitMap = map;
} catch (RuntimeException e) {
logger.error("error while parsing app limit configuration:" + appLimitConfig, e);
}
}
}
@Override
public InvocationResponse invoke(ServiceInvocationHandler handler, ProviderContext invocationContext)
throws Throwable {
invocationContext.getTimeline().add(new TimePoint(TimePhase.G));
InvocationRequest request = invocationContext.getRequest();
String fromApp = request.getApp();
InvocationResponse response = null;
final String requestMethod = request.getServiceName() + "#" + request.getMethodName();
try {
ProviderStatisticsHolder.flowIn(request);
if (Constants.MESSAGE_TYPE_SERVICE == request.getMessageType()) {
if (enableMethodThreadsLimit) {
incrementRequest(requestMethod);
}
if (enableMethodAppLimit && methodAppLimitMap.containsKey(requestMethod)
&& StringUtils.isNotBlank(fromApp)) {
Long limit = methodAppLimitMap.get(requestMethod).get(fromApp);
if (limit != null && limit >= 0) {
long requests = ProviderStatisticsHolder.getMethodAppCapacityBucket(request)
.getRequestsInCurrentSecond();
if (requests + 1 > limit) {
throw new RejectedException(
String.format("Max requests limit %s reached for request %s from app:%s", limit,
requestMethod, fromApp));
}
}
}
if (enableAppLimit && StringUtils.isNotBlank(fromApp) && appLimitMap.containsKey(fromApp)) {
Long limit = appLimitMap.get(fromApp);
if (limit >= 0) {
long requests = ProviderStatisticsHolder.getCapacityBucket(request)
.getRequestsInCurrentSecond();
if (requests + 1 > limit) {
throw new RejectedException(String
.format("Max requests limit %s reached for request from app:%s", limit, fromApp));
}
}
}
if (enableGlobalLimit) {
Long limit = globalLimit;
if (limit >= 0) {
long requests = ProviderStatisticsHolder.getGlobalCapacityBucket().getRequestsInCurrentSecond();
if (requests + 1 > limit) {
throw new RejectedException(String
.format("Max requests limit %s reached for global limitation", limit));
}
}
}
}
response = handler.handle(invocationContext);
return response;
} finally {
if (Constants.MESSAGE_TYPE_SERVICE == request.getMessageType() && enableMethodThreadsLimit) {
decrementRequest(requestMethod);
}
if (!(Constants.REPLY_MANUAL || invocationContext.isAsync())) {
ProviderStatisticsHolder.flowOut(request);
}
}
}
public static String getStatistics() {
StringBuilder stats = new StringBuilder();
if (!CollectionUtils.isEmpty(methodActives)) {
stats.append(",[method actives=[");
for (String key : methodActives.keySet()) {
stats.append("[").append(key).append("=").append(methodActives.get(key)).append("]");
}
stats.append("]");
}
return stats.toString();
}
public static void checkRequest(final InvocationRequest request) {
if (Constants.MESSAGE_TYPE_SERVICE == request.getMessageType() && enableMethodThreadsLimit) {
final String requestMethod = request.getServiceName() + "#" + request.getMethodName();
AtomicInteger count = methodActives.get(requestMethod);
if (count != null) {
int limit = getMaxThreadsForMethod(requestMethod, count.get());
if (count.get() > limit) {
throw new RejectedException(
String.format("Reached the maximum limit %s for method: %s, current: %s", limit,
requestMethod, count.get()));
}
}
}
}
private static int getMaxThreadsForMethod(String requestMethod, int requestMethodThreadCount) {
int totalThreads = total.get();
int limit = MAX_THREADS > totalThreads ? MAX_THREADS - totalThreads + requestMethodThreadCount
: requestMethodThreadCount;
if (limit > MAX_THREADS - 20) {
limit = MAX_THREADS - 20;
}
return limit;
}
private static void incrementRequest(String requestMethod) {
total.incrementAndGet();
AtomicInteger count = methodActives.get(requestMethod);
if (count != null) {
int limit = getMaxThreadsForMethod(requestMethod, count.get());
if (count.incrementAndGet() > limit) {
throw new RejectedException(String.format("Reached the maximum limit %s for method: %s, current: %s",
limit, requestMethod, count.get()));
}
}
}
private static void decrementRequest(String requestMethod) {
total.decrementAndGet();
AtomicInteger count = methodActives.get(requestMethod);
if (count != null) {
count.decrementAndGet();
}
}
private static class InnerConfigChangeListener implements ConfigChangeListener {
@Override
public void onKeyUpdated(String key, String value) {
try {
if (key.endsWith(KEY_APPLIMIT)) {
parseAppLimitConfig(value);
} else if (key.endsWith(KEY_METHODAPPLIMIT)) {
parseMethodAppLimitConfig(value);
} else if (key.endsWith(KEY_APPLIMIT_ENABLE)) {
enableAppLimit = Boolean.valueOf(value);
} else if (key.endsWith(KEY_METHODAPPLIMIT_ENABLE)) {
enableMethodAppLimit = Boolean.valueOf(value);
} else if (key.endsWith(KEY_METHODTHREADSLIMIT_ENABLE)) {
enableMethodThreadsLimit = Boolean.valueOf(value);
} else if (key.endsWith(KEY_GLOBALLIMIT_ENABLE)) {
enableGlobalLimit = Boolean.valueOf(value);
} else if (key.endsWith(KEY_GLOBALLIMIT)) {
parseGlobalLimitConfig(value);
}
} catch (Throwable t) {
logger.warn("invalid value for key " + key, t);
}
}
@Override
public void onKeyAdded(String key, String value) {
}
@Override
public void onKeyRemoved(String key) {
}
}
private static class InnerServiceChangeListener implements ServiceChangeListener {
@Override
public void notifyServicePublished(ProviderConfig<?> providerConfig) {
}
@Override
public void notifyServiceUnpublished(ProviderConfig<?> providerConfig) {
}
@Override
public void notifyServiceOnline(ProviderConfig<?> providerConfig) {
}
@Override
public void notifyServiceOffline(ProviderConfig<?> providerConfig) {
}
@Override
public void notifyServiceAdded(ProviderConfig<?> providerConfig) {
String url = providerConfig.getUrl();
ServiceMethodCache methodCache = ServiceMethodFactory.getServiceMethodCache(url);
Set<String> methodNames = methodCache.getMethodMap().keySet();
for (String method : methodNames) {
methodActives.put(url + "#" + method, new AtomicInteger());
}
}
@Override
public void notifyServiceRemoved(ProviderConfig<?> providerConfig) {
String url = providerConfig.getUrl();
ServiceMethodCache methodCache = ServiceMethodFactory.getServiceMethodCache(url);
Set<String> methodNames = methodCache.getMethodMap().keySet();
for (String method : methodNames) {
methodActives.remove(url + "#" + method);
}
}
}
}