package org.jboss.pitbull.servlet;
import javax.servlet.HttpMethodConstraintElement;
import javax.servlet.MultipartConfigElement;
import javax.servlet.Servlet;
import javax.servlet.ServletConfig;
import javax.servlet.ServletContext;
import javax.servlet.ServletRegistration;
import javax.servlet.ServletSecurityElement;
import javax.servlet.SingleThreadModel;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Enumeration;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.regex.Pattern;
/**
* @author <a href="mailto:bill@burkecentral.com">Bill Burke</a>
* @version $Revision: 1 $
*/
public class DeploymentServletRegistration extends DeploymentRegistration implements ServletRegistration.Dynamic
{
protected int loadLevel;
protected ServletSecurityElement constraint;
protected MultipartConfigElement multipartConfig;
protected String runAsRole;
protected Servlet servlet;
protected String className;
protected Class<? extends Servlet> servletClass;
protected List<String> urlPatterns = new ArrayList<String>();
protected List<Pattern> patterns = new ArrayList<Pattern>();
protected InstanceManager im;
protected ServletConfig config;
protected boolean initialized;
protected boolean perRequest;
public DeploymentServletRegistration(String name, String className, DeploymentServletContext ctx)
{
this.name = name;
this.className = className;
this.servletContext = ctx;
}
public DeploymentServletRegistration(String name, Class<? extends Servlet> servletClass, DeploymentServletContext ctx)
{
this.name = name;
this.servletClass = servletClass;
this.servletContext = ctx;
}
public DeploymentServletRegistration(String name, Servlet servlet, DeploymentServletContext ctx)
{
this.name = name;
this.servlet = servlet;
this.servletContext = ctx;
}
public boolean matchesPattern(String pattern)
{
pattern = pattern.replace("*", "{WILDCARD}");
for (Pattern p : patterns)
{
if (p.matcher(pattern).matches()) return true;
}
return false;
}
public void initialize(InstanceManager im, ClassLoader loader) throws Exception
{
this.im = im;
this.config = new ServletConfig()
{
@Override
public String getServletName()
{
return DeploymentServletRegistration.this.name;
}
@Override
public ServletContext getServletContext()
{
return DeploymentServletRegistration.this.servletContext;
}
@Override
public String getInitParameter(String name)
{
return DeploymentServletRegistration.this.getInitParameter(name);
}
@Override
public Enumeration<String> getInitParameterNames()
{
final Iterator<String> it = DeploymentServletRegistration.this.initParameters.keySet().iterator();
return new Enumeration<String>()
{
@Override
public boolean hasMoreElements()
{
return it.hasNext();
}
@Override
public String nextElement()
{
return it.next();
}
};
}
};
if (servlet == null && servletClass == null)
{
servletClass = (Class<? extends Servlet>) loader.loadClass(className);
}
if (servlet != null)
{
servletClass = servlet.getClass();
}
if (SingleThreadModel.class.isAssignableFrom(servletClass))
{
perRequest = true;
}
if (loadLevel >= 0 && !perRequest)
{
initializeServlet();
}
}
protected void initializeServlet() throws Exception
{
if (servlet != null)
{
this.im.inject(servlet);
}
else
{
servlet = (Servlet) this.im.newInstance(servletClass);
}
servlet.init(config);
initialized = true;
}
public Servlet startRequest() throws Exception
{
if (initialized) return servlet;
initializeServlet();
return servlet;
}
public void endRequest() throws Exception
{
if (perRequest)
{
this.im.destroyInstance(servlet);
servlet = null;
initialized = false;
}
}
public int getLoadLevel()
{
return loadLevel;
}
public ServletSecurityElement getConstraint()
{
return constraint;
}
public MultipartConfigElement getMultipartConfig()
{
return multipartConfig;
}
public boolean isAsyncSupported()
{
return asyncSupported;
}
public List<Pattern> getPatterns()
{
return patterns;
}
@Override
public void setLoadOnStartup(int loadOnStartup)
{
loadLevel = loadOnStartup;
}
@Override
public Set<String> setServletSecurity(ServletSecurityElement constraint)
{
checkNullParameter(constraint);
Set<String> already = new HashSet<String>();
for (String urlPattern : urlPatterns)
{
Set<String> methodNames = new HashSet<String>();
methodNames.addAll(constraint.getMethodNames());
for (HttpMethodConstraintElement methodConstraint : constraint.getHttpMethodConstraints())
{
methodNames.add(methodConstraint.getMethodName());
}
already.addAll(servletContext.matchSecurityConstraint(urlPattern, methodNames));
}
if (already.size() > 0) return already;
this.constraint = constraint;
return already;
}
@Override
public void setMultipartConfig(MultipartConfigElement multipartConfig)
{
checkNullParameter(multipartConfig);
this.multipartConfig = multipartConfig;
}
@Override
public void setRunAsRole(String roleName)
{
checkNullParameter(roleName);
this.runAsRole = roleName;
}
@Override
public String getRunAsRole()
{
return runAsRole;
}
@Override
public String getClassName()
{
if (className != null) return className;
if (servlet != null) return servlet.getClass().getName();
if (servletClass != null) return servletClass.getName();
return null;
}
@Override
public Set<String> addMapping(String... urlPatterns)
{
Set<String> matches = new HashSet<String>();
for (String urlPattern : urlPatterns)
{
if (servletContext.matchesServletUrlPattern(urlPattern))
{
matches.add(urlPattern);
}
}
if (matches.isEmpty())
{
for (String urlPattern : urlPatterns)
{
this.urlPatterns.add(urlPattern);
urlPattern = urlPattern.replace("*", ".*");
this.patterns.add(Pattern.compile(urlPattern));
}
}
return matches;
}
@Override
public Collection<String> getMappings()
{
return urlPatterns;
}
}