/* * Copyright 2015, The OpenNMS Group * * Licensed under the Apache License, Version 2.0 (the "License"); you may * not use this file except in compliance with the License. You may obtain * a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.opennms.newts.rest; import static com.google.common.base.Preconditions.checkNotNull; import java.io.IOException; import java.io.UnsupportedEncodingException; import java.util.Map; import java.util.regex.Matcher; import java.util.regex.Pattern; import javax.servlet.Filter; import javax.servlet.FilterChain; import javax.servlet.FilterConfig; import javax.servlet.ServletException; import javax.servlet.ServletRequest; import javax.servlet.ServletResponse; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import org.apache.commons.codec.binary.Base64; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import com.google.common.base.Optional; import com.google.common.base.Throwables; public class HttpBasicAuthenticationFilter implements Filter { private final static Logger LOG = LoggerFactory.getLogger(HttpBasicAuthenticationFilter.class); private final static String REALM = "Newts"; private final NewtsConfig m_config; public HttpBasicAuthenticationFilter(NewtsConfig config) { m_config = checkNotNull(config, "config argument"); } @Override public void init(FilterConfig filterConfig) throws ServletException { LOG.info("HTTP Basic Auth servlet filter initialized"); } @Override public void destroy() { LOG.info("Shutting down HTTP Basic Auth servlet filter"); } @Override public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain chain) throws IOException, ServletException { LOG.trace("doFilter()"); HttpServletRequest request = (HttpServletRequest)servletRequest; if (enabled() && !isCorsPreflight(request)) { LOG.trace("Authentication is enabled"); HttpServletResponse response = (HttpServletResponse)servletResponse; Optional<String> authHeader = getAuthorizationHeader(request); if (authHeader.isPresent()) { Credentials credentials = Credentials.fromHeader(authHeader.get()); if (!isAuthorized(credentials)) { LOG.trace("Credentials do NOT match; Authorizationi failed"); sendUnauthorized(response); return; // Stop processing filters on failed authentication } LOG.trace("User {} is authorized", credentials.getUser()); } else { LOG.trace("Missing Authorization HTTP header; Authorization failed"); sendUnauthorized(response); return; // Stop processing filters on failed authentication } } else { LOG.trace("Authentication is NOT enabled (skipping...)"); } chain.doFilter(servletRequest, servletResponse); } private void sendUnauthorized(HttpServletResponse response) throws IOException { sendUnauthorized(response, "Unauthorized"); } private void sendUnauthorized(HttpServletResponse response, String msg) throws IOException { response.setHeader("WWW-Authenticate", "Basic realm=\"" + REALM + "\""); response.sendError(401, msg); } private boolean isAuthorized(Credentials credentials) { Map<String, String> passwords = m_config.getAuthenticationConfig().getCredentials(); String user = credentials.getUser(), pass = credentials.getPass(); if (passwords.containsKey(user) && (passwords.get(user) != null)) { return passwords.get(user).equals(pass); } return false; } private boolean isCorsPreflight(HttpServletRequest request) { return request.getMethod().equals("OPTIONS") && (request.getHeader("Access-Control-Request-Method") != null); } private boolean enabled() { return m_config.getAuthenticationConfig().isEnabled(); } private static Optional<String> getAuthorizationHeader(HttpServletRequest request) { String v = trim(request.getHeader("Authorization")); return (v != null) ? Optional.of(v) : Optional.<String>absent(); } /** Trim string if not null. */ private static String trim(String s) { return (s != null) ? s.trim() : s; } static class Credentials { private static final Pattern s_headerPattern = Pattern.compile("Basic (?<token>.+)", Pattern.CASE_INSENSITIVE); private static final Pattern s_credsPattern = Pattern.compile("(?<user>.+):(?<pass>.+)"); private final String m_user; private final String m_pass; Credentials(String user, String pass) { m_user = checkNotNull(user, "user argument"); m_pass = checkNotNull(pass, "pass argument"); } String getUser() { return m_user; } String getPass() { return m_pass; } /** Creates a {@link Credentials} instance from an HTTP basic authentication header value. */ static Credentials fromHeader(String headerValue) { String encoded, decoded; Matcher matcher; matcher = s_headerPattern.matcher(headerValue); if (matcher.matches()) { encoded = matcher.group("token"); } else { throw new IllegalArgumentException("malformed credentials header"); } try { decoded = new String(Base64.decodeBase64(encoded.getBytes("UTF-8")), "UTF-8"); } catch (UnsupportedEncodingException e) { throw Throwables.propagate(e); } matcher = s_credsPattern.matcher(decoded); if (matcher.matches()) { return new Credentials(matcher.group("user"), matcher.group("pass")); } else { throw new IllegalArgumentException("malformed credentials header"); } } } }