/* * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS HEADER. * * Copyright (c) 2015-2017 Oracle and/or its affiliates. All rights reserved. * * The contents of this file are subject to the terms of either the GNU * General Public License Version 2 only ("GPL") or the Common Development * and Distribution License("CDDL") (collectively, the "License"). You * may not use this file except in compliance with the License. You can * obtain a copy of the License at * http://glassfish.java.net/public/CDDL+GPL_1_1.html * or packager/legal/LICENSE.txt. See the License for the specific * language governing permissions and limitations under the License. * * When distributing the software, include this License Header Notice in each * file and include the License file at packager/legal/LICENSE.txt. * * GPL Classpath Exception: * Oracle designates this particular file as subject to the "Classpath" * exception as provided by Oracle in the GPL Version 2 section of the License * file that accompanied this code. * * Modifications: * If applicable, add the following below the License Header, with the fields * enclosed by brackets [] replaced by your own identifying information: * "Portions Copyright [year] [name of copyright owner]" * * Contributor(s): * If you wish your version of this file to be governed by only the CDDL or * only the GPL Version 2, indicate your decision by adding "[Contributor] * elects to include this software in this distribution under the [CDDL or GPL * Version 2] license." If you don't indicate a single choice of license, a * recipient has the option to distribute your version of this file under * either the CDDL, the GPL Version 2 or to extend the choice of license to * its licensees as provided above. However, if you add GPL Version 2 code * and therefore, elected the GPL Version 2 license, then the option applies * only if the new code is made subject to such option by the copyright * holder. */ package org.glassfish.jersey.tests.integration.servlet_request_wrapper_binding2; import java.io.BufferedReader; import java.io.IOException; import java.io.PrintWriter; import java.io.UnsupportedEncodingException; import java.lang.annotation.Annotation; import java.lang.reflect.Type; import java.security.Principal; import java.util.Collection; import java.util.Enumeration; import java.util.HashSet; import java.util.Locale; import java.util.Map; import java.util.Set; import java.util.function.Supplier; import javax.ws.rs.core.GenericType; import javax.inject.Inject; import javax.inject.Provider; import javax.servlet.AsyncContext; import javax.servlet.DispatcherType; import javax.servlet.RequestDispatcher; import javax.servlet.ServletContext; import javax.servlet.ServletException; import javax.servlet.ServletInputStream; import javax.servlet.ServletOutputStream; import javax.servlet.ServletRequest; import javax.servlet.ServletResponse; import javax.servlet.http.Cookie; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequestWrapper; import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponseWrapper; import javax.servlet.http.HttpSession; import javax.servlet.http.Part; import org.glassfish.jersey.inject.hk2.DelayedHk2InjectionManager; import org.glassfish.jersey.inject.hk2.ImmediateHk2InjectionManager; import org.glassfish.jersey.internal.inject.AbstractBinder; import org.glassfish.jersey.internal.inject.InjectionManager; import org.glassfish.jersey.internal.inject.ReferencingFactory; import org.glassfish.jersey.internal.util.collection.Ref; import org.glassfish.jersey.process.internal.RequestScoped; import org.glassfish.jersey.server.ResourceConfig; import org.glassfish.jersey.server.spi.ComponentProvider; import org.glassfish.jersey.server.spi.RequestScopedInitializer; import org.glassfish.jersey.servlet.internal.spi.NoOpServletContainerProvider; import org.glassfish.jersey.servlet.internal.spi.RequestContextProvider; import org.glassfish.jersey.servlet.internal.spi.RequestScopedInitializerProvider; import org.glassfish.hk2.api.DescriptorType; import org.glassfish.hk2.api.DescriptorVisibility; import org.glassfish.hk2.api.PerLookup; import org.glassfish.hk2.api.ServiceHandle; import org.glassfish.hk2.api.ServiceLocator; import org.glassfish.hk2.api.TypeLiteral; import org.glassfish.hk2.utilities.AbstractActiveDescriptor; import org.glassfish.hk2.utilities.ServiceLocatorUtilities; import org.jvnet.hk2.internal.ServiceHandleImpl; /** * Servlet container provider that wraps the original Servlet request/response. * The request wrapper contains a direct reference to the underlying container request * in case it gets injected into a request scoped component. * * @author Jakub Podlesak (jakub.podlesak at oracle.com) */ public class RequestResponseWrapperProvider extends NoOpServletContainerProvider { private final Type REQUEST_TYPE = (new TypeLiteral<Ref<HttpServletRequestWrapper>>() { }).getType(); private final Type RESPONSE_TYPE = (new TypeLiteral<Ref<HttpServletResponseWrapper>>() { }).getType(); public static class DescriptorProvider implements ComponentProvider { @Override public void initialize(InjectionManager injectionManager) { ServiceLocator locator = getServiceLocator(injectionManager); ServiceLocatorUtilities.addOneDescriptor(locator, new HttpServletRequestDescriptor(locator)); } @Override public boolean bind(Class<?> component, Set<Class<?>> providerContracts) { return false; } @Override public void done() { // nop } } /** * Subclass standard wrapper so that we make 100 % sure we are getting the right type. * It is also final, i.e. not proxiable, which we workaround by using custom http servlet request impl. */ public static final class RequestWrapper extends HttpServletRequestWrapper { public RequestWrapper(HttpServletRequest request) { super(request); } } /** * Subclass standard wrapper so that we make 100 % sure we are getting the right type. * It is also final, i.e. not proxiable, which we workaround by using custom http servlet response impl. */ public static final class ResponseWrapper extends HttpServletResponseWrapper { public ResponseWrapper(HttpServletResponse response) { super(response); } } @Override public boolean bindsServletRequestResponse() { return true; } @Override public RequestScopedInitializerProvider getRequestScopedInitializerProvider() { return new RequestScopedInitializerProvider() { @Override public RequestScopedInitializer get(final RequestContextProvider context) { return new RequestScopedInitializer() { @Override public void initialize(InjectionManager injectionManager) { ServiceLocator locator = getServiceLocator(injectionManager); locator.<Ref<HttpServletRequest>>getService(REQUEST_TYPE) .set(finalWrap(context.getHttpServletRequest())); locator.<Ref<HttpServletResponse>>getService(RESPONSE_TYPE) .set(finalWrap(context.getHttpServletResponse())); } }; } }; } private final class Binder extends AbstractBinder { @Override protected void configure() { bindFactory(HttpServletRequestReferencingFactory.class) .to(HttpServletRequestWrapper.class).in(RequestScoped.class); bindFactory(ReferencingFactory.<HttpServletRequestWrapper>referenceFactory()) .to(new GenericType<Ref<HttpServletRequestWrapper>>() { }).in(RequestScoped.class); bindFactory(HttpServletResponseFactory.class).to(HttpServletResponse.class); bindFactory(HttpServletResponseReferencingFactory.class) .to(HttpServletResponseWrapper.class).in(RequestScoped.class); bindFactory(ReferencingFactory.<HttpServletResponseWrapper>referenceFactory()) .to(new GenericType<Ref<HttpServletResponseWrapper>>() { }).in(RequestScoped.class); } } private static class HttpServletRequestDescriptor extends AbstractActiveDescriptor<HttpServletRequest> { static Set<Type> advertisedContracts = new HashSet<Type>() { { add(HttpServletRequest.class); } }; final ServiceLocator locator; volatile javax.inject.Provider<Ref<HttpServletRequestWrapper>> request; public HttpServletRequestDescriptor(final ServiceLocator locator) { super(advertisedContracts, PerLookup.class, null, new HashSet<Annotation>(), DescriptorType.CLASS, DescriptorVisibility.LOCAL, 0, null, null, null, null); this.locator = locator; } @Override public Class<?> getImplementationClass() { return HttpServletRequest.class; } @Override public Type getImplementationType() { return getImplementationClass(); } @Override public synchronized String getImplementation() { return HttpServletRequest.class.getName(); } @Override public HttpServletRequest create(ServiceHandle<?> serviceHandle) { if (request == null) { request = locator.getService(new TypeLiteral<Provider<Ref<HttpServletRequestWrapper>>>() { }.getType()); } boolean direct = false; if (serviceHandle instanceof ServiceHandleImpl) { final ServiceHandleImpl serviceHandleImpl = (ServiceHandleImpl) serviceHandle; final Class<? extends Annotation> scopeAnnotation = serviceHandleImpl.getOriginalRequest().getInjecteeDescriptor().getScopeAnnotation(); if (scopeAnnotation == RequestScoped.class || scopeAnnotation == null) { direct = true; } } return !direct ? new HttpServletRequestWrapper(new MyHttpServletRequestImpl() { @Override HttpServletRequest getHttpServletRequest() { return request.get().get(); } }) { @Override public ServletRequest getRequest() { return request.get().get(); } } : new HttpServletRequestWrapper(request.get().get()); } } private static class HttpServletResponseFactory implements Supplier<HttpServletResponse> { private final javax.inject.Provider<Ref<HttpServletResponseWrapper>> response; @Inject public HttpServletResponseFactory(javax.inject.Provider<Ref<HttpServletResponseWrapper>> response) { this.response = response; } @Override @PerLookup public HttpServletResponse get() { return new HttpServletResponseWrapper(new HttpServletResponse() { private HttpServletResponse getHttpServletResponse() { return response.get().get(); } @Override public void addCookie(Cookie cookie) { getHttpServletResponse().addCookie(cookie); } @Override public boolean containsHeader(String s) { return getHttpServletResponse().containsHeader(s); } @Override public String encodeURL(String s) { return getHttpServletResponse().encodeURL(s); } @Override public String encodeRedirectURL(String s) { return getHttpServletResponse().encodeRedirectURL(s); } @Override public String encodeUrl(String s) { return getHttpServletResponse().encodeUrl(s); } @Override public String encodeRedirectUrl(String s) { return getHttpServletResponse().encodeRedirectUrl(s); } @Override public void sendError(int i, String s) throws IOException { getHttpServletResponse().sendError(i, s); } @Override public void sendError(int i) throws IOException { getHttpServletResponse().sendError(i); } @Override public void sendRedirect(String s) throws IOException { getHttpServletResponse().sendRedirect(s); } @Override public void setDateHeader(String s, long l) { getHttpServletResponse().setDateHeader(s, l); } @Override public void addDateHeader(String s, long l) { getHttpServletResponse().addDateHeader(s, l); } @Override public void setHeader(String h, String v) { getHttpServletResponse().setHeader(h, v); } public Collection<String> getHeaderNames() { return getHttpServletResponse().getHeaderNames(); } public Collection<String> getHeaders(String s) { return getHttpServletResponse().getHeaders(s); } public String getHeader(String s) { return getHttpServletResponse().getHeader(s); } @Override public void addHeader(String h, String v) { getHttpServletResponse().addHeader(h, v); } @Override public void setIntHeader(String s, int i) { getHttpServletResponse().setIntHeader(s, i); } @Override public void addIntHeader(String s, int i) { getHttpServletResponse().addIntHeader(s, i); } @Override public void setStatus(int i) { getHttpServletResponse().setStatus(i); } @Override public int getStatus() { return getHttpServletResponse().getStatus(); } @Override public void setStatus(int i, String s) { getHttpServletResponse().setStatus(i, s); } @Override public String getCharacterEncoding() { return getHttpServletResponse().getCharacterEncoding(); } @Override public String getContentType() { return getHttpServletResponse().getContentType(); } @Override public ServletOutputStream getOutputStream() throws IOException { return getHttpServletResponse().getOutputStream(); } @Override public PrintWriter getWriter() throws IOException { return getHttpServletResponse().getWriter(); } @Override public void setCharacterEncoding(String s) { getHttpServletResponse().setCharacterEncoding(s); } @Override public void setContentLength(int i) { getHttpServletResponse().setContentLength(i); } @Override public void setContentType(String s) { getHttpServletResponse().setContentType(s); } @Override public void setBufferSize(int i) { getHttpServletResponse().setBufferSize(i); } @Override public int getBufferSize() { return getHttpServletResponse().getBufferSize(); } @Override public void flushBuffer() throws IOException { getHttpServletResponse().flushBuffer(); } @Override public void resetBuffer() { getHttpServletResponse().resetBuffer(); } @Override public boolean isCommitted() { return getHttpServletResponse().isCommitted(); } @Override public void reset() { getHttpServletResponse().reset(); } @Override public void setLocale(Locale locale) { getHttpServletResponse().setLocale(locale); } @Override public Locale getLocale() { return getHttpServletResponse().getLocale(); } } ) { @Override public ServletResponse getResponse() { return response.get().get(); } }; } } @SuppressWarnings("JavaDoc") private static class HttpServletRequestReferencingFactory extends ReferencingFactory<HttpServletRequestWrapper> { @Inject public HttpServletRequestReferencingFactory( final javax.inject.Provider<Ref<HttpServletRequestWrapper>> referenceFactory) { super(referenceFactory); } } @SuppressWarnings("JavaDoc") private static class HttpServletResponseReferencingFactory extends ReferencingFactory<HttpServletResponseWrapper> { @Inject public HttpServletResponseReferencingFactory( final javax.inject.Provider<Ref<HttpServletResponseWrapper>> referenceFactory) { super(referenceFactory); } } @Override public void configure(final ResourceConfig resourceConfig) throws ServletException { resourceConfig.register(new Binder()); } private HttpServletRequest finalWrap(final HttpServletRequest request) { return new RequestWrapper(request); } private HttpServletResponse finalWrap(final HttpServletResponse response) { return new ResponseWrapper(response); } private abstract static class MyHttpServletRequestImpl implements HttpServletRequest { @Override public String getAuthType() { return getHttpServletRequest().getAuthType(); } @Override public boolean authenticate(HttpServletResponse response) throws IOException, ServletException { return getHttpServletRequest().authenticate(response); } @Override public boolean isAsyncSupported() { return getHttpServletRequest().isAsyncSupported(); } @Override public boolean isAsyncStarted() { return getHttpServletRequest().isAsyncStarted(); } @Override public AsyncContext startAsync() throws IllegalStateException { return getHttpServletRequest().startAsync(); } @Override public AsyncContext startAsync(ServletRequest request, ServletResponse response) throws IllegalStateException { return getHttpServletRequest().startAsync(request, response); } abstract HttpServletRequest getHttpServletRequest(); @Override public Cookie[] getCookies() { return getHttpServletRequest().getCookies(); } @Override public long getDateHeader(String s) { return getHttpServletRequest().getDateHeader(s); } @Override public Part getPart(String s) throws ServletException, IOException { return getHttpServletRequest().getPart(s); } @Override public Collection<Part> getParts() throws ServletException, IOException { return getHttpServletRequest().getParts(); } @Override public String getHeader(String s) { return getHttpServletRequest().getHeader(s); } @Override public Enumeration getHeaders(String s) { return getHttpServletRequest().getHeaders(s); } @Override public Enumeration getHeaderNames() { return getHttpServletRequest().getHeaderNames(); } @Override public int getIntHeader(String s) { return getHttpServletRequest().getIntHeader(s); } @Override public String getMethod() { return getHttpServletRequest().getMethod(); } @Override public String getPathInfo() { return getHttpServletRequest().getPathInfo(); } @Override public String getPathTranslated() { return getHttpServletRequest().getPathTranslated(); } @Override public String getContextPath() { return getHttpServletRequest().getContextPath(); } @Override public String getQueryString() { return getHttpServletRequest().getQueryString(); } @Override public String getRemoteUser() { return getHttpServletRequest().getRemoteUser(); } @Override public boolean isUserInRole(String s) { return getHttpServletRequest().isUserInRole(s); } @Override public Principal getUserPrincipal() { return getHttpServletRequest().getUserPrincipal(); } @Override public String getRequestedSessionId() { return getHttpServletRequest().getRequestedSessionId(); } @Override public String getRequestURI() { return getHttpServletRequest().getRequestURI(); } @Override public StringBuffer getRequestURL() { return getHttpServletRequest().getRequestURL(); } @Override public String getServletPath() { return getHttpServletRequest().getServletPath(); } @Override public HttpSession getSession(boolean b) { return getHttpServletRequest().getSession(b); } @Override public HttpSession getSession() { return getHttpServletRequest().getSession(); } @Override public boolean isRequestedSessionIdValid() { return getHttpServletRequest().isRequestedSessionIdValid(); } @Override public boolean isRequestedSessionIdFromCookie() { return getHttpServletRequest().isRequestedSessionIdFromCookie(); } @Override public boolean isRequestedSessionIdFromURL() { return getHttpServletRequest().isRequestedSessionIdFromURL(); } @Override public boolean isRequestedSessionIdFromUrl() { return getHttpServletRequest().isRequestedSessionIdFromUrl(); } @Override public Object getAttribute(String s) { return getHttpServletRequest().getAttribute(s); } @Override public Enumeration getAttributeNames() { return getHttpServletRequest().getAttributeNames(); } @Override public String getCharacterEncoding() { return getHttpServletRequest().getCharacterEncoding(); } @Override public void setCharacterEncoding(String s) throws UnsupportedEncodingException { getHttpServletRequest().setCharacterEncoding(s); } @Override public int getContentLength() { return getHttpServletRequest().getContentLength(); } @Override public String getContentType() { return getHttpServletRequest().getContentType(); } @Override public ServletInputStream getInputStream() throws IOException { return getHttpServletRequest().getInputStream(); } @Override public String getParameter(String s) { return getHttpServletRequest().getParameter(s); } @Override public Enumeration getParameterNames() { return getHttpServletRequest().getParameterNames(); } @Override public String[] getParameterValues(String s) { return getHttpServletRequest().getParameterValues(s); } @Override public Map getParameterMap() { return getHttpServletRequest().getParameterMap(); } @Override public String getProtocol() { return getHttpServletRequest().getProtocol(); } @Override public String getScheme() { return getHttpServletRequest().getScheme(); } @Override public String getServerName() { return getHttpServletRequest().getServerName(); } @Override public int getServerPort() { return getHttpServletRequest().getServerPort(); } @Override public BufferedReader getReader() throws IOException { return getHttpServletRequest().getReader(); } @Override public String getRemoteAddr() { return getHttpServletRequest().getRemoteAddr(); } @Override public String getRemoteHost() { return getHttpServletRequest().getRemoteHost(); } @Override public void setAttribute(String s, Object o) { getHttpServletRequest().setAttribute(s, o); } @Override public void removeAttribute(String s) { getHttpServletRequest().removeAttribute(s); } @Override public Locale getLocale() { return getHttpServletRequest().getLocale(); } @Override public Enumeration getLocales() { return getHttpServletRequest().getLocales(); } @Override public boolean isSecure() { return getHttpServletRequest().isSecure(); } @Override public RequestDispatcher getRequestDispatcher(String s) { return getHttpServletRequest().getRequestDispatcher(s); } @Override public String getRealPath(String s) { return getHttpServletRequest().getRealPath(s); } @Override public int getRemotePort() { return getHttpServletRequest().getRemotePort(); } @Override public String getLocalName() { return getHttpServletRequest().getLocalName(); } @Override public String getLocalAddr() { return getHttpServletRequest().getLocalAddr(); } @Override public int getLocalPort() { return getHttpServletRequest().getLocalPort(); } @Override public DispatcherType getDispatcherType() { return getHttpServletRequest().getDispatcherType(); } @Override public AsyncContext getAsyncContext() { return getHttpServletRequest().getAsyncContext(); } @Override public ServletContext getServletContext() { return getHttpServletRequest().getServletContext(); } @Override public void logout() throws ServletException { getHttpServletRequest().logout(); } @Override public void login(String u, String p) throws ServletException { getHttpServletRequest().login(u, p); } } private static ServiceLocator getServiceLocator(InjectionManager injectionManager) { if (injectionManager instanceof ImmediateHk2InjectionManager) { return ((ImmediateHk2InjectionManager) injectionManager).getServiceLocator(); } else if (injectionManager instanceof DelayedHk2InjectionManager) { return ((DelayedHk2InjectionManager) injectionManager).getServiceLocator(); } else { throw new RuntimeException("Invalid InjectionManager"); } } }