/** * Copyright (c) Codice Foundation * <p> * This is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser * General Public License as published by the Free Software Foundation, either version 3 of the * License, or any later version. * <p> * This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without * even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU * Lesser General Public License for more details. A copy of the GNU Lesser General Public License * is distributed along with this program and can be found at * <http://www.gnu.org/licenses/lgpl.html>. */ package org.codice.ui.admin.docs; import java.io.IOException; import java.io.NotSerializableException; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.io.UnsupportedEncodingException; import java.net.URLDecoder; import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Objects; import java.util.Optional; import java.util.function.BiFunction; import java.util.regex.Pattern; import java.util.stream.Collectors; import javax.servlet.ServletConfig; import javax.servlet.ServletException; import javax.servlet.http.HttpServlet; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequestWrapper; import javax.servlet.http.HttpServletResponse; import org.apache.commons.lang.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import spark.globalstate.ServletFlag; import spark.http.matching.MatcherFilter; import spark.route.ServletRoutes; import spark.servlet.SparkApplication; import spark.staticfiles.StaticFilesConfiguration; import spark.utils.CollectionUtils; /** * Servlet that can be configured through a web.xml to serve {@code SparkApplication}s. * Needs to be initialized with the {@code applicationName} init parameter (or through a direct * call to {@link #setSparkApplications(List)}) to the list of application classes defining routes. * An optional {@code wrapperSupplierName} parameter can be provided (or directly set through the * {@link #setRequestSupplier(BiFunction)} method) to provide added path flexibility. */ public class SparkServlet extends HttpServlet { private static final Logger LOGGER = LoggerFactory.getLogger(SparkServlet.class); private static final String SLASH_WILDCARD = "/*"; private static final String SLASH = "/"; private static final String APPLICATION_CLASS_PARAM = "applicationName"; private static final String WRAPPER_SUPPLIER_PARAM_NAME = "wrapperSupplierName"; private static final String FILTER_MAPPING_PARAM = "filterMappingUrlPattern"; private static final BiFunction<HttpServletRequest, String, HttpServletRequestWrapper> DEFAULT_REQ_FUNC = new BiFunction<HttpServletRequest, String, HttpServletRequestWrapper>() { @Override public HttpServletRequestWrapper apply(HttpServletRequest req, String relativePath) { return new HttpServletRequestWrapper(req) { @Override public String getPathInfo() { return relativePath; } @Override public String getRequestURI() { return relativePath; } }; } }; private BiFunction<HttpServletRequest, String, HttpServletRequestWrapper> requestSupplier; private final List<SparkApplication> sparkApplications = Collections.synchronizedList(new ArrayList<>()); private String filterMappingPattern = null; private String filterPath; private MatcherFilter matcherFilter; public synchronized void setRequestSupplier( BiFunction<HttpServletRequest, String, HttpServletRequestWrapper> requestSupplier) { this.requestSupplier = requestSupplier; } public void setSparkApplications(List<SparkApplication> sparkApplications) { this.sparkApplications.addAll(sparkApplications); } public void setFilterMappingPattern(String filterMappingPattern) { this.filterMappingPattern = filterMappingPattern; } @Override public void init(ServletConfig config) throws ServletException { super.init(config); ServletFlag.runFromServlet(); populateWrapperSupplier(config); populateSparkApplications(config); sparkApplications.stream() .sequential() .forEach(SparkApplication::init); filterPath = getConfigPath(filterMappingPattern, config); matcherFilter = new MatcherFilter(ServletRoutes.get(), StaticFilesConfiguration.servletInstance, false, false); } @Override public void destroy() { sparkApplications.stream() .filter(Objects::nonNull) .forEach(SparkApplication::destroy); } @Override protected void service(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { final String relativePath = getRelativePath(req, filterPath); HttpServletRequestWrapper requestWrapper; synchronized (this) { requestWrapper = requestSupplier.apply(req, relativePath); } // handle static resources boolean consumed = StaticFilesConfiguration.servletInstance.consume(req, resp); if (consumed) { return; } matcherFilter.doFilter(requestWrapper, resp, null); } private static String getConfigPath(String filterMappingPattern, ServletConfig config) { String result = Optional.ofNullable(filterMappingPattern) .orElse(config.getInitParameter(FILTER_MAPPING_PARAM)); if (result == null || result.equals(SLASH_WILDCARD)) { return ""; } else if (!result.startsWith(SLASH) || !result.endsWith(SLASH_WILDCARD)) { throw new RuntimeException(String.format( "The %s must start with '/' and end with '/*'. Instead it is: %s", FILTER_MAPPING_PARAM, result)); } return result.substring(1, result.length() - 1); } private static String getRelativePath(HttpServletRequest request, String filterPath) { String path = request.getRequestURI() .substring(request.getContextPath() .length()); if (path.length() > 0) { path = path.substring(1); } if (filterPath.equals(path + SLASH)) { path += SLASH; } if (path.startsWith(filterPath)) { path = path.substring(filterPath.length()); } if (!path.startsWith(SLASH)) { path = SLASH + path; } try { path = URLDecoder.decode(path, "UTF-8"); } catch (UnsupportedEncodingException ignore) { // this can't really ever happen } LOGGER.debug("Relative path = {}", path); return path; } private synchronized void populateWrapperSupplier(ServletConfig config) { // Do not override an injected supplier through initialization if (requestSupplier != null) { return; } String wrapperSupplierName = config.getInitParameter(WRAPPER_SUPPLIER_PARAM_NAME); if (StringUtils.isNotBlank(wrapperSupplierName)) { try { Class<?> wrapperClass = Class.forName(wrapperSupplierName); if (BiFunction.class.isAssignableFrom(wrapperClass)) { requestSupplier = (BiFunction<HttpServletRequest, String, HttpServletRequestWrapper>) wrapperClass.newInstance(); } } catch (ClassNotFoundException | InstantiationException | IllegalAccessException e) { LOGGER.debug( "Error converting {} to BiFunction<HttpServletRequest, String, HttpServletRequestWrapper>; " + "falling back to default", wrapperSupplierName, e); } } if (requestSupplier == null) { requestSupplier = DEFAULT_REQ_FUNC; } } private void populateSparkApplications(ServletConfig config) { // Do not override injected spark applications through initialization if (!CollectionUtils.isEmpty(sparkApplications)) { return; } String applications = config.getInitParameter(APPLICATION_CLASS_PARAM); if (StringUtils.isNotBlank(applications)) { sparkApplications.addAll(Pattern.compile(",") .splitAsStream(applications) .map(String::trim) .map(this::getApplication) .filter(Objects::nonNull) .collect(Collectors.toList())); } } private SparkApplication getApplication(String applicationClassName) { try { Class<?> appClass = Class.forName(applicationClassName); if (SparkApplication.class.isAssignableFrom(appClass)) { return SparkApplication.class.cast(appClass.newInstance()); } } catch (ClassNotFoundException | InstantiationException | IllegalAccessException e) { LOGGER.debug("Error converting {} to SparkApplication", applicationClassName, e); } return null; } private void writeObject(ObjectOutputStream stream) throws IOException { throw new NotSerializableException(getClass().getName()); } private void readObject(ObjectInputStream stream) throws IOException, ClassNotFoundException { throw new NotSerializableException(getClass().getName()); } }