/** * The contents of this file are subject to the license and copyright * detailed in the LICENSE file at the root of the source * tree and available online at * * https://github.com/keeps/roda */ package org.roda.wui.filter; import java.io.IOException; import java.util.ArrayList; import java.util.List; import javax.servlet.Filter; import javax.servlet.FilterChain; import javax.servlet.FilterConfig; import javax.servlet.ServletException; import javax.servlet.ServletRequest; import javax.servlet.ServletResponse; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequestWrapper; import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpSession; import org.apache.commons.lang3.StringUtils; import org.jasig.cas.client.util.AbstractCasFilter; import org.jasig.cas.client.validation.Assertion; import org.roda.core.data.exceptions.AuthenticationDeniedException; import org.roda.core.data.exceptions.GenericException; import org.roda.core.data.v2.common.Pair; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * CAS authentication filter for API requests. * * @author Rui Castro <rui.castro@gmail.com> */ public class CasApiAuthFilter implements Filter { /** Logger. */ private static final Logger LOGGER = LoggerFactory.getLogger(CasApiAuthFilter.class); /** List of excluded URLs. */ private final List<String> exclusions = new ArrayList<>(); /** CAS client. */ private CasClient casClient; @Override public void init(final FilterConfig filterConfig) throws ServletException { this.casClient = new CasClient(filterConfig.getInitParameter("casServerUrlPrefix")); final String exclusionsParam = filterConfig.getInitParameter("exclusions"); if (StringUtils.isNotBlank(exclusionsParam)) { final String[] listOfExclusions = exclusionsParam.split(","); for (String exclusion : listOfExclusions) { this.exclusions.add(exclusion.trim()); } } } @Override public void doFilter(final ServletRequest servletRequest, final ServletResponse servletResponse, final FilterChain filterChain) throws IOException, ServletException { final HttpServletRequest request = (HttpServletRequest) servletRequest; final HttpServletResponse response = (HttpServletResponse) servletResponse; if (isRequestUrlExcluded(request)) { LOGGER.debug("Request is ignored."); filterChain.doFilter(request, response); return; } final HttpSession session = request.getSession(false); final Assertion assertion = session != null ? (Assertion) session.getAttribute(AbstractCasFilter.CONST_CAS_ASSERTION) : null; if (assertion != null) { filterChain.doFilter(request, response); return; } try { final String tgt = request.getHeader("TGT"); final Pair<String, String> credentials = new BasicAuthRequestWrapper(request).getCredentials(); if (StringUtils.isNotBlank(tgt)) { doFilterWithTGT(request, response, filterChain, tgt); } else if (credentials != null) { // TGT is blank. Try to use username and password doFilterWithCredentials(request, response, filterChain, credentials.getFirst(), credentials.getSecond()); } else { LOGGER.debug("No username and password"); response.sendError(HttpServletResponse.SC_UNAUTHORIZED, "No credentials"); } } catch (final AuthenticationDeniedException e) { LOGGER.debug(e.getMessage(), e); response.sendError(HttpServletResponse.SC_UNAUTHORIZED, e.getMessage()); } catch (final GenericException e) { throw new ServletException(e.getMessage(), e); } } @Override public void destroy() { // do nothing } private void doFilterWithTGT(final HttpServletRequest request, final HttpServletResponse response, final FilterChain filterChain, final String tgt) throws GenericException, IOException, ServletException { final String serviceUrl = constructServiceUrl(request); final String st = this.casClient.getServiceTicket(tgt, serviceUrl); filterChain.doFilter(new ServiceTicketRequestWrapper(request, st), response); } private void doFilterWithCredentials(final HttpServletRequest request, final HttpServletResponse response, final FilterChain filterChain, final String username, final String password) throws GenericException, IOException, ServletException, AuthenticationDeniedException { final String tgt = this.casClient.getTicketGrantingTicket(username, password); doFilterWithTGT(request, response, filterChain, tgt); } private String constructServiceUrl(final HttpServletRequest request) { return String.format("%s?%s", request.getRequestURL(), request.getQueryString()); } /** * Is the requested path in the list of exclusions? * * @param request * the request. * * @return <code>true</code> if it is excluded and <code>false</code> * otherwise. */ private boolean isRequestUrlExcluded(final HttpServletRequest request) { for (String exclusion : this.exclusions) { if (request.getPathInfo().matches(exclusion)) { return true; } } return false; } /** * A {@link HttpServletRequestWrapper} that adds a <code>ticket</code> query * string parameter. * * @author Rui Castro <rui.castro@gmai.com> */ private class ServiceTicketRequestWrapper extends HttpServletRequestWrapper { /** * CAS service ticket. */ private final String serviceTicket; /** * The query string. */ private String queryString = null; /** * Constructor. * * @param request * the HTTP request. * @param serviceTicket * the CAS service ticket. */ ServiceTicketRequestWrapper(final HttpServletRequest request, final String serviceTicket) { super(request); setAttribute("ticket", serviceTicket); this.serviceTicket = serviceTicket; } @Override public String getQueryString() { if (this.queryString == null) { String qs = super.getQueryString(); if (StringUtils.isBlank(qs)) { qs = ""; } else { qs = qs + "&"; } this.queryString = String.format("%sticket=%s", qs, this.serviceTicket); } return this.queryString; } @Override public String getParameter(final String name) { if ("ticket".equals(name)) { return this.serviceTicket; } else { return super.getParameter(name); } } } }