/* * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS HEADER. * * Copyright (c) 2011-2013 Oracle and/or its affiliates. All rights reserved. * * The contents of this file are subject to the terms of either the GNU * General Public License Version 2 only ("GPL") or the Common Development * and Distribution License("CDDL") (collectively, the "License"). You * may not use this file except in compliance with the License. You can * obtain a copy of the License at * https://glassfish.dev.java.net/public/CDDL+GPL_1_1.html * or packager/legal/LICENSE.txt. See the License for the specific * language governing permissions and limitations under the License. * * When distributing the software, include this License Header Notice in each * file and include the License file at packager/legal/LICENSE.txt. * * GPL Classpath Exception: * Oracle designates this particular file as subject to the "Classpath" * exception as provided by Oracle in the GPL Version 2 section of the License * file that accompanied this code. * * Modifications: * If applicable, add the following below the License Header, with the fields * enclosed by brackets [] replaced by your own identifying information: * "Portions Copyright [year] [name of copyright owner]" * * Contributor(s): * If you wish your version of this file to be governed by only the CDDL or * only the GPL Version 2, indicate your decision by adding "[Contributor] * elects to include this software in this distribution under the [CDDL or GPL * Version 2] license." If you don't indicate a single choice of license, a * recipient has the option to distribute your version of this file under * either the CDDL, the GPL Version 2 or to extend the choice of license to * its licensees as provided above. However, if you add GPL Version 2 code * and therefore, elected the GPL Version 2 license, then the option applies * only if the new code is made subject to such option by the copyright * holder. * * * This file incorporates work covered by the following copyright and * permission notice: * * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.catalina.filters; import org.apache.catalina.core.StandardServer; import org.glassfish.logging.annotation.LogMessageInfo; import java.io.IOException; import java.io.Serializable; import java.security.SecureRandom; import java.text.MessageFormat; import java.util.*; import java.util.logging.Logger; import javax.servlet.FilterChain; import javax.servlet.FilterConfig; import javax.servlet.ServletException; import javax.servlet.ServletRequest; import javax.servlet.ServletResponse; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponseWrapper; import javax.servlet.http.HttpSession; /** * Provides basic CSRF protection for a web application. The filter assumes * that: * <ul> * <li>The filter is mapped to /*</li> * <li>{@link HttpServletResponse#encodeRedirectURL(String)} and * {@link HttpServletResponse#encodeURL(String)} are used to encode all URLs * returned to the client * </ul> */ public class CsrfPreventionFilter extends FilterBase { protected static final Logger log = StandardServer.log; @LogMessageInfo( message = "Unable to create Random source using class [{0}]", level = "WARNING" ) public static final String UNABLE_CREATE_RANDOM_SOURCE_EXCEPTION = "AS-WEB-CORE-00286"; private String randomClass = SecureRandom.class.getName(); private Random randomSource; private final Set<String> entryPoints = new HashSet<String>(); private int nonceCacheSize = 5; @Override protected Logger getLogger() { return log; } /** * Entry points are URLs that will not be tested for the presence of a valid * nonce. They are used to provide a way to navigate back to a protected * application after navigating away from it. Entry points will be limited * to HTTP GET requests and should not trigger any security sensitive * actions. * * @param entryPoints Comma separated list of URLs to be configured as * entry points. */ public void setEntryPoints(String entryPoints) { String values[] = entryPoints.split(","); for (String value : values) { this.entryPoints.add(value.trim()); } } /** * Sets the number of previously issued nonces that will be cached on a LRU * basis to support parallel requests, limited use of the refresh and back * in the browser and similar behaviors that may result in the submission * of a previous nonce rather than the current one. If not set, the default * value of 5 will be used. * * @param nonceCacheSize The number of nonces to cache */ public void setNonceCacheSize(int nonceCacheSize) { this.nonceCacheSize = nonceCacheSize; } /** * Specify the class to use to generate the nonces. Must be in instance of * {@link Random}. * * @param randomClass The name of the class to use */ public void setRandomClass(String randomClass) { this.randomClass = randomClass; } @Override public void init(FilterConfig filterConfig) throws ServletException { // Set the parameters super.init(filterConfig); String msg = MessageFormat.format(rb.getString(UNABLE_CREATE_RANDOM_SOURCE_EXCEPTION), randomClass); try { Class<?> clazz = Class.forName(randomClass); randomSource = (Random) clazz.newInstance(); } catch (ClassNotFoundException e) { ServletException se = new ServletException(msg, e); throw se; } catch (InstantiationException e) { ServletException se = new ServletException(msg, e); throw se; } catch (IllegalAccessException e) { ServletException se = new ServletException(msg, e); throw se; } } @Override public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException { ServletResponse wResponse = null; if (request instanceof HttpServletRequest && response instanceof HttpServletResponse) { HttpServletRequest req = (HttpServletRequest) request; HttpServletResponse res = (HttpServletResponse) response; boolean skipNonceCheck = false; if (Constants.METHOD_GET.equals(req.getMethod())) { String path = req.getServletPath(); if (req.getPathInfo() != null) { path = path + req.getPathInfo(); } if (entryPoints.contains(path)) { skipNonceCheck = true; } } HttpSession session = req.getSession(false); @SuppressWarnings("unchecked") LruCache<String> nonceCache = (session == null) ? null : (LruCache<String>) session.getAttribute( Constants.CSRF_NONCE_SESSION_ATTR_NAME); if (!skipNonceCheck) { String previousNonce = req.getParameter(Constants.CSRF_NONCE_REQUEST_PARAM); if (nonceCache == null || previousNonce == null || !nonceCache.contains(previousNonce)) { res.sendError(HttpServletResponse.SC_FORBIDDEN); return; } } if (nonceCache == null) { nonceCache = new LruCache<String>(nonceCacheSize); if (session == null) { session = req.getSession(true); } session.setAttribute( Constants.CSRF_NONCE_SESSION_ATTR_NAME, nonceCache); } String newNonce = generateNonce(); nonceCache.add(newNonce); wResponse = new CsrfResponseWrapper(res, newNonce); } else { wResponse = response; } chain.doFilter(request, wResponse); } @Override protected boolean isConfigProblemFatal() { return true; } /** * Generate a once time token (nonce) for authenticating subsequent * requests. This will also add the token to the session. The nonce * generation is a simplified version of ManagerBase.generateSessionId(). * */ protected String generateNonce() { byte random[] = new byte[16]; // Render the result as a String of hexadecimal digits StringBuilder buffer = new StringBuilder(); randomSource.nextBytes(random); for (int j = 0; j < random.length; j++) { byte b1 = (byte) ((random[j] & 0xf0) >> 4); byte b2 = (byte) (random[j] & 0x0f); if (b1 < 10) { buffer.append((char) ('0' + b1)); } else { buffer.append((char) ('A' + (b1 - 10))); } if (b2 < 10) { buffer.append((char) ('0' + b2)); } else { buffer.append((char) ('A' + (b2 - 10))); } } return buffer.toString(); } protected static class CsrfResponseWrapper extends HttpServletResponseWrapper { private final String nonce; public CsrfResponseWrapper(HttpServletResponse response, String nonce) { super(response); this.nonce = nonce; } @Override @Deprecated public String encodeRedirectUrl(String url) { return encodeRedirectURL(url); } @Override public String encodeRedirectURL(String url) { return addNonce(super.encodeRedirectURL(url)); } @Override @Deprecated public String encodeUrl(String url) { return encodeURL(url); } @Override public String encodeURL(String url) { return addNonce(super.encodeURL(url)); } /** * Return the specified URL with the nonce added to the query string. * * @param url URL to be modified */ private String addNonce(String url) { if ((url == null) || (nonce == null)) { return (url); } String path = url; String query = ""; String anchor = ""; int pound = path.indexOf('#'); if (pound >= 0) { anchor = path.substring(pound); path = path.substring(0, pound); } int question = path.indexOf('?'); if (question >= 0) { query = path.substring(question); path = path.substring(0, question); } StringBuilder sb = new StringBuilder(path); if (query.length() >0) { sb.append(query); sb.append('&'); } else { sb.append('?'); } sb.append(Constants.CSRF_NONCE_REQUEST_PARAM); sb.append('='); sb.append(nonce); sb.append(anchor); return (sb.toString()); } } protected static class LruCache<T> implements Serializable { private static final long serialVersionUID = 1L; // Although the internal implementation uses a Map, this cache // implementation is only concerned with the keys. private final Map<T,T> cache; public LruCache(final int cacheSize) { cache = new FixSizeLinkedHashMap<T,T>(cacheSize); } public void add(T key) { synchronized (cache) { cache.put(key, null); } } public boolean contains(T key) { synchronized (cache) { return cache.containsKey(key); } } } private static class FixSizeLinkedHashMap<K, V> extends LinkedHashMap<K, V> { private static final long serialVersionUID = 1L; private int cacheSize; private FixSizeLinkedHashMap(int cs) { super(); cacheSize = cs; } @Override protected boolean removeEldestEntry(Map.Entry<K,V> eldest) { if (size() > cacheSize) { return true; } return false; } } }