/* * Copyright 2014-2016 Red Hat, Inc, and individual contributors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.projectodd.wunderboss.as.web; import org.jboss.logging.Logger; import org.projectodd.wunderboss.CompletableFuture; import org.projectodd.wunderboss.Options; import org.projectodd.wunderboss.as.ActionConduit; import org.projectodd.wunderboss.web.Web; import javax.servlet.Filter; import javax.servlet.FilterRegistration; import javax.servlet.Servlet; import javax.servlet.ServletContext; import javax.servlet.ServletRegistration; import java.util.Collections; import java.util.HashSet; import java.util.Map; import java.util.Set; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicLong; import static org.projectodd.wunderboss.web.Web.RegisterOption.PATH; import static org.projectodd.wunderboss.web.Web.RegisterOption.SERVLET_NAME; public class ServletWeb implements Web<Servlet> { public ServletWeb(String name, ServletContext servletContext, ActionConduit actionConduit, AtomicLong sharedTimeout) { this.name = name; this.servletContext = servletContext; this.actionConduit = actionConduit; this.sharedTimeout = sharedTimeout; } @Override public boolean registerHandler(Servlet handler, Map<RegisterOption, Object> opts) { return registerServlet(handler, opts); } @Override public boolean registerServlet(final Servlet servlet, Map<RegisterOption, Object> opts) { final Options<RegisterOption> options = new Options<>(opts); final String context = options.getString(PATH); final String servletName = options.getString(SERVLET_NAME, context); // TODO: Take mapping instead of path for servlets? final String mapping = context.endsWith("/") ? context + "*" : context + "/*"; final CompletableFuture<Void> servletFuture = new CompletableFuture<>(); final Runnable action = new Runnable() { @Override public void run() { try { ServletRegistration.Dynamic servletRegistration = servletContext.addServlet(servletName, servlet); servletRegistration.addMapping(mapping); servletRegistration.setLoadOnStartup(1); servletRegistration.setAsyncSupported(true); servletRegistration.setInitParameter(ORIGINAL_CONTEXT, context); Map<String, Filter> filterMap = (Map<String, Filter>) options.get(RegisterOption.FILTER_MAP); if (filterMap != null) { for (Map.Entry<String, Filter> entry : filterMap.entrySet()) { FilterRegistration.Dynamic filter = servletContext.addFilter(entry.getKey() + servletName, entry.getValue()); filter.setAsyncSupported(true); filter.addMappingForUrlPatterns(null, false, mapping); } } servletFuture.complete(null); } catch (Exception e) { servletFuture.completeExceptionally(e); } } }; if (!this.actionConduit.add(action)) { throw new IllegalStateException("Can't add servlet after servlet init has completed"); } try { // this shares a timeout with the ServletListener so we're chipping // away at the same pool of time. It is responsible for noticing that // the full timeout has expired final long now = System.currentTimeMillis(); servletFuture.get(this.sharedTimeout.get(), TimeUnit.MILLISECONDS); this.sharedTimeout.addAndGet(System.currentTimeMillis() - now); } catch (InterruptedException | TimeoutException | ExecutionException e) { log.error("Registering servlet failed", e); throw new RuntimeException(e instanceof ExecutionException ? e.getCause() : e); } return false; } @Override public boolean unregister(Map<RegisterOption, Object> opts) { log.warn("Removing a servlet is a no-op in container."); return false; } @Override public Set<String> registeredContexts() { Set<String> contexts = new HashSet<>(); for(ServletRegistration each : this.servletContext.getServletRegistrations().values()) { contexts.add(each.getInitParameter(ORIGINAL_CONTEXT)); } return Collections.unmodifiableSet(contexts); } @Override public void start() { // no-op on WildFly } @Override public void stop() { // no-op on WildFly } @Override public boolean isRunning() { return true; } @Override public String name() { return this.name; } private final String name; private final ServletContext servletContext; private final ActionConduit actionConduit; private final AtomicLong sharedTimeout; private static final String ORIGINAL_CONTEXT = "original-context"; private static final Logger log = Logger.getLogger(ServletWeb.class); }