/*
* Copyright (c) 2002-2012 Alibaba Group Holding Limited.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.citrus.turbine.pipeline.valve;
import static com.alibaba.citrus.springext.util.SpringExtUtil.*;
import static com.alibaba.citrus.turbine.util.TurbineUtil.*;
import static com.alibaba.citrus.util.ObjectUtil.*;
import static com.alibaba.citrus.util.StringUtil.*;
import java.util.List;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpSession;
import com.alibaba.citrus.logconfig.support.SecurityLogger;
import com.alibaba.citrus.service.pipeline.PipelineContext;
import com.alibaba.citrus.service.pipeline.support.AbstractValve;
import com.alibaba.citrus.service.pipeline.support.AbstractValveDefinitionParser;
import com.alibaba.citrus.turbine.TurbineRunData;
import com.alibaba.citrus.turbine.util.CsrfToken;
import com.alibaba.citrus.turbine.util.CsrfTokenCheckException;
import com.alibaba.citrus.util.StringUtil;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.support.BeanDefinitionBuilder;
import org.springframework.beans.factory.xml.ParserContext;
import org.w3c.dom.Element;
/**
* 用来检查<code>CsrfToken</code>的valve,用来防止csrf攻击和重复提交同一表单。
*
* @author Michael Zhou
*/
public class CheckCsrfTokenValve extends AbstractValve {
private final SecurityLogger log = new SecurityLogger();
@Autowired
private HttpServletRequest request;
private String tokenKey;
private int maxTokens;
private String expiredPage;
public String getTokenKey() {
return tokenKey;
}
public void setTokenKey(String tokenKey) {
this.tokenKey = trimToNull(tokenKey);
}
public int getMaxTokens() {
return maxTokens;
}
public void setMaxTokens(int maxTokens) {
this.maxTokens = maxTokens;
}
public String getExpiredPage() {
return expiredPage;
}
public void setExpiredPage(String expiredPage) {
this.expiredPage = expiredPage;
}
public String getLogName() {
return this.log.getLogger().getName();
}
public void setLogName(String logName) {
this.log.setLogName(logName);
}
@Override
protected void init() {
tokenKey = defaultIfNull(tokenKey, CsrfToken.DEFAULT_TOKEN_KEY);
}
/** 如果csrf不符,则重定向到出错页面。 */
public void invoke(PipelineContext pipelineContext) throws Exception {
TurbineRunData rundata = getTurbineRunData(request);
// 获取request中的csrf值
String tokenFromRequest = StringUtil.trimToNull(rundata.getParameters().getString(tokenKey));
if (tokenFromRequest != null) {
HttpSession session = rundata.getRequest().getSession();
// 先检查longLiveToken,如果匹配,则不用检查uniqueToken了。
if (!tokenFromRequest.equals(CsrfToken.getLongLiveTokenInSession(session))) {
List<String> tokensInSession = CsrfToken.getTokensInSession(session, tokenKey);
if (!tokensInSession.contains(tokenFromRequest)) {
// 如果不符则终止请求
requestExpired(rundata, tokenFromRequest, tokensInSession);
} else {
// 如果符合,则清除session中相应的token,以防止再次使用它
tokensInSession.remove(tokenFromRequest);
CsrfToken.setTokensInSession(session, tokenKey, tokensInSession);
}
}
}
try {
// 在thread上下文中保存当前的tokenKey,以便使其它csrfToken的检查都能使用统一的key。
CsrfToken.setContextTokenConfiguration(tokenKey, maxTokens);
pipelineContext.invokeNext();
} finally {
CsrfToken.resetContextTokenConfiguration();
}
}
private void requestExpired(TurbineRunData rundata, String tokenFromRequest, List<String> tokensInSession) {
log.getLogger().warn("CsrfToken \"{}\" does not match: requested token is {}, but the session tokens are {}.",
new Object[] { tokenKey, tokenFromRequest, tokensInSession });
// 有两种处理方法,1. 显示expiredPage;2. 抛出异常。
if (expiredPage != null) {
rundata.setRedirectTarget(expiredPage);
} else if (expiredPage == null) {
throw new CsrfTokenCheckException(rundata.getRequest().getRequestURL().toString());
}
}
public static class DefinitionParser extends AbstractValveDefinitionParser<CheckCsrfTokenValve> {
@Override
protected void doParse(Element element, ParserContext parserContext, BeanDefinitionBuilder builder) {
attributesToProperties(element, builder, "tokenKey", "maxTokens", "expiredPage", "logName");
}
}
}