/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.presto.server.security; import com.google.common.base.Splitter; import com.google.common.base.Throwables; import com.google.common.cache.CacheBuilder; import com.google.common.cache.CacheLoader; import com.google.common.cache.LoadingCache; import com.google.common.collect.ImmutableMap; import com.google.common.net.HttpHeaders; import io.airlift.http.client.HttpStatus; import io.airlift.log.Logger; import javax.annotation.Nonnull; import javax.inject.Inject; import javax.naming.NamingEnumeration; import javax.naming.NamingException; import javax.naming.directory.DirContext; import javax.naming.directory.InitialDirContext; import javax.naming.directory.SearchControls; import javax.naming.directory.SearchResult; import javax.servlet.Filter; import javax.servlet.FilterChain; import javax.servlet.FilterConfig; import javax.servlet.ServletException; import javax.servlet.ServletRequest; import javax.servlet.ServletResponse; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequestWrapper; import javax.servlet.http.HttpServletResponse; import java.io.IOException; import java.io.InputStream; import java.security.Principal; import java.util.Base64; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Optional; import java.util.concurrent.ExecutionException; import static com.facebook.presto.server.security.util.jndi.JndiUtils.getInitialDirContext; import static com.google.common.base.CharMatcher.JAVA_ISO_CONTROL; import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Throwables.propagateIfInstanceOf; import static com.google.common.io.ByteStreams.copy; import static com.google.common.io.ByteStreams.nullOutputStream; import static com.google.common.net.HttpHeaders.AUTHORIZATION; import static io.airlift.http.client.HttpStatus.BAD_REQUEST; import static io.airlift.http.client.HttpStatus.INTERNAL_SERVER_ERROR; import static io.airlift.http.client.HttpStatus.UNAUTHORIZED; import static java.lang.String.format; import static java.nio.charset.StandardCharsets.UTF_8; import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.MILLISECONDS; import static javax.naming.Context.INITIAL_CONTEXT_FACTORY; import static javax.naming.Context.PROVIDER_URL; import static javax.naming.Context.SECURITY_AUTHENTICATION; import static javax.naming.Context.SECURITY_CREDENTIALS; import static javax.naming.Context.SECURITY_PRINCIPAL; public class LdapFilter implements Filter { private static final Logger log = Logger.get(LdapFilter.class); private static final String BASIC_AUTHENTICATION_PREFIX = "Basic "; private static final String LDAP_CONTEXT_FACTORY = "com.sun.jndi.ldap.LdapCtxFactory"; private final String ldapUrl; private final String userBindSearchPattern; private final Optional<String> groupAuthorizationSearchPattern; private final Optional<String> userBaseDistinguishedName; private final Map<String, String> basicEnvironment; private final LoadingCache<Credentials, Principal> authenticationCache; @Inject public LdapFilter(LdapConfig serverConfig) { this.ldapUrl = requireNonNull(serverConfig.getLdapUrl(), "ldapUrl is null"); this.userBindSearchPattern = requireNonNull(serverConfig.getUserBindSearchPattern(), "userBindSearchPattern is null"); this.groupAuthorizationSearchPattern = Optional.ofNullable(serverConfig.getGroupAuthorizationSearchPattern()); this.userBaseDistinguishedName = Optional.ofNullable(serverConfig.getUserBaseDistinguishedName()); if (groupAuthorizationSearchPattern.isPresent()) { checkState(userBaseDistinguishedName.isPresent(), "Base distinguished name (DN) for user is null"); } Map<String, String> environment = ImmutableMap.<String, String>builder() .put(INITIAL_CONTEXT_FACTORY, LDAP_CONTEXT_FACTORY) .put(PROVIDER_URL, ldapUrl) .build(); checkEnvironment(environment); this.basicEnvironment = environment; this.authenticationCache = CacheBuilder.newBuilder() .expireAfterWrite(serverConfig.getLdapCacheTtl().toMillis(), MILLISECONDS) .build(new CacheLoader<Credentials, Principal>() { @Override public Principal load(@Nonnull Credentials key) throws AuthenticationException { return authenticate(key.getUser(), key.getPassword()); } }); } private static void checkEnvironment(Map<String, String> environment) { try { closeContext(createDirContext(environment)); } catch (NamingException e) { throw Throwables.propagate(e); } } private static InitialDirContext createDirContext(Map<String, String> environment) throws NamingException { return getInitialDirContext(environment); } private static void closeContext(InitialDirContext context) { if (context != null) { try { context.close(); } catch (NamingException ignore) { } } } @Override public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain nextFilter) throws IOException, ServletException { // skip auth for http if (!servletRequest.isSecure()) { nextFilter.doFilter(servletRequest, servletResponse); return; } HttpServletRequest request = (HttpServletRequest) servletRequest; HttpServletResponse response = (HttpServletResponse) servletResponse; try { String header = request.getHeader(AUTHORIZATION); Credentials credentials = getCredentials(header); Principal principal = getPrincipal(credentials); // ldap authentication ok, continue nextFilter.doFilter(new HttpServletRequestWrapper(request) { @Override public Principal getUserPrincipal() { return principal; } }, servletResponse); } catch (AuthenticationException e) { log.debug(e, "LDAP authentication failed"); processAuthenticationException(e, request, response); } } private Principal getPrincipal(Credentials credentials) throws AuthenticationException { try { return authenticationCache.get(credentials); } catch (ExecutionException e) { Throwable cause = e.getCause(); propagateIfInstanceOf(cause, AuthenticationException.class); throw Throwables.propagate(cause); } } private static void processAuthenticationException(AuthenticationException e, HttpServletRequest request, HttpServletResponse response) throws IOException { if (e.getStatus() == UNAUTHORIZED) { // If we send the challenge without consuming the body of the request, // the Jetty server will close the connection after sending the response. // The client interprets this as a failed request and does not resend // the request with the authentication header. // We can avoid this behavior in the Jetty client by reading and discarding // the entire body of the unauthenticated request before sending the response. skipRequestBody(request); response.setHeader(HttpHeaders.WWW_AUTHENTICATE, "Basic realm=\"presto\""); } response.sendError(e.getStatus().code(), e.getMessage()); } private static void skipRequestBody(HttpServletRequest request) throws IOException { try (InputStream inputStream = request.getInputStream()) { copy(inputStream, nullOutputStream()); } } private static Credentials getCredentials(String header) throws AuthenticationException { if (header == null) { throw new AuthenticationException(UNAUTHORIZED, "Unauthorized"); } if (!header.startsWith(BASIC_AUTHENTICATION_PREFIX)) { throw new AuthenticationException(BAD_REQUEST, "Basic authentication is expected"); } String base64EncodedCredentials = header.substring(BASIC_AUTHENTICATION_PREFIX.length()); String credentials = decodeCredentials(base64EncodedCredentials); List<String> parts = Splitter.on(':').limit(2).splitToList(credentials); if (parts.size() != 2 || parts.stream().anyMatch(String::isEmpty)) { throw new AuthenticationException(BAD_REQUEST, "Malformed decoded credentials"); } return new Credentials(parts.get(0), parts.get(1)); } private static String decodeCredentials(String base64EncodedCredentials) throws AuthenticationException { byte[] bytes; try { bytes = Base64.getDecoder().decode(base64EncodedCredentials); } catch (IllegalArgumentException e) { throw new AuthenticationException(BAD_REQUEST, "Invalid base64 encoded credentials"); } return new String(bytes, UTF_8); } private Principal authenticate(String user, String password) throws AuthenticationException { Map<String, String> environment = createEnvironment(user, password); InitialDirContext context = null; try { context = createDirContext(environment); checkForGroupMembership(user, context); log.debug("Authentication successful for user %s", user); return new LdapPrincipal(user); } catch (javax.naming.AuthenticationException e) { String formattedAsciiMessage = format("Invalid credentials: %s", JAVA_ISO_CONTROL.removeFrom(e.getMessage())); log.debug("Authentication failed for user [%s]. %s", user, e.getMessage()); throw new AuthenticationException(UNAUTHORIZED, formattedAsciiMessage, e); } catch (NamingException e) { log.debug("Authentication failed", e.getMessage()); throw new AuthenticationException(INTERNAL_SERVER_ERROR, "Authentication failed", e); } finally { closeContext(context); } } private Map<String, String> createEnvironment(String user, String password) { return ImmutableMap.<String, String>builder() .putAll(basicEnvironment) .put(SECURITY_AUTHENTICATION, "simple") .put(SECURITY_PRINCIPAL, createPrincipal(user)) .put(SECURITY_CREDENTIALS, password) .build(); } private String createPrincipal(String user) { return replaceUser(userBindSearchPattern, user); } private String replaceUser(String pattern, String user) { return pattern.replaceAll("\\$\\{USER\\}", user); } private void checkForGroupMembership(String user, DirContext context) throws AuthenticationException { if (!groupAuthorizationSearchPattern.isPresent()) { return; } String searchFilter = replaceUser(groupAuthorizationSearchPattern.get(), user); SearchControls searchControls = new SearchControls(); searchControls.setSearchScope(SearchControls.SUBTREE_SCOPE); boolean authorized; NamingEnumeration<SearchResult> search = null; try { search = context.search(userBaseDistinguishedName.get(), searchFilter, searchControls); authorized = search.hasMoreElements(); } catch (NamingException e) { log.debug("Authentication failed", e.getMessage()); throw new AuthenticationException(INTERNAL_SERVER_ERROR, "Authentication failed", e); } finally { if (search != null) { try { search.close(); } catch (NamingException ignore) { } } } if (!authorized) { String message = format("Unauthorized user: User %s not a member of the authorized group", user); log.debug("Authorization failed for user. " + message); throw new AuthenticationException(UNAUTHORIZED, message); } log.debug("Authorization succeeded for user %s", user); } @Override public void init(FilterConfig filterConfig) {} @Override public void destroy() {} private static final class LdapPrincipal implements Principal { private final String name; private LdapPrincipal(String name) { this.name = requireNonNull(name, "name is null"); } @Override public String getName() { return name; } @Override public boolean equals(Object o) { if (this == o) { return true; } if (o == null || getClass() != o.getClass()) { return false; } LdapPrincipal that = (LdapPrincipal) o; return Objects.equals(name, that.name); } @Override public int hashCode() { return Objects.hash(name); } @Override public String toString() { return name; } } private static class Credentials { private final String user; private final String password; private Credentials(String user, String password) { this.user = requireNonNull(user); this.password = requireNonNull(password); } public String getUser() { return user; } public String getPassword() { return password; } @Override public boolean equals(Object o) { if (this == o) { return true; } if (o == null || getClass() != o.getClass()) { return false; } Credentials that = (Credentials) o; return Objects.equals(this.user, that.user) && Objects.equals(this.password, that.password); } @Override public int hashCode() { return Objects.hash(user, password); } @Override public String toString() { return toStringHelper(this) .add("user", user) .add("password", password) .toString(); } } private static class AuthenticationException extends Exception { private final HttpStatus status; private AuthenticationException(HttpStatus status, String message) { this(status, message, null); } private AuthenticationException(HttpStatus status, String message, Throwable cause) { super(message, cause); requireNonNull(message, "message is null"); this.status = requireNonNull(status, "status is null"); } public HttpStatus getStatus() { return status; } } }