package org.sigmah.server.servlet.base;
/*
* #%L
* Sigmah
* %%
* Copyright (C) 2010 - 2016 URD
* %%
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as
* published by the Free Software Foundation, either version 3 of the
* License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public
* License along with this program. If not, see
* <http://www.gnu.org/licenses/gpl-3.0.html>.
* #L%
*/
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import javax.persistence.EntityManager;
import javax.servlet.ServletConfig;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.validation.ConstraintViolationException;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.time.StopWatch;
import org.sigmah.client.page.RequestParameter;
import org.sigmah.client.util.ClientUtils;
import org.sigmah.server.conf.Properties;
import org.sigmah.server.domain.User;
import org.sigmah.server.domain.util.DomainFilters;
import org.sigmah.server.inject.ServletModule;
import org.sigmah.server.mapper.Mapper;
import org.sigmah.server.security.SecureSessionValidator;
import org.sigmah.server.security.SecureSessionValidator.Access;
import org.sigmah.server.servlet.util.Servlets;
import org.sigmah.shared.conf.PropertyKey;
import org.sigmah.shared.security.InvalidSessionException;
import org.sigmah.shared.security.UnauthorizedAccessException;
import org.sigmah.shared.servlet.ServletConstants;
import org.sigmah.shared.servlet.ServletConstants.Servlet;
import org.sigmah.shared.servlet.ServletConstants.ServletMethod;
import org.sigmah.shared.util.FileType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.google.gwt.http.client.Response;
import com.google.inject.Inject;
import com.google.inject.Provider;
/**
* <p>
* Abstract additional servlet which secures access to its methods.
* All additional <em>secured</em> servlet should inherit this abstract layer.
* </p>
* The declared child servlet methods must have this signature:
*
* <pre>
* [method_name] (HttpServletRequest request, HttpServletResponse response, ServletExecutionContext context) throws [any exception(s)];
* </pre>
*
* @author Denis Colliot (dcolliot@ideia.fr)
*/
public abstract class AbstractServlet extends HttpServlet {
/**
* Serial id.
*/
private static final long serialVersionUID = 301456647415093255L;
/**
* Logger.
*/
private static final Logger LOG = LoggerFactory.getLogger(AbstractServlet.class);
/**
* HTML servlet error page filename.
*/
private static final String ERROR_PAGE_NAME = "servlet-error.html";
/**
* Injected secure session validator.
*/
@Inject
private SecureSessionValidator secureSessionValidator;
/**
* Injected application properties service.
*/
@Inject
private Properties properties;
/**
* Injected {@link EntityManager} provider.
*/
@Inject
private Provider<EntityManager> entityManagerProvider;
/**
* Injected {@link Mapper}.
*/
@Inject
private Mapper mapper;
/**
* HTML error page template.
*/
private String template;
/**
* {@inheritDoc}
*/
@Override
public final void init(final ServletConfig config) throws ServletException {
if (LOG.isDebugEnabled()) {
LOG.debug("Reading HTML error page template.");
}
try (final InputStream is = getClass().getResourceAsStream(ERROR_PAGE_NAME)) {
template = Servlets.readAll(is);
// Replaces tags.
template = template.replaceAll(Pattern.quote("<!-- ${AppName} -->"), Matcher.quoteReplacement(properties.getProperty(PropertyKey.APP_NAME)));
} catch (final IOException e) {
throw new ServletException("Cannot read the HTML page template.", e);
}
}
/**
* {@inheritDoc}
*/
@Override
public final void log(final String msg) {
this.log(msg, null);
}
/**
* {@inheritDoc}
*/
@Override
public final void log(final String message, final Throwable t) {
if (t != null) {
if (LOG.isErrorEnabled()) {
LOG.error(message, t);
}
} else {
if (LOG.isDebugEnabled()) {
LOG.debug(message);
}
}
}
/**
* Secures the given {@code servletMethod} execution.
*
* @param request
* The HTTP request.
* @param response
* The HTTP response.
* @param servletMethod
* Java servlet method to execute once user session has been secured.
* @throws ServletException
* If the servlet execution fails.
*/
private void secureServlet(final HttpServletRequest request, final HttpServletResponse response, final Method servletMethod) throws ServletException {
if (servletMethod == null) {
if (LOG.isErrorEnabled()) {
LOG.error("The given servlet method {} is null.", servletMethod);
}
throw new IllegalArgumentException("Servlet method is required.");
}
User user = null;
try {
// Validates the user session and user access.
final String authenticationToken = request.getParameter(ServletConstants.AUTHENTICATION_TOKEN);
final String originPageToken = request.getParameter(ServletConstants.ORIGIN_PAGE_TOKEN);
final String servletPath = request.getRequestURI().replaceFirst(ServletModule.ENDPOINT, "");
final Servlet servletEnum = Servlet.fromPathName(servletPath);
final ServletMethod servletMethodEnum = ServletMethod.fromMethodName(servletMethod.getName());
final Access access = secureSessionValidator.validate(authenticationToken, servletEnum, servletMethodEnum, originPageToken);
user = access.getUser();
switch (access.getAccessType()) {
case INVALID_SESSION:
if (LOG.isDebugEnabled()) {
LOG.debug("SERVLET METHOD EXECUTION FAILED - Servlet method: '{}' ; User: '{}' ; Error: Invalid auth token '{}'.", servletMethod,
Servlets.logUser(user), authenticationToken);
}
throw new InvalidSessionException("Your session is no longer valid.");
case UNAUTHORIZED_ACCESS:
if (LOG.isDebugEnabled()) {
LOG.debug("SERVLET METHOD EXECUTION FAILED - Servlet method: '{}' ; User: '{}' ; Error: Unauthorized process.", servletMethod,
Servlets.logUser(user));
}
throw new UnauthorizedAccessException("You are not authorized to execute this process.");
default:
// Access granted, executes servlet method.
if (LOG.isDebugEnabled()) {
LOG.debug("SERVLET METHOD EXECUTION GRANTED - Servlet method: '{}' ; User: '{}'.", servletMethod, Servlets.logUser(user));
}
// Activate filters into hibernate session.
DomainFilters.applyUserFilter(user, entityManagerProvider.get());
final StopWatch chrono = new StopWatch();
chrono.start();
servletMethod.setAccessible(true);
servletMethod.invoke(this, request, response, new ServletExecutionContext(access.getUser(), request, originPageToken));
if (LOG.isDebugEnabled()) {
LOG.debug("SERVLET METHOD '{}' EXECUTED IN {} MS.", servletMethod, chrono.getTime());
}
}
} catch (final InvocationTargetException e) {
// NO NEED TO LOG EXCEPTION HERE.
if (e.getTargetException() instanceof ServletException) {
// Servlet exception.
throw (ServletException) e.getTargetException();
} else if (e.getTargetException() instanceof ConstraintViolationException) {
// Bean validation failed.
final ConstraintViolationException cve = (ConstraintViolationException) e.getTargetException();
if (LOG.isErrorEnabled()) {
LOG.error("SERVLET METHOD EXECUTION FAILED - Servlet method: '"
+ servletMethod
+ "' ; User: '"
+ Servlets.logUser(user)
+ "' ; Error: A bean validation failed during servlet method execution. Consider performing the validation on client-side.\n"
+ Servlets.logConstraints(cve.getConstraintViolations()));
}
throw new ServletException(e.getCause().getMessage(), cve);
} else {
throw new ServletException(e.getCause().getMessage(), e.getTargetException());
}
} catch (final Throwable e) {
// Server unknown error.
throw new ServletException(e.getMessage(), e);
}
}
/**
* Retrieves {@code java} method to execute from {@code request} and calls
* {@link #secureServlet(HttpServletRequest, HttpServletResponse, Method)}.
*
* @param servletMethodName
* The real servlet method name ({@code doGet}, {@code doPost}, etc.).
* @param request
* The HTTP request.
* @param response
* The HTTP response.
* @throws ServletException
* If an error occurs while executing servlet process.
*/
private void secureServletMethod(final String servletMethodName, final HttpServletRequest request, final HttpServletResponse response)
throws ServletException {
if (LOG.isDebugEnabled()) {
LOG.debug("Executing specific '{}' servlet method.", servletMethodName);
}
// Retrieving method name from request.
final String methodName = request.getParameter(ServletConstants.SERVLET_METHOD);
boolean popupDestination = false;
try {
if (LOG.isDebugEnabled()) {
LOG.debug("Retrieving by reflection the given servlet method '{}'.", methodName);
}
if (StringUtils.isBlank(methodName)) {
return;
}
// Retrieving servlet method.
final Method servletMethod = getClass().getDeclaredMethod(methodName, HttpServletRequest.class, HttpServletResponse.class, ServletExecutionContext.class);
final ServletMethod servletMethodEnum = ServletMethod.fromMethodName(servletMethod.getName());
popupDestination = servletMethodEnum != null && servletMethodEnum.isPopup();
// Secure servlet method.
secureServlet(request, response, servletMethod);
} catch (final StatusServletException e) {
handleException(request, response, servletMethodName, popupDestination, e, e.getStatusCode());
} catch (final Throwable caught) {
handleException(request, response, servletMethodName, popupDestination, caught, Response.SC_INTERNAL_SERVER_ERROR);
}
}
// ---------------------------------------------------------------------------------------
//
// SECURED DEFAULT SERVLET METHODS.
//
// ---------------------------------------------------------------------------------------
/**
* Servlet {@code GET} method name.
*/
private static final String DO_GET_METHOD_NAME = "doGet";
@Override
final protected void doPost(final HttpServletRequest request, final HttpServletResponse response) throws ServletException, IOException {
secureServletMethod("doPost", request, response);
}
@Override
final protected void doGet(final HttpServletRequest request, final HttpServletResponse response) throws ServletException, IOException {
secureServletMethod(DO_GET_METHOD_NAME, request, response);
}
@Override
final protected void doDelete(final HttpServletRequest request, final HttpServletResponse response) throws ServletException, IOException {
secureServletMethod("doDelete", request, response);
}
@Override
final protected void doOptions(final HttpServletRequest request, final HttpServletResponse response) throws ServletException, IOException {
secureServletMethod("doOptions", request, response);
}
@Override
final protected void doHead(final HttpServletRequest request, final HttpServletResponse response) throws ServletException, IOException {
secureServletMethod("doHead", request, response);
}
@Override
final protected void doPut(final HttpServletRequest request, final HttpServletResponse response) throws ServletException, IOException {
secureServletMethod("doPut", request, response);
}
@Override
final protected void doTrace(final HttpServletRequest request, final HttpServletResponse response) throws ServletException, IOException {
secureServletMethod("doTrace", request, response);
}
// ---------------------------------------------------------------------------------------
//
// UTILITY METHODS.
//
// ---------------------------------------------------------------------------------------
/**
* Returns the application {@link Properties} service.
*
* @return the application {@link Properties} service, never {@code null}.
*/
protected final Properties prop() {
return properties;
}
/**
* Returns the application {@link Mapper} service.
*
* @return the application {@link Mapper} service, never {@code null}.
*/
protected final Mapper mapper() {
return mapper;
}
/**
* Returns the given {@code paramKey} corresponding value from the {@code request}.
*
* @param request
* The HTTP request.
* @param paramKey
* The {@link RequestParameter} key.
* @param acceptNull
* {@code true} to accept a {@code null} value result, {@code false} to throw a
* {@link StatusServletException} if the value is {@code null}.
* @return The given {@code paramKey} corresponding value from the {@code request}.
* @throws StatusServletException
* If the parameter value is {@code null} <b>and</b> {@code acceptNull} is set to {@code false}.
*/
protected static final String getParameter(final HttpServletRequest request, final RequestParameter paramKey, boolean acceptNull)
throws StatusServletException {
final String value = ClientUtils.deletePreTags(request.getParameter(paramKey.getRequestName()));
if (StringUtils.isBlank(value) || "null".equals(value)) {
if (acceptNull) {
return null;
}
if (LOG.isWarnEnabled()) {
LOG.warn("No value for parameter key '{}'.", paramKey);
}
throw new StatusServletException(Response.SC_BAD_REQUEST);
}
return value;
}
/**
* Returns the given {@code paramKey} corresponding {@link Integer} value from the {@code request}.
*
* @param request
* The HTTP request.
* @param paramKey
* The {@link RequestParameter} key.
* @param acceptNull
* {@code true} to accept a {@code null} value result, {@code false} to throw a
* {@link StatusServletException} if the value is {@code null}.
* @return The given {@code paramKey} corresponding {@link Integer} value from the {@code request}.
* @throws StatusServletException
* If the parameter value is {@code null} <b>and</b> {@code acceptNull} is set to {@code false}.
*/
protected static final Integer getIntegerParameter(final HttpServletRequest request, final RequestParameter paramKey, boolean acceptNull)
throws StatusServletException {
final String intValue = getParameter(request, paramKey, acceptNull);
if (StringUtils.isBlank(intValue)) {
if (acceptNull) {
return null;
}
throw new StatusServletException(Response.SC_BAD_REQUEST);
}
try {
return Integer.parseInt(intValue);
} catch (final NumberFormatException e) {
LOG.error("Error while parsing the integer parameter '" + intValue + "'.", e);
throw new StatusServletException(Response.SC_BAD_REQUEST);
}
}
/**
* Returns the given {@code paramKey} corresponding {@link Boolean} value from the {@code request}.
*
* @param request
* The HTTP request.
* @param paramKey
* The {@link RequestParameter} key.
* @param acceptNull
* {@code true} to accept a {@code null} value result, {@code false} to throw a
* {@link StatusServletException} if the value is {@code null}.
* @return The given {@code paramKey} corresponding {@link Boolean} value from the {@code request}.
* @throws StatusServletException
* If the parameter value is {@code null} <b>and</b> {@code acceptNull} is set to {@code false}.
*/
protected static final Boolean getBooleanParameter(final HttpServletRequest request, final RequestParameter paramKey, boolean acceptNull)
throws StatusServletException {
final String booleanValue = getParameter(request, paramKey, acceptNull);
if (StringUtils.isBlank(booleanValue)) {
if (acceptNull) {
return null;
}
throw new StatusServletException(Response.SC_BAD_REQUEST);
}
try {
return Boolean.parseBoolean(booleanValue);
} catch (final NumberFormatException e) {
throw new StatusServletException(Response.SC_BAD_REQUEST);
}
}
/**
* <p>
* Handles the {@code caught} exception.
* </p>
* <p>
* <ul>
* <li>If {@code GET} access (direct access and not ajax call), writes into the {@code response} the HTML error page
* content.</li>
* <li>Else, writes into the {@code response} the given {@code errorCode} as header and
* {@link ServletConstants#ERROR_RESPONSE_CONTENT} as content.</li>
* </ul>
* </p>
*
* @param request
* The HTTP request.
* @param response
* The HTTP response.
* @param servletMethodName
* The {@link ServletMethod} value.
* @param popupDestination
* Is the servlet process destined to be displayed into a pop-up window?
* @param caught
* The throwable.
* @param errorCode
* The error code set on the {@code response}.
*/
private void handleException(final HttpServletRequest request, final HttpServletResponse response, final String servletMethodName,
final boolean popupDestination, final Throwable caught, final int errorCode) {
if (LOG.isErrorEnabled()) {
LOG.error("Exception while executing '" + getClass().getName() + '#' + servletMethodName + "' servlet method.", caught);
}
try {
response.setContentType(FileType.HTML.getContentType());
final String htmlMessage = caught.getClass().getSimpleName() + " : " + caught.getMessage();
final boolean ajaxCall = ClientUtils.isTrue(request.getParameter(ServletConstants.AJAX));
if (DO_GET_METHOD_NAME.equals(servletMethodName) && !ajaxCall) {
// If the servlet method is executed using HTTP {@code GET} method.
String html = template;
html = html.replaceAll(Pattern.quote("<!-- ${MessageContent} -->"), Matcher.quoteReplacement(htmlMessage));
html = html.replaceAll(Pattern.quote("<!-- ${ButtonDisplay} -->"), Servlets.cssDisplay(popupDestination));
response.setCharacterEncoding(Servlets.UTF8_CHARSET);
response.getWriter().write(html);
} else {
// Other method.
response.setStatus(errorCode);
response.getWriter().write(ServletConstants.buildErrorResponse(errorCode));
}
} catch (final IOException ioe) {
// Nothing to do ; 'getWriter()' has just failed.
if (LOG.isErrorEnabled()) {
LOG.error("'getWriter()' method has raised an exception.", ioe);
}
}
}
}