package org.sigmah.server.servlet.base; /* * #%L * Sigmah * %% * Copyright (C) 2010 - 2016 URD * %% * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as * published by the Free Software Foundation, either version 3 of the * License, or (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public * License along with this program. If not, see * <http://www.gnu.org/licenses/gpl-3.0.html>. * #L% */ import java.io.IOException; import java.io.InputStream; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.util.regex.Matcher; import java.util.regex.Pattern; import javax.persistence.EntityManager; import javax.servlet.ServletConfig; import javax.servlet.ServletException; import javax.servlet.http.HttpServlet; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import javax.validation.ConstraintViolationException; import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.time.StopWatch; import org.sigmah.client.page.RequestParameter; import org.sigmah.client.util.ClientUtils; import org.sigmah.server.conf.Properties; import org.sigmah.server.domain.User; import org.sigmah.server.domain.util.DomainFilters; import org.sigmah.server.inject.ServletModule; import org.sigmah.server.mapper.Mapper; import org.sigmah.server.security.SecureSessionValidator; import org.sigmah.server.security.SecureSessionValidator.Access; import org.sigmah.server.servlet.util.Servlets; import org.sigmah.shared.conf.PropertyKey; import org.sigmah.shared.security.InvalidSessionException; import org.sigmah.shared.security.UnauthorizedAccessException; import org.sigmah.shared.servlet.ServletConstants; import org.sigmah.shared.servlet.ServletConstants.Servlet; import org.sigmah.shared.servlet.ServletConstants.ServletMethod; import org.sigmah.shared.util.FileType; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import com.google.gwt.http.client.Response; import com.google.inject.Inject; import com.google.inject.Provider; /** * <p> * Abstract additional servlet which secures access to its methods. * All additional <em>secured</em> servlet should inherit this abstract layer. * </p> * The declared child servlet methods must have this signature: * * <pre> * [method_name] (HttpServletRequest request, HttpServletResponse response, ServletExecutionContext context) throws [any exception(s)]; * </pre> * * @author Denis Colliot (dcolliot@ideia.fr) */ public abstract class AbstractServlet extends HttpServlet { /** * Serial id. */ private static final long serialVersionUID = 301456647415093255L; /** * Logger. */ private static final Logger LOG = LoggerFactory.getLogger(AbstractServlet.class); /** * HTML servlet error page filename. */ private static final String ERROR_PAGE_NAME = "servlet-error.html"; /** * Injected secure session validator. */ @Inject private SecureSessionValidator secureSessionValidator; /** * Injected application properties service. */ @Inject private Properties properties; /** * Injected {@link EntityManager} provider. */ @Inject private Provider<EntityManager> entityManagerProvider; /** * Injected {@link Mapper}. */ @Inject private Mapper mapper; /** * HTML error page template. */ private String template; /** * {@inheritDoc} */ @Override public final void init(final ServletConfig config) throws ServletException { if (LOG.isDebugEnabled()) { LOG.debug("Reading HTML error page template."); } try (final InputStream is = getClass().getResourceAsStream(ERROR_PAGE_NAME)) { template = Servlets.readAll(is); // Replaces tags. template = template.replaceAll(Pattern.quote("<!-- ${AppName} -->"), Matcher.quoteReplacement(properties.getProperty(PropertyKey.APP_NAME))); } catch (final IOException e) { throw new ServletException("Cannot read the HTML page template.", e); } } /** * {@inheritDoc} */ @Override public final void log(final String msg) { this.log(msg, null); } /** * {@inheritDoc} */ @Override public final void log(final String message, final Throwable t) { if (t != null) { if (LOG.isErrorEnabled()) { LOG.error(message, t); } } else { if (LOG.isDebugEnabled()) { LOG.debug(message); } } } /** * Secures the given {@code servletMethod} execution. * * @param request * The HTTP request. * @param response * The HTTP response. * @param servletMethod * Java servlet method to execute once user session has been secured. * @throws ServletException * If the servlet execution fails. */ private void secureServlet(final HttpServletRequest request, final HttpServletResponse response, final Method servletMethod) throws ServletException { if (servletMethod == null) { if (LOG.isErrorEnabled()) { LOG.error("The given servlet method {} is null.", servletMethod); } throw new IllegalArgumentException("Servlet method is required."); } User user = null; try { // Validates the user session and user access. final String authenticationToken = request.getParameter(ServletConstants.AUTHENTICATION_TOKEN); final String originPageToken = request.getParameter(ServletConstants.ORIGIN_PAGE_TOKEN); final String servletPath = request.getRequestURI().replaceFirst(ServletModule.ENDPOINT, ""); final Servlet servletEnum = Servlet.fromPathName(servletPath); final ServletMethod servletMethodEnum = ServletMethod.fromMethodName(servletMethod.getName()); final Access access = secureSessionValidator.validate(authenticationToken, servletEnum, servletMethodEnum, originPageToken); user = access.getUser(); switch (access.getAccessType()) { case INVALID_SESSION: if (LOG.isDebugEnabled()) { LOG.debug("SERVLET METHOD EXECUTION FAILED - Servlet method: '{}' ; User: '{}' ; Error: Invalid auth token '{}'.", servletMethod, Servlets.logUser(user), authenticationToken); } throw new InvalidSessionException("Your session is no longer valid."); case UNAUTHORIZED_ACCESS: if (LOG.isDebugEnabled()) { LOG.debug("SERVLET METHOD EXECUTION FAILED - Servlet method: '{}' ; User: '{}' ; Error: Unauthorized process.", servletMethod, Servlets.logUser(user)); } throw new UnauthorizedAccessException("You are not authorized to execute this process."); default: // Access granted, executes servlet method. if (LOG.isDebugEnabled()) { LOG.debug("SERVLET METHOD EXECUTION GRANTED - Servlet method: '{}' ; User: '{}'.", servletMethod, Servlets.logUser(user)); } // Activate filters into hibernate session. DomainFilters.applyUserFilter(user, entityManagerProvider.get()); final StopWatch chrono = new StopWatch(); chrono.start(); servletMethod.setAccessible(true); servletMethod.invoke(this, request, response, new ServletExecutionContext(access.getUser(), request, originPageToken)); if (LOG.isDebugEnabled()) { LOG.debug("SERVLET METHOD '{}' EXECUTED IN {} MS.", servletMethod, chrono.getTime()); } } } catch (final InvocationTargetException e) { // NO NEED TO LOG EXCEPTION HERE. if (e.getTargetException() instanceof ServletException) { // Servlet exception. throw (ServletException) e.getTargetException(); } else if (e.getTargetException() instanceof ConstraintViolationException) { // Bean validation failed. final ConstraintViolationException cve = (ConstraintViolationException) e.getTargetException(); if (LOG.isErrorEnabled()) { LOG.error("SERVLET METHOD EXECUTION FAILED - Servlet method: '" + servletMethod + "' ; User: '" + Servlets.logUser(user) + "' ; Error: A bean validation failed during servlet method execution. Consider performing the validation on client-side.\n" + Servlets.logConstraints(cve.getConstraintViolations())); } throw new ServletException(e.getCause().getMessage(), cve); } else { throw new ServletException(e.getCause().getMessage(), e.getTargetException()); } } catch (final Throwable e) { // Server unknown error. throw new ServletException(e.getMessage(), e); } } /** * Retrieves {@code java} method to execute from {@code request} and calls * {@link #secureServlet(HttpServletRequest, HttpServletResponse, Method)}. * * @param servletMethodName * The real servlet method name ({@code doGet}, {@code doPost}, etc.). * @param request * The HTTP request. * @param response * The HTTP response. * @throws ServletException * If an error occurs while executing servlet process. */ private void secureServletMethod(final String servletMethodName, final HttpServletRequest request, final HttpServletResponse response) throws ServletException { if (LOG.isDebugEnabled()) { LOG.debug("Executing specific '{}' servlet method.", servletMethodName); } // Retrieving method name from request. final String methodName = request.getParameter(ServletConstants.SERVLET_METHOD); boolean popupDestination = false; try { if (LOG.isDebugEnabled()) { LOG.debug("Retrieving by reflection the given servlet method '{}'.", methodName); } if (StringUtils.isBlank(methodName)) { return; } // Retrieving servlet method. final Method servletMethod = getClass().getDeclaredMethod(methodName, HttpServletRequest.class, HttpServletResponse.class, ServletExecutionContext.class); final ServletMethod servletMethodEnum = ServletMethod.fromMethodName(servletMethod.getName()); popupDestination = servletMethodEnum != null && servletMethodEnum.isPopup(); // Secure servlet method. secureServlet(request, response, servletMethod); } catch (final StatusServletException e) { handleException(request, response, servletMethodName, popupDestination, e, e.getStatusCode()); } catch (final Throwable caught) { handleException(request, response, servletMethodName, popupDestination, caught, Response.SC_INTERNAL_SERVER_ERROR); } } // --------------------------------------------------------------------------------------- // // SECURED DEFAULT SERVLET METHODS. // // --------------------------------------------------------------------------------------- /** * Servlet {@code GET} method name. */ private static final String DO_GET_METHOD_NAME = "doGet"; @Override final protected void doPost(final HttpServletRequest request, final HttpServletResponse response) throws ServletException, IOException { secureServletMethod("doPost", request, response); } @Override final protected void doGet(final HttpServletRequest request, final HttpServletResponse response) throws ServletException, IOException { secureServletMethod(DO_GET_METHOD_NAME, request, response); } @Override final protected void doDelete(final HttpServletRequest request, final HttpServletResponse response) throws ServletException, IOException { secureServletMethod("doDelete", request, response); } @Override final protected void doOptions(final HttpServletRequest request, final HttpServletResponse response) throws ServletException, IOException { secureServletMethod("doOptions", request, response); } @Override final protected void doHead(final HttpServletRequest request, final HttpServletResponse response) throws ServletException, IOException { secureServletMethod("doHead", request, response); } @Override final protected void doPut(final HttpServletRequest request, final HttpServletResponse response) throws ServletException, IOException { secureServletMethod("doPut", request, response); } @Override final protected void doTrace(final HttpServletRequest request, final HttpServletResponse response) throws ServletException, IOException { secureServletMethod("doTrace", request, response); } // --------------------------------------------------------------------------------------- // // UTILITY METHODS. // // --------------------------------------------------------------------------------------- /** * Returns the application {@link Properties} service. * * @return the application {@link Properties} service, never {@code null}. */ protected final Properties prop() { return properties; } /** * Returns the application {@link Mapper} service. * * @return the application {@link Mapper} service, never {@code null}. */ protected final Mapper mapper() { return mapper; } /** * Returns the given {@code paramKey} corresponding value from the {@code request}. * * @param request * The HTTP request. * @param paramKey * The {@link RequestParameter} key. * @param acceptNull * {@code true} to accept a {@code null} value result, {@code false} to throw a * {@link StatusServletException} if the value is {@code null}. * @return The given {@code paramKey} corresponding value from the {@code request}. * @throws StatusServletException * If the parameter value is {@code null} <b>and</b> {@code acceptNull} is set to {@code false}. */ protected static final String getParameter(final HttpServletRequest request, final RequestParameter paramKey, boolean acceptNull) throws StatusServletException { final String value = ClientUtils.deletePreTags(request.getParameter(paramKey.getRequestName())); if (StringUtils.isBlank(value) || "null".equals(value)) { if (acceptNull) { return null; } if (LOG.isWarnEnabled()) { LOG.warn("No value for parameter key '{}'.", paramKey); } throw new StatusServletException(Response.SC_BAD_REQUEST); } return value; } /** * Returns the given {@code paramKey} corresponding {@link Integer} value from the {@code request}. * * @param request * The HTTP request. * @param paramKey * The {@link RequestParameter} key. * @param acceptNull * {@code true} to accept a {@code null} value result, {@code false} to throw a * {@link StatusServletException} if the value is {@code null}. * @return The given {@code paramKey} corresponding {@link Integer} value from the {@code request}. * @throws StatusServletException * If the parameter value is {@code null} <b>and</b> {@code acceptNull} is set to {@code false}. */ protected static final Integer getIntegerParameter(final HttpServletRequest request, final RequestParameter paramKey, boolean acceptNull) throws StatusServletException { final String intValue = getParameter(request, paramKey, acceptNull); if (StringUtils.isBlank(intValue)) { if (acceptNull) { return null; } throw new StatusServletException(Response.SC_BAD_REQUEST); } try { return Integer.parseInt(intValue); } catch (final NumberFormatException e) { LOG.error("Error while parsing the integer parameter '" + intValue + "'.", e); throw new StatusServletException(Response.SC_BAD_REQUEST); } } /** * Returns the given {@code paramKey} corresponding {@link Boolean} value from the {@code request}. * * @param request * The HTTP request. * @param paramKey * The {@link RequestParameter} key. * @param acceptNull * {@code true} to accept a {@code null} value result, {@code false} to throw a * {@link StatusServletException} if the value is {@code null}. * @return The given {@code paramKey} corresponding {@link Boolean} value from the {@code request}. * @throws StatusServletException * If the parameter value is {@code null} <b>and</b> {@code acceptNull} is set to {@code false}. */ protected static final Boolean getBooleanParameter(final HttpServletRequest request, final RequestParameter paramKey, boolean acceptNull) throws StatusServletException { final String booleanValue = getParameter(request, paramKey, acceptNull); if (StringUtils.isBlank(booleanValue)) { if (acceptNull) { return null; } throw new StatusServletException(Response.SC_BAD_REQUEST); } try { return Boolean.parseBoolean(booleanValue); } catch (final NumberFormatException e) { throw new StatusServletException(Response.SC_BAD_REQUEST); } } /** * <p> * Handles the {@code caught} exception. * </p> * <p> * <ul> * <li>If {@code GET} access (direct access and not ajax call), writes into the {@code response} the HTML error page * content.</li> * <li>Else, writes into the {@code response} the given {@code errorCode} as header and * {@link ServletConstants#ERROR_RESPONSE_CONTENT} as content.</li> * </ul> * </p> * * @param request * The HTTP request. * @param response * The HTTP response. * @param servletMethodName * The {@link ServletMethod} value. * @param popupDestination * Is the servlet process destined to be displayed into a pop-up window? * @param caught * The throwable. * @param errorCode * The error code set on the {@code response}. */ private void handleException(final HttpServletRequest request, final HttpServletResponse response, final String servletMethodName, final boolean popupDestination, final Throwable caught, final int errorCode) { if (LOG.isErrorEnabled()) { LOG.error("Exception while executing '" + getClass().getName() + '#' + servletMethodName + "' servlet method.", caught); } try { response.setContentType(FileType.HTML.getContentType()); final String htmlMessage = caught.getClass().getSimpleName() + " : " + caught.getMessage(); final boolean ajaxCall = ClientUtils.isTrue(request.getParameter(ServletConstants.AJAX)); if (DO_GET_METHOD_NAME.equals(servletMethodName) && !ajaxCall) { // If the servlet method is executed using HTTP {@code GET} method. String html = template; html = html.replaceAll(Pattern.quote("<!-- ${MessageContent} -->"), Matcher.quoteReplacement(htmlMessage)); html = html.replaceAll(Pattern.quote("<!-- ${ButtonDisplay} -->"), Servlets.cssDisplay(popupDestination)); response.setCharacterEncoding(Servlets.UTF8_CHARSET); response.getWriter().write(html); } else { // Other method. response.setStatus(errorCode); response.getWriter().write(ServletConstants.buildErrorResponse(errorCode)); } } catch (final IOException ioe) { // Nothing to do ; 'getWriter()' has just failed. if (LOG.isErrorEnabled()) { LOG.error("'getWriter()' method has raised an exception.", ioe); } } } }