/* * Copyright (c) 2002-2012 Alibaba Group Holding Limited. * All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.alibaba.citrus.turbine.util; import static com.alibaba.citrus.test.TestUtil.*; import static org.easymock.EasyMock.*; import static org.hamcrest.Matchers.*; import static org.hamcrest.Matchers.not; import static org.junit.Assert.*; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpSession; import com.alibaba.citrus.service.requestcontext.util.RequestContextUtil; import com.alibaba.citrus.turbine.util.CsrfToken.DefaultGenerator; import com.alibaba.citrus.util.StringUtil; import com.meterware.httpunit.WebRequest; import org.junit.Before; import org.junit.Test; public class CsrfTokenTests extends AbstractPullToolTests<CsrfToken> { private HttpServletRequest request; private HttpSession session; @Before public void initMock() { request = createMock(HttpServletRequest.class); session = createMock(HttpSession.class); expect(request.getSession()).andReturn(session).anyTimes(); expect(session.getId()).andReturn("aaa").anyTimes(); expect(session.getCreationTime()).andReturn(1234L).anyTimes(); } @Override protected String toolName() { return "csrfToken"; } @Test public void checkScope() throws Exception { assertSame(tool, getTool()); // global scope } @Test public void getConfiguration() throws Exception { tool = new CsrfToken(newRequest); // default values assertEquals(CsrfToken.DEFAULT_TOKEN_KEY, CsrfToken.getKey()); assertEquals(CsrfToken.DEFAULT_MAX_TOKENS, CsrfToken.getMaxTokens()); // thread context key CsrfToken.setContextTokenConfiguration(" testKey ", -1); assertEquals("testKey", CsrfToken.getKey()); assertEquals(CsrfToken.DEFAULT_MAX_TOKENS, CsrfToken.getMaxTokens()); CsrfToken.setContextTokenConfiguration(" testKey ", 2); assertEquals("testKey", CsrfToken.getKey()); assertEquals(2, CsrfToken.getMaxTokens()); // reset CsrfToken.resetContextTokenConfiguration(); assertEquals(CsrfToken.DEFAULT_TOKEN_KEY, CsrfToken.getKey()); assertEquals(CsrfToken.DEFAULT_MAX_TOKENS, CsrfToken.getMaxTokens()); } @Test public void generateLongLiveToken() throws InterruptedException { replay(session); DefaultGenerator g1 = new DefaultGenerator(); DefaultGenerator g2 = new DefaultGenerator(); String token1 = g1.generateLongLiveToken(session); String token2 = g2.generateLongLiveToken(session); assertEquals(token1, token2); assertNotNull(token1); } @Test public void getLongLiveTokenInSession() { replay(session); try { CsrfToken.getLongLiveTokenInSession(null); fail(); } catch (IllegalArgumentException e) { assertThat(e, exception("session")); } String token = CsrfToken.getLongLiveTokenInSession(session); assertNotNull(token); assertTrue(StringUtil.containsOnly(token, "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz".toCharArray())); } @Test @Deprecated public void getLongLiveToken() throws Exception { // ----------------------- // 请求1,取得token String token = tool.getLongLiveToken(); assertNotNull(token); assertThat(tool.getHiddenField(true).toString(), containsString("<input name='_csrf_token' type='hidden' value='")); assertThat(tool.getLongLiveHiddenField().toString(), containsString(token)); assertEquals("_csrf_token", CsrfToken.getKey()); // 同一个请求,再次取得token assertEquals(token, tool.getLongLiveToken()); assertEquals(null, newRequest.getSession().getAttribute("_csrf_token")); commitRequestContext(); // ----------------------- // 请求2,再次取得token getInvocationContext("http://localhost/app1/1.html"); initRequestContext(); assertEquals(token, tool.getLongLiveToken()); assertThat(tool.getHiddenField(true).toString(), containsString("<input name='_csrf_token' type='hidden' value='")); assertThat(tool.getLongLiveHiddenField().toString(), containsString(token)); assertEquals("_csrf_token", CsrfToken.getKey()); // 和unique token混用 String token2 = tool.getUniqueToken(); assertNotNull(token2); assertThat(tool.getHiddenField(false).toString(), containsString("<input name='_csrf_token' type='hidden' value='")); assertThat(tool.getUniqueHiddenField().toString(), containsString(token2)); assertEquals("_csrf_token", CsrfToken.getKey()); assertEquals(token2, newRequest.getSession().getAttribute("_csrf_token")); commitRequestContext(); // ----------------------- // 请求3,取得token getInvocationContext("http://localhost/app1/1.html"); initRequestContext(); assertEquals(token, tool.getLongLiveToken()); assertThat(tool.getHiddenField(true).toString(), containsString("<input name='_csrf_token' type='hidden' value='")); assertThat(tool.getLongLiveHiddenField().toString(), containsString(token)); assertEquals("_csrf_token", CsrfToken.getKey()); assertEquals(token2, newRequest.getSession().getAttribute("_csrf_token")); // 让session过期,但保持sessionId final String sessionId = newRequest.getSession().getId(); newRequest.getSession().invalidate(); commitRequestContext(); // ----------------------- // 请求4,再次取得token assertEquals("", client.getCookieValue("JSESSIONID")); assertTrue(client.getCookieDetails("JSESSIONID").isExpired()); Thread.sleep(10); // 确保创建时间改变 getInvocationContext("http://localhost/app1/1.html", new WebRequestCallback() { public void process(WebRequest wr) { wr.setHeaderField("Cookie", "JSESSIONID=" + sessionId); } }); initRequestContext(); assertTrue(newRequest.getSession().isNew()); // 新session assertEquals(sessionId, newRequest.getSession().getId()); // id不变 String token3 = tool.getLongLiveToken(); assertFalse(token.equals(token3)); // token由id和时间共同生成,因此即使id不变,token也改变 assertNotNull(token3); assertThat(tool.getHiddenField(true).toString(), containsString("<input name='_csrf_token' type='hidden' value='")); assertThat(tool.getLongLiveHiddenField().toString(), containsString(token3)); assertEquals("_csrf_token", CsrfToken.getKey()); commitRequestContext(); } @Test public void getUniqueToken() throws Exception { // ----------------------- // 请求1,取得token String token = tool.getUniqueToken(); assertNotNull(token); assertThat(tool.getUniqueHiddenField().toString(), containsString("<input name='_csrf_token' type='hidden' value='")); assertThat(tool.getUniqueHiddenField().toString(), containsString(token)); assertEquals("_csrf_token", CsrfToken.getKey()); // 同一个请求,再次取得token assertEquals(token, tool.getUniqueToken()); assertEquals(token, newRequest.getSession().getAttribute("_csrf_token")); commitRequestContext(); // ----------------------- // 请求2,再次取得token getInvocationContext("http://localhost/app1/1.html"); initRequestContext(); String token2 = tool.getUniqueToken(); assertNotNull(token2); assertThat(tool.getUniqueHiddenField().toString(), containsString("<input name='_csrf_token' type='hidden' value='")); assertThat(tool.getUniqueHiddenField().toString(), containsString(token2)); assertEquals("_csrf_token", CsrfToken.getKey()); // 同一个请求,再次取得token assertEquals(token2, tool.getUniqueToken()); assertEquals(token + "/" + token2, newRequest.getSession().getAttribute("_csrf_token")); commitRequestContext(); // ----------------------- // 请求3-8,取得token String tokens = token + "/" + token2; for (int i = 2; i < 8; i++) { getInvocationContext("http://localhost/app1/1.html"); initRequestContext(); String token_i = tool.getUniqueToken(); assertNotNull(token_i); assertThat(tool.getUniqueHiddenField().toString(), containsString("<input name='_csrf_token' type='hidden' value='")); assertThat(tool.getUniqueHiddenField().toString(), containsString(token_i)); assertEquals("_csrf_token", CsrfToken.getKey()); // 同一个请求,再次取得token assertEquals(token_i, tool.getUniqueToken()); assertEquals(tokens += "/" + token_i, newRequest.getSession().getAttribute("_csrf_token")); commitRequestContext(); } // ----------------------- // 请求9,取得token,抛弃第一个token getInvocationContext("http://localhost/app1/1.html"); initRequestContext(); String token_9 = tool.getUniqueToken(); assertEquals(token_9, tool.getUniqueToken()); tokens += "/" + token_9; tokens = tokens.substring(tokens.indexOf("/") + 1); assertEquals(tokens, newRequest.getSession().getAttribute("_csrf_token")); commitRequestContext(); // ----------------------- // 请求10,取得token,设置maxTokens=3 getInvocationContext("http://localhost/app1/1.html"); initRequestContext(); CsrfToken.setContextTokenConfiguration(null, 3); String token_10 = tool.getUniqueToken(); assertEquals(token_10, tool.getUniqueToken()); tokens += "/" + token_10; tokens = tokens.substring(indexOf(tokens, "/", 6) + 1); assertEquals(tokens, newRequest.getSession().getAttribute("_csrf_token")); CsrfToken.resetContextTokenConfiguration(); commitRequestContext(); } private int indexOf(String str, String strToFind, int count) { int index = -1; for (int i = 0; i < count; i++) { index = str.indexOf(strToFind, index + 1); } return index; } @Test public void check_defaultKey_succ() { expect(request.getParameter("_csrf_token")).andReturn("any"); replay(request); assertTrue(CsrfToken.check(request)); verify(request); } @Test public void check_defaultKey_failed() { expect(request.getParameter("_csrf_token")).andReturn(null); replay(request); assertFalse(CsrfToken.check(request)); verify(request); } @Test public void check_contextKey_succ() { CsrfToken.setContextTokenConfiguration("contextKey", -1); expect(request.getParameter("contextKey")).andReturn("any"); replay(request); assertTrue(CsrfToken.check(request)); CsrfToken.resetContextTokenConfiguration(); verify(request); } @Test public void check_contextKey_failed() { CsrfToken.setContextTokenConfiguration("contextKey", -1); expect(request.getParameter("contextKey")).andReturn(null); replay(request); assertFalse(CsrfToken.check(request)); CsrfToken.resetContextTokenConfiguration(); verify(request); } @Test public void toString_() { // in request String token = tool.getUniqueToken(); assertNotNull(token); assertThat(tool.toString(), not(equalTo("<No thread-bound request>"))); // not in request requestContexts.commitRequestContext(RequestContextUtil.getRequestContext(newRequest)); assertEquals("<No thread-bound request>", tool.toString()); } }