/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.ambari.logsearch.web.filters;
import io.jsonwebtoken.Claims;
import io.jsonwebtoken.ExpiredJwtException;
import io.jsonwebtoken.Jwts;
import io.jsonwebtoken.MalformedJwtException;
import io.jsonwebtoken.SignatureException;
import org.apache.ambari.logsearch.conf.AuthPropsConfig;
import org.apache.ambari.logsearch.web.model.JWTAuthenticationToken;
import org.apache.commons.lang.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.security.authentication.AnonymousAuthenticationToken;
import org.springframework.security.authentication.BadCredentialsException;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter;
import org.springframework.security.web.util.matcher.NegatedRequestMatcher;
import org.springframework.security.web.util.matcher.RequestMatcher;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.Cookie;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.UnsupportedEncodingException;
import java.security.cert.CertificateException;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;
import java.security.interfaces.RSAPublicKey;
public class LogsearchJWTFilter extends AbstractAuthenticationProcessingFilter {
private static final Logger LOG = LoggerFactory.getLogger(LogsearchJWTFilter.class);
private static final String PEM_HEADER = "-----BEGIN CERTIFICATE-----\n";
private static final String PEM_FOOTER = "\n-----END CERTIFICATE-----";
private AuthPropsConfig authPropsConfig;
public LogsearchJWTFilter(RequestMatcher requestMatcher, AuthPropsConfig authPropsConfig) {
super(new NegatedRequestMatcher(requestMatcher));
this.authPropsConfig = authPropsConfig;
}
@Override
public Authentication attemptAuthentication(HttpServletRequest request, HttpServletResponse response) throws AuthenticationException, IOException, ServletException {
if (StringUtils.isEmpty(authPropsConfig.getProvidedUrl())) {
throw new BadCredentialsException("Authentication provider URL must not be null or empty.");
}
if (StringUtils.isEmpty(authPropsConfig.getPublicKey())) {
throw new BadCredentialsException("Public key for signature validation must be provisioned.");
}
try {
Claims claims = Jwts
.parser()
.setSigningKey(parseRSAPublicKey(authPropsConfig.getPublicKey()))
.parseClaimsJws(getJWTFromCookie(request))
.getBody();
String userName = claims.getSubject();
LOG.info("USERNAME: " + userName);
LOG.info("URL = " + request.getRequestURL());
if (StringUtils.isNotEmpty(claims.getAudience()) && !authPropsConfig.getAudiences().contains(claims.getAudience())) {
throw new IllegalArgumentException(String.format("Audience validation failed. (Not found: %s)", claims.getAudience()));
}
Authentication authentication = new JWTAuthenticationToken(userName, authPropsConfig.getPublicKey());
authentication.setAuthenticated(true);
SecurityContextHolder.getContext().setAuthentication(authentication);
return authentication;
} catch (ExpiredJwtException | MalformedJwtException | SignatureException | IllegalArgumentException e) {
LOG.info("URL = " + request.getRequestURL());
LOG.warn("Error during JWT authentication: ", e.getMessage());
throw new BadCredentialsException(e.getMessage(), e);
}
}
@Override
public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain) throws IOException, ServletException {
Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
if (!authPropsConfig.isAuthJwtEnabled() || isAuthenticated(authentication)) {
chain.doFilter(req, res);
return;
}
super.doFilter(req, res, chain);
}
@Override
protected void successfulAuthentication(HttpServletRequest request, HttpServletResponse response, FilterChain chain, Authentication authResult) throws IOException, ServletException {
super.successfulAuthentication(request, response, chain, authResult);
response.sendRedirect(request.getRequestURL().toString() + getOriginalQueryString(request));
}
@Override
protected void unsuccessfulAuthentication(HttpServletRequest request, HttpServletResponse response, AuthenticationException failed) throws IOException, ServletException {
super.unsuccessfulAuthentication(request, response, failed);
String loginUrl = constructLoginURL(request);
response.sendRedirect(loginUrl);
}
private String getJWTFromCookie(HttpServletRequest req) {
String serializedJWT = null;
Cookie[] cookies = req.getCookies();
if (cookies != null) {
for (Cookie cookie : cookies) {
if (authPropsConfig.getCookieName().equals(cookie.getName())) {
LOG.info(authPropsConfig.getCookieName() + " cookie has been found and is being processed");
serializedJWT = cookie.getValue();
break;
}
}
}
return serializedJWT;
}
private RSAPublicKey parseRSAPublicKey(String pem) throws ServletException {
String fullPem = PEM_HEADER + pem + PEM_FOOTER;
try {
CertificateFactory fact = CertificateFactory.getInstance("X.509");
ByteArrayInputStream is = new ByteArrayInputStream(fullPem.getBytes("UTF8"));
X509Certificate cer = (X509Certificate) fact.generateCertificate(is);
return (RSAPublicKey) cer.getPublicKey();
} catch (CertificateException ce) {
String message;
if (pem.startsWith(PEM_HEADER)) {
message = "CertificateException - be sure not to include PEM header "
+ "and footer in the PEM configuration element.";
} else {
message = "CertificateException - PEM may be corrupt";
}
throw new ServletException(message, ce);
} catch (UnsupportedEncodingException uee) {
throw new ServletException(uee);
}
}
private String constructLoginURL(HttpServletRequest request) {
String delimiter = "?";
if (authPropsConfig.getProvidedUrl().contains("?")) {
delimiter = "&";
}
return authPropsConfig.getProvidedUrl() + delimiter
+ authPropsConfig.getOriginalUrlQueryParam() + "="
+ request.getRequestURL().toString() + getOriginalQueryString(request);
}
private String getOriginalQueryString(HttpServletRequest request) {
String originalQueryString = request.getQueryString();
return (originalQueryString == null) ? "" : "?" + originalQueryString;
}
private boolean isAuthenticated(Authentication authentication) {
return authentication != null && !(authentication instanceof AnonymousAuthenticationToken) && authentication.isAuthenticated();
}
}