/**
* Dianping.com Inc.
* Copyright (c) 2003-${year} All Rights Reserved.
*/
package com.dianping.pigeon.remoting.provider.process.filter;
import java.io.Serializable;
import java.util.Collections;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.commons.lang.StringUtils;
import com.dianping.pigeon.config.ConfigChangeListener;
import com.dianping.pigeon.config.ConfigManager;
import com.dianping.pigeon.config.ConfigManagerLoader;
import com.dianping.pigeon.log.Logger;
import com.dianping.pigeon.log.LoggerLoader;
import com.dianping.pigeon.remoting.common.domain.InvocationContext.TimePhase;
import com.dianping.pigeon.remoting.common.domain.InvocationContext.TimePoint;
import com.dianping.pigeon.remoting.common.domain.InvocationRequest;
import com.dianping.pigeon.remoting.common.domain.InvocationResponse;
import com.dianping.pigeon.remoting.common.domain.generic.UnifiedRequest;
import com.dianping.pigeon.remoting.common.exception.SecurityException;
import com.dianping.pigeon.remoting.common.process.ServiceInvocationFilter;
import com.dianping.pigeon.remoting.common.process.ServiceInvocationHandler;
import com.dianping.pigeon.remoting.common.util.Constants;
import com.dianping.pigeon.remoting.common.util.ContextUtils;
import com.dianping.pigeon.remoting.common.util.SecurityUtils;
import com.dianping.pigeon.remoting.provider.domain.ProviderContext;
/**
* @author xiangwu
*/
public class SecurityFilter implements ServiceInvocationFilter<ProviderContext> {
private static final Logger logger = LoggerLoader.getLogger(SecurityFilter.class);
private static final ConfigManager configManager = ConfigManagerLoader.getConfigManager();
private static final String KEY_APP_SECRETS = "pigeon.provider.token.app.secrets";
private static final String KEY_TOKEN_ENABLE = "pigeon.provider.token.enable";
private static final String KEY_TOKEN_SWITCHES = "pigeon.provider.token.switches";
private static final String KEY_ACCESS_IP_ENABLE = "pigeon.provider.access.ip.enable";
private static final String KEY_TOKEN_PROTOCOL_DEFAULT_ENABLE = "pigeon.provider.token.protocol.default.enable";
private static final String KEY_TOKEN_TIMESTAMP_DIFF = "pigeon.provider.token.timestamp.diff";
private static volatile ConcurrentHashMap<String, String> appSecrets = new ConcurrentHashMap<String, String>();
private static volatile ConcurrentHashMap<String, Boolean> tokenSwitches = new ConcurrentHashMap<String, Boolean>();
private static volatile Set<String> ipBlackSet = Collections
.newSetFromMap(new ConcurrentHashMap<String, Boolean>());
private static volatile Set<String> ipWhiteSet = Collections
.newSetFromMap(new ConcurrentHashMap<String, Boolean>());
private static final String KEY_BLACKLIST = "pigeon.provider.access.ip.blacklist";
private static final String KEY_WHITELIST = "pigeon.provider.access.ip.whitelist";
private static final String DEFAULT_VALUE_WHITELIST = "127.0.0.1,";
private static final String KEY_ACCESS_DEFAULT = "pigeon.provider.access.ip.default";
private static volatile boolean isTokenEnable = configManager.getBooleanValue(KEY_TOKEN_ENABLE, false);
private static volatile boolean isTokenEnableForDefaultProtocol = configManager
.getBooleanValue(KEY_TOKEN_PROTOCOL_DEFAULT_ENABLE, false);
private static volatile boolean isAccessIpEnable = configManager.getBooleanValue(KEY_ACCESS_IP_ENABLE, false);
private static volatile boolean isAccessDefault = configManager.getBooleanValue(KEY_ACCESS_DEFAULT, true);
private static volatile int tokenTimestampDiff = configManager.getIntValue(KEY_TOKEN_TIMESTAMP_DIFF, 120);
private static final String KEY_ACCESS_APP_ENABLE = "pigeon.provider.access.app.enable";
private static final String KEY_APP_ACCESS_DEFAULT = "pigeon.provider.access.app.default";
private static final String KEY_APP_BLACKLIST = "pigeon.provider.access.app.blacklist";
private static final String KEY_APP_WHITELIST = "pigeon.provider.access.app.whitelist";
private static volatile boolean isAccessAppEnable = configManager.getBooleanValue(KEY_ACCESS_APP_ENABLE, false);
private static volatile boolean isAppAccessDefault = configManager.getBooleanValue(KEY_APP_ACCESS_DEFAULT, true);
private static volatile Set<String> appBlackSet = Collections.newSetFromMap(new ConcurrentHashMap<String, Boolean>());
private static volatile Set<String> appWhiteSet = Collections.newSetFromMap(new ConcurrentHashMap<String, Boolean>());
public SecurityFilter() {
parseBlackList(configManager.getStringValue(KEY_BLACKLIST, ""));
parseWhiteList(configManager.getStringValue(KEY_WHITELIST, DEFAULT_VALUE_WHITELIST));
parseAppBlackList(configManager.getStringValue(KEY_APP_BLACKLIST, ""));
parseAppWhiteList(configManager.getStringValue(KEY_APP_WHITELIST, ""));
parseAppSecrets(configManager.getStringValue(KEY_APP_SECRETS, ""));
parseTokenSwitchesConfig(configManager.getStringValue(KEY_TOKEN_SWITCHES, ""));
ConfigManagerLoader.getConfigManager().registerConfigChangeListener(new InnerConfigChangeListener());
}
private static void parseAppBlackList(String config) {
try {
String[] blackArray = config.split(",");
Set<String> set = Collections.newSetFromMap(new ConcurrentHashMap<String, Boolean>());
for (String app : blackArray) {
if (StringUtils.isBlank(app)) {
continue;
}
set.add(app.trim());
}
appBlackSet = set;
} catch (Throwable t) {
logger.warn("parse app blacklist error, remain unchanged", t);
}
}
private static void parseAppWhiteList(String config) {
try {
String[] whiteArray = config.split(",");
Set<String> set = Collections.newSetFromMap(new ConcurrentHashMap<String, Boolean>());
for (String app : whiteArray) {
if (StringUtils.isBlank(app)) {
continue;
}
set.add(app.trim());
}
appWhiteSet = set;
} catch (Throwable t) {
logger.warn("parse app whitelist error, remain unchanged", t);
}
}
private static void parseBlackList(String config) {
String[] blackArray = config.split(",");
Set<String> set = Collections.newSetFromMap(new ConcurrentHashMap<String, Boolean>());
for (String addr : blackArray) {
if (StringUtils.isBlank(addr)) {
continue;
}
set.add(addr.trim());
}
ipBlackSet = set;
}
private static void parseWhiteList(String config) {
String[] whiteArray = config.split(",");
Set<String> set = Collections.newSetFromMap(new ConcurrentHashMap<String, Boolean>());
for (String addr : whiteArray) {
if (StringUtils.isBlank(addr)) {
continue;
}
set.add(addr.trim());
}
ipWhiteSet = set;
}
private static boolean canAccess(String ip) {
if (isAccessIpEnable) {
for (String addr : ipWhiteSet) {
if (ip.startsWith(addr)) {
return true;
}
}
for (String addr : ipBlackSet) {
if (ip.startsWith(addr)) {
return false;
}
}
return isAccessDefault;
}
return true;
}
private static void parseAppSecrets(String config) {
if (StringUtils.isNotBlank(config)) {
ConcurrentHashMap<String, String> map = new ConcurrentHashMap<String, String>();
try {
String[] pairArray = config.split(",");
for (String str : pairArray) {
if (StringUtils.isNotBlank(str)) {
String[] pair = str.split(":");
if (pair != null && pair.length == 2) {
String app = pair[0].trim();
String secret = pair[1].trim();
if (secret.length() < 16) {
throw new IllegalArgumentException("Secret length must not be less than 16");
}
map.put(app, secret);
}
}
}
appSecrets.clear();
appSecrets = map;
} catch (RuntimeException e) {
logger.error("error while parsing app secret configuration:" + config, e);
}
} else {
appSecrets.clear();
}
}
private static class InnerConfigChangeListener implements ConfigChangeListener {
@Override
public void onKeyUpdated(String key, String value) {
if (key.endsWith(KEY_APP_BLACKLIST)) {
parseAppBlackList(value);
} else if (key.endsWith(KEY_APP_WHITELIST)) {
parseAppWhiteList(value);
} else if (key.endsWith(KEY_APP_SECRETS)) {
parseAppSecrets(value);
} else if (key.endsWith(KEY_BLACKLIST)) {
parseBlackList(value);
} else if (key.endsWith(KEY_WHITELIST)) {
parseWhiteList(value);
} else if (key.endsWith(KEY_TOKEN_SWITCHES)) {
parseTokenSwitchesConfig(value);
} else if (key.endsWith(KEY_TOKEN_ENABLE)) {
try {
isTokenEnable = Boolean.valueOf(value);
} catch (RuntimeException e) {
logger.warn("invalid value for key " + key, e);
}
} else if (key.endsWith(KEY_TOKEN_PROTOCOL_DEFAULT_ENABLE)) {
try {
isTokenEnableForDefaultProtocol = Boolean.valueOf(value);
} catch (RuntimeException e) {
logger.warn("invalid value for key " + key, e);
}
} else if (key.endsWith(KEY_ACCESS_IP_ENABLE)) {
try {
isAccessIpEnable = Boolean.valueOf(value);
} catch (RuntimeException e) {
logger.warn("invalid value for key " + key, e);
}
} else if (key.endsWith(KEY_ACCESS_APP_ENABLE)) {
try {
isAccessAppEnable = Boolean.valueOf(value);
} catch (RuntimeException e) {
logger.warn("invalid value for key " + key, e);
}
} else if (key.endsWith(KEY_ACCESS_DEFAULT)) {
try {
isAccessDefault = Boolean.valueOf(value);
} catch (RuntimeException e) {
logger.warn("invalid value for key " + key, e);
}
} else if (key.endsWith(KEY_APP_ACCESS_DEFAULT)) {
try {
isAppAccessDefault = Boolean.valueOf(value);
} catch (RuntimeException e) {
logger.warn("invalid value for key " + key, e);
}
} else if (key.endsWith(KEY_TOKEN_TIMESTAMP_DIFF)) {
try {
tokenTimestampDiff = Integer.valueOf(value);
} catch (RuntimeException e) {
logger.warn("invalid value for key " + key, e);
}
}
}
@Override
public void onKeyAdded(String key, String value) {
}
@Override
public void onKeyRemoved(String key) {
}
}
private static int getCurrentTime() {
return (int) (System.currentTimeMillis() / 1000);
}
public static void authenticateRequestIp(String remoteAddress) {
if (!canAccess(remoteAddress)) {
throw new SecurityException("Request ip:" + remoteAddress + " is not allowed");
}
}
private void authenticateRequestApp(String requestApp) {
if (!canAccessApp(requestApp)) {
throw new SecurityException("Request App: " + requestApp + " is not allowed");
}
}
private boolean canAccessApp(String requestApp) {
if (isAccessAppEnable) {
for (String app : appWhiteSet) {
if (app.equals(requestApp)) {
return true;
}
}
for (String app : appBlackSet) {
if (app.equals(requestApp)) {
return false;
}
}
return isAppAccessDefault;
}
return true;
}
public static void authenticateRequestToken(String app, String remoteAddress, String timestamp, String version,
String token, String serviceName, String methodName) {
if (needValidateToken(serviceName, methodName)) {
doAuthenticateRequestToken(app, remoteAddress, timestamp, version, token, serviceName, methodName);
}
}
private static void doAuthenticateRequestToken(String app, String remoteAddress, String timestamp, String version,
String token, String serviceName, String methodName) {
if (StringUtils.isBlank(app)) {
throw new SecurityException("Request app is required, from:" + remoteAddress);
}
String secret = appSecrets.get(app);
if (StringUtils.isNotBlank(secret)) {
if (StringUtils.isBlank(token)) {
throw new SecurityException("Request token is required, from:" + remoteAddress + "@" + app);
}
int time = 0;
try {
time = Integer.parseInt(timestamp);
} catch (RuntimeException e) {
}
if (time <= 0) {
throw new SecurityException(
"Request timestamp is invalid:" + timestamp + ", from:" + remoteAddress + "@" + app);
}
long timediff = getCurrentTime() - time;
if (Math.abs(timediff) > tokenTimestampDiff) {
throw new SecurityException("The request has expired:" + timestamp + ", from:" + app);
}
String data = serviceName + "#" + methodName + "#" + time;
String expectToken = SecurityUtils.encrypt(data, secret);
if (!expectToken.equals(token)) {
throw new SecurityException("Invalid request token:" + token + ", from:" + remoteAddress + "@" + app);
}
} else {
throw new SecurityException("Secret not found for app:" + app);
}
}
private static void parseTokenSwitchesConfig(String config) {
ConcurrentHashMap<String, Boolean> map = new ConcurrentHashMap<String, Boolean>();
String[] pairArray = config.split(",");
for (String str : pairArray) {
if (StringUtils.isNotBlank(str)) {
String[] pair = str.split("=");
if (pair != null && pair.length == 2) {
String key = pair[0].trim();
String value = pair[1].trim();
if (StringUtils.isNotBlank(key) && StringUtils.isNotBlank(value)) {
try {
map.put(key, Boolean.valueOf(value));
} catch (RuntimeException e) {
}
}
}
}
}
ConcurrentHashMap<String, Boolean> old = tokenSwitches;
tokenSwitches = map;
old.clear();
}
private static boolean needValidateToken(String serviceName, String methodName) {
if (isTokenEnable) {
if (!tokenSwitches.isEmpty()) {
Boolean enable = tokenSwitches.get(serviceName + "#" + methodName);
if (enable != null) {
return enable;
} else {
enable = tokenSwitches.get(serviceName);
if (enable != null) {
return enable;
}
}
}
return true;
}
return false;
}
@Override
public InvocationResponse invoke(ServiceInvocationHandler handler, ProviderContext invocationContext)
throws Throwable {
String remoteAddress = invocationContext.getChannel().getRemoteAddress();
authenticateRequestIp(remoteAddress);
String requestApp = (String) ContextUtils.getLocalContext("RequestApp");
if (StringUtils.isBlank(requestApp)) {
requestApp = (String) ContextUtils.getLocalContext(Constants.CONTEXT_KEY_CLIENT_APP);
}
authenticateRequestApp(requestApp);
if (needValidateToken(invocationContext.getRequest().getServiceName(),
invocationContext.getRequest().getMethodName())) {
invocationContext.getTimeline().add(new TimePoint(TimePhase.A));
InvocationRequest request = invocationContext.getRequest();
if (request.getMessageType() == Constants.MESSAGE_TYPE_SERVICE) {
boolean isAuth = false;
String from = (String) ContextUtils.getLocalContext("RequestIp");
if (from == null) {
isAuth = true;
}
if (!isTokenEnableForDefaultProtocol
&& Constants.PROTOCOL_DEFAULT.equals(invocationContext.getChannel().getProtocol())) {
isAuth = false;
}
if (isAuth) {
authenticateRequestToken(request, invocationContext);
}
}
}
return handler.handle(invocationContext);
}
private void authenticateRequestToken(InvocationRequest request, ProviderContext invocationContext) {
String remoteAddress = invocationContext.getChannel().getRemoteAddress();
String token = null;
String timestamp = null;
String version = null;
if (request instanceof UnifiedRequest) {
UnifiedRequest _request = (UnifiedRequest) request;
Map<String, String> localContext = _request.getLocalContext();
if (localContext != null) {
token = localContext.get(Constants.REQUEST_KEY_TOKEN);
if (localContext.containsKey(Constants.REQUEST_KEY_TIMESTAMP)) {
timestamp = localContext.get(Constants.REQUEST_KEY_TIMESTAMP);
}
if (localContext.containsKey(Constants.REQUEST_KEY_VERSION)) {
version = localContext.get(Constants.REQUEST_KEY_VERSION);
}
}
} else {
Map<String, Serializable> requestValues = request.getRequestValues();
if (requestValues != null) {
token = (String) requestValues.get(Constants.REQUEST_KEY_TOKEN);
if (requestValues.containsKey(Constants.REQUEST_KEY_TIMESTAMP)) {
timestamp = requestValues.get(Constants.REQUEST_KEY_TIMESTAMP).toString();
}
if (requestValues.containsKey(Constants.REQUEST_KEY_VERSION)) {
version = requestValues.get(Constants.REQUEST_KEY_VERSION).toString();
}
}
}
doAuthenticateRequestToken(request.getApp(), remoteAddress, timestamp, version, token, request.getServiceName(),
request.getMethodName());
}
}