/******************************************************************************* * Copyright (c) 2012-2016 Codenvy, S.A. * All rights reserved. This program and the accompanying materials * are made available under the terms of the Eclipse Public License v1.0 * which accompanies this distribution, and is available at * http://www.eclipse.org/legal/epl-v10.html * * Contributors: * Codenvy, S.A. - initial API and implementation *******************************************************************************/ package org.everrest.core.servlet; import org.everrest.core.ExtHttpHeaders; import org.everrest.core.impl.ContainerRequest; import org.everrest.core.impl.InputHeadersMap; import org.everrest.core.impl.MultivaluedMapImpl; import org.slf4j.LoggerFactory; import javax.servlet.http.HttpServletRequest; import javax.ws.rs.core.MultivaluedMap; import javax.ws.rs.core.SecurityContext; import java.io.IOException; import java.io.InputStream; import java.net.MalformedURLException; import java.net.URI; import java.net.URISyntaxException; import java.net.URL; import java.security.Principal; import java.util.Enumeration; /** @author andrew00x */ public class ServletContainerRequest extends ContainerRequest { private static final org.slf4j.Logger LOG = LoggerFactory.getLogger(ServletContainerRequest.class); public static ServletContainerRequest create(final HttpServletRequest req) { // If the URL is forwarded, obtain the forwarding information final URL forwardedUrl = getForwardedUrl(req); String host; int port; if (forwardedUrl == null) { host = req.getServerName(); port = req.getServerPort(); } else { host = forwardedUrl.getHost(); port = forwardedUrl.getPort(); if (port < 0) { port = forwardedUrl.getDefaultPort(); } LOG.debug("Assuming forwarded URL: {}", forwardedUrl); } // The common URI prefix for both baseUri and requestUri final StringBuilder commonUriBuilder = new StringBuilder(); final String scheme = getScheme(req); commonUriBuilder.append(scheme); commonUriBuilder.append("://"); commonUriBuilder.append(host); if (!(port < 0 || (port == 80 && "http".equals(scheme)) || (port == 443 && "https".equals(scheme)))) { commonUriBuilder.append(':'); commonUriBuilder.append(port); } final String commonUriPrefix = commonUriBuilder.toString(); // The Base URI - up to the servlet path final StringBuilder baseUriBuilder = new StringBuilder(commonUriPrefix); baseUriBuilder.append(req.getContextPath()); baseUriBuilder.append(req.getServletPath()); final URI baseUri = URI.create(baseUriBuilder.toString()); // The RequestURI - everything in the URL final StringBuilder requestUriBuilder = new StringBuilder(commonUriPrefix); requestUriBuilder.append(req.getRequestURI()); final String queryString = req.getQueryString(); if (queryString != null) { requestUriBuilder.append('?'); requestUriBuilder.append(queryString); } final URI requestUri = URI.create(requestUriBuilder.toString()); return new ServletContainerRequest(getMethod(req), requestUri, baseUri, getEntityStream(req), getHeaders(req), getSecurityContext(req)); } private ServletContainerRequest(String method, URI requestUri, URI baseUri, InputStream entityStream, MultivaluedMap<String, String> httpHeaders, SecurityContext securityContext) { super(method, requestUri, baseUri, entityStream, httpHeaders, securityContext); } /** * Extract HTTP method name from servlet request. * * @param servletRequest * {@link HttpServletRequest} * @return HTTP method name * @see HttpServletRequest#getMethod() */ private static String getMethod(HttpServletRequest servletRequest) { return servletRequest.getMethod(); } private static String getScheme(HttpServletRequest servletRequest) { return servletRequest.getScheme(); } /** * Get the URL that is forwarded using the standard X-Forwarded-Host header. * * @param servletRequest * @return The URL of the forwarded host. If the header is missing or invalid, null is returned. */ private static URL getForwardedUrl(HttpServletRequest servletRequest) { final String forwardedHostAndPort = servletRequest.getHeader(FORWARDED_HOST); if (forwardedHostAndPort == null || forwardedHostAndPort.isEmpty()) { return null; } URL url = parseForwardedHostHeader(forwardedHostAndPort, servletRequest); if (url == null && LOG.isWarnEnabled()) { LOG.warn("Ignoring invalid " + ExtHttpHeaders.FORWARDED_HOST + ": " + forwardedHostAndPort); } return url; } /** Parse according to IETF standard for Host field: http://tools.ietf.org/html/rfc7230#section-5.4 */ private static URL parseForwardedHostHeader(String forwardedHostAndPort, HttpServletRequest servletRequest) { final String[] parts = forwardedHostAndPort.split(":"); if (parts.length > 2) { return null; } int fwdPort = -1; if (parts.length == 2) { try { fwdPort = Integer.parseInt(parts[1]); } catch (NumberFormatException e) { return null; } if (fwdPort < 0) { return null; } } final String fwdHost = parts[0]; final String scheme = getScheme(servletRequest); try { return new URI(scheme, null, fwdHost, fwdPort, null, null, null).toURL(); } catch (URISyntaxException | MalformedURLException e) { LOG.debug(e.getLocalizedMessage()); } return null; } /** * Get HTTP headers from {@link HttpServletRequest} . * * @param servletRequest * {@link HttpServletRequest} * @return request headers */ private static MultivaluedMap<String, String> getHeaders(HttpServletRequest servletRequest) { MultivaluedMap<String, String> h = new MultivaluedMapImpl(); Enumeration<String> headerNames = servletRequest.getHeaderNames(); while (headerNames.hasMoreElements()) { String name = headerNames.nextElement(); Enumeration<String> e = servletRequest.getHeaders(name); while (e.hasMoreElements()) { h.add(name, e.nextElement()); } } return new InputHeadersMap(h); } /** * Get input stream from {@link HttpServletRequest} . * * @param servletRequest * {@link HttpServletRequest} * @return request stream or null */ private static InputStream getEntityStream(HttpServletRequest servletRequest) { try { return servletRequest.getInputStream(); } catch (IOException e) { throw new RuntimeException(e); } } private static SecurityContext getSecurityContext(final HttpServletRequest servletRequest) { return new SecurityContext() { @Override public Principal getUserPrincipal() { return servletRequest.getUserPrincipal(); } @Override public boolean isUserInRole(String role) { return servletRequest.isUserInRole(role); } @Override public boolean isSecure() { return servletRequest.isSecure(); } @Override public String getAuthenticationScheme() { return servletRequest.getAuthType(); } }; } }