package kickr.security; import java.util.stream.Stream; import javax.servlet.http.Cookie; import kickr.security.service.AuthenticationService; import javax.servlet.http.HttpServletRequest; import javax.ws.rs.NotAuthorizedException; import javax.ws.rs.core.HttpHeaders; import javax.ws.rs.core.Response; import kickr.db.entity.user.User; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import support.security.AuthenticationException; import support.security.SecurityContextFactory; import support.security.TypedSecurityContext; import support.transactional.WithTransaction; /** * * @author nikku */ public class UserSecurityContextFactory extends SecurityContextFactory<User> { private static final Logger LOGGER = LoggerFactory.getLogger(UserSecurityContextFactory.class); private static final String SCHEME = "Token"; private static final String CHALLENGE_FORMAT = SCHEME + " realm=\"%s\""; private final String realm; private final WithTransaction transactional; private final AuthenticationService authenticationService; public UserSecurityContextFactory( String realm, WithTransaction transactional, AuthenticationService authenticationService) { this.realm = realm; this.transactional = transactional; this.authenticationService = authenticationService; } @Override public TypedSecurityContext<User> createSecurityContext(HttpServletRequest request) { if (request == null) { return null; } String address = request.getRemoteAddr(); boolean secure = request.isSecure(); String headerToken = extractHeaderToken(request); if (headerToken != null) { // if header token is provided, // we must be able to authenticate TokenCredentials headerCredentials = new TokenCredentials(headerToken, address); try { // create context based on header token auth return createContext(secure, SCHEME, authenticate(headerCredentials)); } catch (AuthenticationException ex) { LOGGER.warn("failed to authenticate <" + headerCredentials + ">"); // indicate authentication failure throw new NotAuthorizedException( "failed to authenticate", Response .status(Response.Status.UNAUTHORIZED) .header(HttpHeaders.WWW_AUTHENTICATE, String.format(CHALLENGE_FORMAT, realm)) .entity("invalid credentials") .build()); } } String cookieToken = extractCookieToken(request); if (cookieToken != null) { // if cookie token is provided, // authentication is optional TokenCredentials cookieCredentials = new TokenCredentials(cookieToken, address); try { // create context based on header token auth return createContext(secure, SCHEME, authenticate(cookieCredentials)); } catch (AuthenticationException ex) { LOGGER.warn("failed to authenticate <" + cookieCredentials + ">"); } } return createContext(secure, SCHEME, null); } protected String extractHeaderToken(HttpServletRequest request) { final String header = request.getHeader(HttpHeaders.AUTHORIZATION); if (header != null) { final int space = header.indexOf(" "); if (space > 0) { final String method = header.substring(0, space); if (SCHEME.equalsIgnoreCase(method)) { return header.substring(space + 1); } } } return null; } protected String extractCookieToken(HttpServletRequest request) { Cookie[] cookies = request.getCookies(); if (cookies != null) { return Stream.of(cookies) .filter(c -> "__sid".equals(c.getName())) .map(c -> c.getValue()) .findAny() .orElse(null); } return null; } protected User authenticate(TokenCredentials credentials) { return transactional.get(() -> { return authenticationService.authenticate(credentials.token); }); } private TypedSecurityContext<User> createContext(boolean secure, String scheme, User principal) { return new UserSecurityContext(secure, scheme, principal); } protected static class TokenCredentials { public final String token; public final String address; public TokenCredentials(String token, String address) { this.token = token; this.address = address; } } }