/******************************************************************************* * * Copyright (c) 2010-2011 Sonatype, Inc. * * All rights reserved. This program and the accompanying materials * are made available under the terms of the Eclipse Public License v1.0 * which accompanies this distribution, and is available at * http://www.eclipse.org/legal/epl-v10.html * * Contributors: * * * * *******************************************************************************/ package org.hudsonci.servlets.internal; import org.hudsonci.servlets.ServletRegistration; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import javax.servlet.Filter; import javax.servlet.FilterChain; import javax.servlet.FilterConfig; import javax.servlet.Servlet; import javax.servlet.ServletConfig; import javax.servlet.ServletContext; import javax.servlet.ServletException; import javax.servlet.ServletRequest; import javax.servlet.ServletResponse; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequestWrapper; import javax.servlet.http.HttpServletResponse; import java.io.IOException; import java.util.Enumeration; import java.util.Iterator; import static com.google.common.base.Preconditions.checkNotNull; /** * Wraps a {@link Servlet} as a {@link Filter} for installation via * {@link hudson.util.PluginServletFilter}. * * @author <a href="mailto:jason@planet57.com">Jason Dillon</a> * @since 2.1.0 */ public class ServletRegistrationFilterAdapter implements Filter { private static final Logger log = LoggerFactory.getLogger(ServletRegistrationFilterAdapter.class); private final ServletRegistration registration; private final Servlet servlet; private final String uriPrefix; private boolean enabled; public ServletRegistrationFilterAdapter(final ServletRegistration registration) throws Exception { this.registration = checkNotNull(registration); this.servlet = createServlet(); if (registration.getName() == null) { registration.setName(servlet.getClass().getName()); } uriPrefix = registration.getUriPrefix(); if (uriPrefix == null) { throw new IllegalArgumentException("Registration missing uriPrefix"); } } public boolean isEnabled() { return enabled; } public void setEnabled(final boolean enabled) { this.enabled = enabled; } private Servlet createServlet() throws Exception { Servlet servlet = registration.getServlet(); if (servlet != null) { return servlet; } Class<? extends Servlet> type = registration.getServletType(); if (type != null) { return type.newInstance(); } throw new IllegalArgumentException("Registration missing servlet or servlet type"); } public void init(final FilterConfig config) throws ServletException { checkNotNull(config); servlet.init(new ServletConfig() { public String getServletName() { return registration.getName(); } public ServletContext getServletContext() { return config.getServletContext(); } public Enumeration getInitParameterNames() { final Iterator<String> iter = registration.getParameters().keySet().iterator(); return new Enumeration() { public boolean hasMoreElements() { return iter.hasNext(); } public Object nextElement() { return iter.next(); } }; } public String getInitParameter(final String name) { return registration.getParameters().get(name); } }); } public void destroy() { servlet.destroy(); } public void doFilter(final ServletRequest request, final ServletResponse response, final FilterChain chain) throws IOException, ServletException { assert chain != null; if (isEnabled() && request instanceof HttpServletRequest && response instanceof HttpServletResponse) { doFilter((HttpServletRequest) request, (HttpServletResponse) response, chain); } else { chain.doFilter(request, response); } } private void doFilter(final HttpServletRequest request, final HttpServletResponse response, final FilterChain chain) throws IOException, ServletException { assert request != null; assert response != null; assert chain != null; String contextPath = request.getContextPath(); if (!contextPath.endsWith("/") && !uriPrefix.startsWith("/")) { contextPath = contextPath + '/'; } if (request.getRequestURI().startsWith(contextPath + uriPrefix)) { // Wrap the request to augment the servlet uriPrefix HttpServletRequestWrapper req = new HttpServletRequestWrapper(request) { @Override public String getServletPath() { return String.format("/%s", uriPrefix); } }; servlet.service(req, response); } else { chain.doFilter(request, response); } } }