/*
* oxAuth is available under the MIT License (2008). See http://opensource.org/licenses/MIT for full text.
*
* Copyright (c) 2014, Gluu
*/
package org.xdi.oxauth.auth;
import java.io.IOException;
import java.io.PrintWriter;
import java.io.UnsupportedEncodingException;
import java.util.List;
import javax.inject.Inject;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.annotation.WebFilter;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.apache.commons.codec.binary.Base64;
import org.apache.commons.lang.StringUtils;
import org.slf4j.Logger;
import org.xdi.model.security.Identity;
import org.xdi.oxauth.model.authorize.AuthorizeRequestParam;
import org.xdi.oxauth.model.common.AuthenticationMethod;
import org.xdi.oxauth.model.common.Prompt;
import org.xdi.oxauth.model.common.SessionIdState;
import org.xdi.oxauth.model.common.SessionState;
import org.xdi.oxauth.model.configuration.AppConfiguration;
import org.xdi.oxauth.model.error.ErrorResponseFactory;
import org.xdi.oxauth.model.exception.InvalidJwtException;
import org.xdi.oxauth.model.registration.Client;
import org.xdi.oxauth.model.token.ClientAssertion;
import org.xdi.oxauth.model.token.ClientAssertionType;
import org.xdi.oxauth.model.token.TokenErrorResponseType;
import org.xdi.oxauth.model.util.Util;
import org.xdi.oxauth.service.ClientFilterService;
import org.xdi.oxauth.service.ClientService;
import org.xdi.oxauth.service.SessionStateService;
import org.xdi.oxauth.util.ServerUtil;
import org.xdi.util.StringHelper;
/**
* @author Javier Rojas Blum
* @version March 4, 2016
*/
@WebFilter(asyncSupported = true, urlPatterns = {
"/seam/resource/restv1/oxauth/authorize", "/seam/resource/restv1/oxauth/token",
"/seam/resource/restv1/oxauth/userinfo"
}, displayName = "oxAuth"
)
public class AuthenticationFilter implements Filter {
@Inject
private Logger log;
@Inject
private Authenticator authenticator;
@Inject
private SessionStateService sessionStateService;
@Inject
private ClientService clientService;
@Inject
private ClientFilterService clientFilterService;
@Inject
private ErrorResponseFactory errorResponseFactory;
@Inject
private AppConfiguration appConfiguration;
@Inject
private Identity identity;
private String realm;
public static final String REALM = "oxAuth";
@Override
public void init(FilterConfig filterConfig) throws ServletException {
}
@Override
public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, final FilterChain filterChain)
throws IOException, ServletException {
final HttpServletRequest httpRequest = (HttpServletRequest) servletRequest;
final HttpServletResponse httpResponse = (HttpServletResponse) servletResponse;
try {
final String requestUrl = httpRequest.getRequestURL().toString();
log.trace("Get request to: '{}'", requestUrl);
if (requestUrl.endsWith("/token")
&& ServerUtil.isSameRequestPath(requestUrl, appConfiguration.getTokenEndpoint())) {
log.debug("Starting token endpoint authentication");
if (httpRequest.getParameter("client_assertion") != null
&& httpRequest.getParameter("client_assertion_type") != null) {
log.debug("Starting JWT token endpoint authentication");
processJwtAuth(httpRequest, httpResponse, filterChain);
} else if (httpRequest.getHeader("Authorization") != null
&& httpRequest.getHeader("Authorization").startsWith("Basic ")) {
log.debug("Starting Basic Auth token endpoint authentication");
processBasicAuth(clientService, errorResponseFactory, httpRequest, httpResponse, filterChain);
} else {
log.debug("Starting POST Auth token endpoint authentication");
processPostAuth(clientService, clientFilterService, errorResponseFactory, httpRequest, httpResponse,
filterChain);
}
} else if (httpRequest.getHeader("Authorization") != null) {
String header = httpRequest.getHeader("Authorization");
if (header.startsWith("Bearer ")) {
processBearerAuth(httpRequest, httpResponse, filterChain);
} else if (header.startsWith("Basic ")) {
processBasicAuth(clientService, errorResponseFactory, httpRequest, httpResponse, filterChain);
} else {
httpResponse.addHeader("WWW-Authenticate", "Basic realm=\"" + getRealm() + "\"");
httpResponse.sendError(401, "Not authorized");
}
} else {
String sessionState = httpRequest.getParameter(AuthorizeRequestParam.SESSION_STATE);
List<Prompt> prompts = Prompt.fromString(httpRequest.getParameter(AuthorizeRequestParam.PROMPT), " ");
if (StringUtils.isBlank(sessionState)) {
// OXAUTH-297 : check whether session_state is present in
// cookie
sessionState = sessionStateService.getSessionStateFromCookie(httpRequest);
}
SessionState sessionStateObject = null;
if (StringUtils.isNotBlank(sessionState)) {
sessionStateObject = sessionStateService.getSessionState(sessionState);
}
if (sessionStateObject != null && SessionIdState.AUTHENTICATED == sessionStateObject.getState()
&& !prompts.contains(Prompt.LOGIN)) {
processSessionAuth(errorResponseFactory, sessionState, httpRequest, httpResponse, filterChain);
} else {
filterChain.doFilter(httpRequest, httpResponse);
}
}
} catch (IOException ex) {
log.error(ex.getMessage(), ex);
} catch (Exception ex) {
log.error(ex.getMessage(), ex);
}
}
private void processSessionAuth(ErrorResponseFactory errorResponseFactory, String p_sessionState,
HttpServletRequest p_httpRequest, HttpServletResponse p_httpResponse, FilterChain p_filterChain)
throws IOException, ServletException {
boolean requireAuth;
requireAuth = !authenticator.authenticateBySessionState(p_sessionState);
log.trace("Process Session Auth, sessionState = {}, requireAuth = {}", p_sessionState, requireAuth);
if (!requireAuth) {
try {
p_filterChain.doFilter(p_httpRequest, p_httpResponse);
} catch (Exception ex) {
log.error("Failed to process session authentication", ex);
requireAuth = true;
}
}
if (requireAuth) {
sendError(p_httpResponse);
}
}
private void processBasicAuth(ClientService clientService, ErrorResponseFactory errorResponseFactory,
HttpServletRequest servletRequest, HttpServletResponse servletResponse, FilterChain filterChain) {
boolean requireAuth = true;
try {
String header = servletRequest.getHeader("Authorization");
if (header != null && header.startsWith("Basic ")) {
String base64Token = header.substring(6);
String token = new String(Base64.decodeBase64(base64Token), Util.UTF8_STRING_ENCODING);
String username = "";
String password = "";
int delim = token.indexOf(":");
if (delim != -1) {
username = token.substring(0, delim);
password = token.substring(delim + 1);
}
requireAuth = !StringHelper.equals(username, identity.getCredentials().getUsername())
|| !identity.isLoggedIn();
// Only authenticate if username doesn't match Identity.username
// and user isn't authenticated
if (requireAuth) {
if (!username.equals(identity.getCredentials().getUsername()) || !identity.isLoggedIn()) {
if (servletRequest.getRequestURI().endsWith("/token")) {
Client client = clientService.getClient(username);
if (client == null
|| AuthenticationMethod.CLIENT_SECRET_BASIC != client.getAuthenticationMethod()) {
throw new Exception("The Token Authentication Method is not valid.");
}
}
identity.getCredentials().setUsername(username);
identity.getCredentials().setPassword(password);
requireAuth = !authenticator.authenticateWebService();
}
}
}
if (!requireAuth) {
filterChain.doFilter(servletRequest, servletResponse);
return;
}
} catch (UnsupportedEncodingException ex) {
log.info("Basic authentication failed", ex);
} catch (ServletException ex) {
log.info("Basic authentication failed", ex);
} catch (IOException ex) {
log.info("Basic authentication failed", ex);
} catch (Exception ex) {
log.info("Basic authentication failed", ex);
}
try {
if (requireAuth && !identity.isLoggedIn()) {
sendError(servletResponse);
}
} catch (IOException ex) {
log.error(ex.getMessage(), ex);
}
}
private void processBearerAuth(HttpServletRequest servletRequest, HttpServletResponse servletResponse,
FilterChain filterChain) {
try {
String header = servletRequest.getHeader("Authorization");
if (header != null && header.startsWith("Bearer ")) {
// Immutable object
// servletRequest.getParameterMap().put("access_token", new
// String[]{accessToken});
filterChain.doFilter(servletRequest, servletResponse);
}
} catch (ServletException ex) {
log.info("Bearer authorization failed: {}", ex);
} catch (IOException ex) {
log.info("Bearer authorization failed: {}", ex);
} catch (Exception ex) {
log.info("Bearer authorization failed: {}", ex);
}
}
private void processPostAuth(ClientService clientService, ClientFilterService clientFilterService,
ErrorResponseFactory errorResponseFactory, HttpServletRequest servletRequest,
HttpServletResponse servletResponse, FilterChain filterChain) {
try {
String clientId = "";
String clientSecret = "";
boolean isExistUserPassword = false;
if (StringHelper.isNotEmpty(servletRequest.getParameter("client_id"))
&& StringHelper.isNotEmpty(servletRequest.getParameter("client_secret"))) {
clientId = servletRequest.getParameter("client_id");
clientSecret = servletRequest.getParameter("client_secret");
isExistUserPassword = true;
}
log.trace("isExistUserPassword: {}", isExistUserPassword);
boolean requireAuth = !StringHelper.equals(clientId, identity.getCredentials().getUsername())
|| !identity.isLoggedIn();
log.debug("requireAuth: '{}'", requireAuth);
if (requireAuth) {
if (isExistUserPassword) {
Client client = clientService.getClient(clientId);
if (client != null && AuthenticationMethod.CLIENT_SECRET_POST == client.getAuthenticationMethod()) {
// Only authenticate if username doesn't match
// Identity.username and user isn't authenticated
if (!clientId.equals(identity.getCredentials().getUsername()) || !identity.isLoggedIn()) {
identity.logout();
identity.getCredentials().setUsername(clientId);
identity.getCredentials().setPassword(clientSecret);
requireAuth = !authenticator.authenticateWebService();
} else {
authenticator.configureSessionClient(client);
}
}
} else if (Boolean.TRUE.equals(appConfiguration.getClientAuthenticationFiltersEnabled())) {
String clientDn = clientFilterService
.processAuthenticationFilters(servletRequest.getParameterMap());
if (clientDn != null) {
Client client = clientService.getClientByDn(clientDn);
identity.logout();
identity.getCredentials().setUsername(client.getClientId());
identity.getCredentials().setPassword(null);
requireAuth = !authenticator.authenticateWebService(true);
}
}
}
if (!requireAuth) {
filterChain.doFilter(servletRequest, servletResponse);
return;
}
if (requireAuth && !identity.isLoggedIn()) {
sendError(servletResponse);
}
} catch (ServletException ex) {
log.error("Post authentication failed: {}", ex);
} catch (IOException ex) {
log.error("Post authentication failed: {}", ex);
} catch (Exception ex) {
log.error("Post authentication failed: {}", ex);
}
}
private void processJwtAuth(HttpServletRequest servletRequest, HttpServletResponse servletResponse,
FilterChain filterChain) {
boolean authorized = false;
try {
if (servletRequest.getParameter("client_assertion") != null
&& servletRequest.getParameter("client_assertion_type") != null) {
String clientId = servletRequest.getParameter("client_id");
ClientAssertionType clientAssertionType = ClientAssertionType
.fromString(servletRequest.getParameter("client_assertion_type"));
String encodedAssertion = servletRequest.getParameter("client_assertion");
if (clientAssertionType == ClientAssertionType.JWT_BEARER) {
ClientAssertion clientAssertion = new ClientAssertion(appConfiguration, clientId,
clientAssertionType, encodedAssertion);
String username = clientAssertion.getSubjectIdentifier();
String password = clientAssertion.getClientSecret();
// Only authenticate if username doesn't match
// Identity.username and user isn't authenticated
if (!username.equals(identity.getCredentials().getUsername()) || !identity.isLoggedIn()) {
identity.getCredentials().setUsername(username);
identity.getCredentials().setPassword(password);
authenticator.authenticateWebService(true);
authorized = true;
}
}
}
filterChain.doFilter(servletRequest, servletResponse);
} catch (ServletException ex) {
log.info("JWT authentication failed: {}", ex);
} catch (IOException ex) {
log.info("JWT authentication failed: {}", ex);
} catch (InvalidJwtException ex) {
log.info("JWT authentication failed: {}", ex);
}
try {
if (!authorized) {
sendError(servletResponse);
}
} catch (IOException ex) {
}
}
private void sendError(HttpServletResponse servletResponse) throws IOException {
PrintWriter out = null;
try {
out = servletResponse.getWriter();
servletResponse.setStatus(401);
servletResponse.addHeader("WWW-Authenticate", "Basic realm=\"" + getRealm() + "\"");
servletResponse.setContentType("application/json;charset=UTF-8");
out.write(errorResponseFactory.getErrorAsJson(TokenErrorResponseType.INVALID_CLIENT));
} catch (IOException ex) {
log.error(ex.getMessage(), ex);
} finally {
if (out != null) {
out.close();
}
}
}
public String getRealm() {
if (realm != null) {
return realm;
} else {
return REALM;
}
}
public void setRealm(String realm) {
this.realm = realm;
}
@Override
public void destroy() {
}
}