/*
* Copyright (c) 2002-2012 Alibaba Group Holding Limited.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.citrus.turbine.util;
import static com.alibaba.citrus.util.Assert.*;
import static com.alibaba.citrus.util.CollectionUtil.*;
import static com.alibaba.citrus.util.StringUtil.*;
import java.util.LinkedList;
import java.util.List;
import java.util.Random;
import java.util.concurrent.atomic.AtomicInteger;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpSession;
import com.alibaba.citrus.service.pull.ToolFactory;
import com.alibaba.citrus.springext.support.parser.AbstractSingleBeanDefinitionParser;
import com.alibaba.citrus.util.ClassLoaderUtil;
import com.alibaba.citrus.util.ServiceNotFoundException;
import com.alibaba.citrus.util.StringUtil;
import org.apache.commons.codec.digest.DigestUtils;
import org.apache.ecs.html.Input;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
/**
* 用于生成一个唯一的ID,来防止CSRF攻击(Cross Site Request Forgery)。 此外,还可以用来防止重复提交同一张表单。
* <p>
* 该类可作为pull tool,由于采用了singleton request proxy,所以该类可被注册成global作用域的pull tool。
* </p>
* <p>
* CSRF token的key是按照以下逻辑取得的:
* </p>
* <ol>
* <li>如果Thread上下文<code>setContextTokenKey()</code>被明确设置,则使用它;</li>
* <li>否则,使用默认值“<code>_csrf_token</code>”。</li>
* </ol>
*
* @author Michael Zhou
*/
public class CsrfToken {
public static final String DEFAULT_TOKEN_KEY = "_csrf_token";
public static final int DEFAULT_MAX_TOKENS = 8;
public static final String CSRF_TOKEN_SEPARATOR = "/";
private static final AtomicInteger counter = new AtomicInteger();
;
private static final ThreadLocal<Configuration> contextTokenConfigurationHolder = new ThreadLocal<Configuration>();
private final HttpServletRequest request;
public CsrfToken(HttpServletRequest request) {
this.request = assertNotNull(request, "request");
}
public static String getKey() {
String key = null;
Configuration conf = contextTokenConfigurationHolder.get();
if (conf != null) {
key = conf.getTokenKey();
}
if (key == null) {
key = DEFAULT_TOKEN_KEY;
}
return key;
}
public static int getMaxTokens() {
int maxTokens = -1;
Configuration conf = contextTokenConfigurationHolder.get();
if (conf != null) {
maxTokens = conf.getMaxTokens();
}
if (maxTokens <= 0) {
maxTokens = DEFAULT_MAX_TOKENS;
}
return maxTokens;
}
public static void setContextTokenConfiguration(String tokenKey, int maxTokens) {
contextTokenConfigurationHolder.set(new Configuration(tokenKey, maxTokens));
}
public static void resetContextTokenConfiguration() {
contextTokenConfigurationHolder.remove();
}
/** 创建包含csrf token的hidden field。 所生成的token会保持有效,直到session过期。 */
public Input getHiddenField() {
return getLongLiveHiddenField();
}
/**
* 创建包含csrf token的hidden field。
*
* @param longLiveToken 如果为<code>true</code>,则token会保持有效,直到session过期。
* @deprecated use getUniqueHiddenField() or getLongLiveHiddenField()
* instead
*/
@Deprecated
public Input getHiddenField(boolean longLiveToken) {
return longLiveToken ? getLongLiveHiddenField() : getUniqueHiddenField();
}
public Input getUniqueHiddenField() {
return new Input("hidden", getKey(), getUniqueToken());
}
public Input getLongLiveHiddenField() {
return new Input("hidden", getKey(), getLongLiveToken());
}
/** 创建csrf token,所生成的token只能被使用一次。 */
public String getUniqueToken() {
HttpSession session = request.getSession();
String key = getKey();
String tokenOfRequest = (String) request.getAttribute(key);
int maxTokens = getMaxTokens();
if (tokenOfRequest == null) {
// 创建新的token。
// 如果当前session中已经有token了,
// 并且token数没有超过最大数,则将token追加到session中;
// 如果token超过最大数,则覆盖最早的token。
LinkedList<String> tokens = getTokensInSession(session, key);
tokenOfRequest = getGenerator().generateUniqueToken();
request.setAttribute(key, tokenOfRequest);
tokens.addLast(tokenOfRequest);
while (tokens.size() > maxTokens) {
tokens.removeFirst();
}
setTokensInSession(session, key, tokens);
}
return tokenOfRequest;
}
/** 取得长效token。和<code>uniqueToken</code> 不同,长效token的寿命和session相同。 */
public String getLongLiveToken() {
return getLongLiveTokenInSession(request.getSession());
}
public static LinkedList<String> getTokensInSession(HttpSession session, String tokenKey) {
return createLinkedList(StringUtil.split((String) session.getAttribute(tokenKey), CSRF_TOKEN_SEPARATOR));
}
public static void setTokensInSession(HttpSession session, String tokenKey, List<String> tokens) {
if (tokens.isEmpty()) {
session.removeAttribute(tokenKey);
} else {
session.setAttribute(tokenKey, StringUtil.join(tokens, CSRF_TOKEN_SEPARATOR));
}
}
public static String getLongLiveTokenInSession(HttpSession session) {
return getGenerator().generateLongLiveToken(session);
}
@Override
public String toString() {
try {
return getUniqueToken();
} catch (IllegalStateException e) {
return "<No thread-bound request>";
}
}
/** 检查token,如果token存在,则返回<code>true</code>。 */
public static boolean check(HttpServletRequest request) {
String key = getKey();
String fromRequest = trimToNull(request.getParameter(key));
return fromRequest != null;
}
private static class Configuration {
private final String tokenKey;
private final int maxTokens;
public Configuration(String tokenKey, int maxTokens) {
this.tokenKey = trimToNull(tokenKey);
this.maxTokens = maxTokens;
}
public String getTokenKey() {
return tokenKey;
}
public int getMaxTokens() {
return maxTokens;
}
}
public static class DefinitionParser extends AbstractSingleBeanDefinitionParser<Factory> {
}
/** pull tool factory。 */
public static class Factory implements ToolFactory {
private HttpServletRequest request;
@Autowired
public void setRequest(HttpServletRequest request) {
this.request = request;
}
public boolean isSingleton() {
return true;
}
public Object createTool() throws Exception {
return new CsrfToken(request);
}
}
private static Logger log = LoggerFactory.getLogger(CsrfToken.class);
private static final Generator generator = new DefaultGenerator();
private static final Generator generatorOverride = getGeneratorOverride();
private static Generator getGeneratorOverride() {
try {
return Generator.class.cast(ClassLoaderUtil.newServiceInstance("csrfTokenGeneratorOverride",
CsrfToken.class));
} catch (ServiceNotFoundException e) {
// ignore
} catch (Exception e) {
log.warn("Failure in CsrfToken.getGeneratorOverride()", e);
}
return null;
}
private static Generator getGenerator() {
return generatorOverride != null ? generatorOverride : generator;
}
/** 允许其它模块override生成token的算法。 */
public interface Generator {
String generateUniqueToken();
String generateLongLiveToken(HttpSession session);
}
static class DefaultGenerator implements Generator {
private final long seed = new Random().nextLong();
public String generateUniqueToken() {
return longToString(counter.getAndIncrement()) + longToString(seed + System.currentTimeMillis());
}
public String generateLongLiveToken(HttpSession session) {
String sessionId = assertNotNull(session, "session").getId();
byte[] digest = DigestUtils.md5(session.getCreationTime() + sessionId);
return StringUtil.bytesToString(digest);
}
}
}