/*******************************************************************************
* Cloud Foundry
* Copyright (c) [2009-2016] Pivotal Software, Inc. All Rights Reserved.
*
* This product is licensed to you under the Apache License, Version 2.0 (the "License").
* You may not use this product except in compliance with the License.
*
* This product includes a number of subcomponents with
* separate copyright notices and license terms. Your use of these
* subcomponents is subject to the terms and conditions of the
* subcomponent's license, as noted in the LICENSE file.
*******************************************************************************/
package org.cloudfoundry.identity.uaa.authentication;
import com.fasterxml.jackson.core.type.TypeReference;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.cloudfoundry.identity.uaa.codestore.ExpiringCode;
import org.cloudfoundry.identity.uaa.codestore.ExpiringCodeStore;
import org.cloudfoundry.identity.uaa.constants.OriginKeys;
import org.cloudfoundry.identity.uaa.login.PasscodeInformation;
import org.cloudfoundry.identity.uaa.user.UaaUser;
import org.cloudfoundry.identity.uaa.user.UaaUserDatabase;
import org.cloudfoundry.identity.uaa.util.JsonUtils;
import org.cloudfoundry.identity.uaa.zone.IdentityZoneHolder;
import org.hsqldb.lib.StringUtil;
import org.springframework.http.HttpMethod;
import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.authentication.BadCredentialsException;
import org.springframework.security.authentication.InsufficientAuthenticationException;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.userdetails.UsernameNotFoundException;
import org.springframework.security.oauth2.provider.OAuth2RequestFactory;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.IOException;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
/**
* Authentication filter to verify one time passwords with what's cached in the
* one time password store.
*
*
*/
public class PasscodeAuthenticationFilter extends BackwardsCompatibleTokenEndpointAuthenticationFilter {
private final Log logger = LogFactory.getLog(getClass());
private List<String> parameterNames = Collections.emptyList();
public PasscodeAuthenticationFilter(UaaUserDatabase uaaUserDatabase, AuthenticationManager authenticationManager, OAuth2RequestFactory oAuth2RequestFactory, ExpiringCodeStore expiringCodeStore) {
super(
new ExpiringCodeAuthenticationManager(
uaaUserDatabase,
authenticationManager,
LogFactory.getLog(PasscodeAuthenticationFilter.class),
expiringCodeStore,
Collections.singleton(HttpMethod.POST.toString())),
oAuth2RequestFactory);
}
@Override
public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain) throws IOException, ServletException {
PasscodeHttpServletRequest request = new PasscodeHttpServletRequest((HttpServletRequest)req);
super.doFilter(request, res, chain);
}
protected static class ExpiringCodeAuthentication implements Authentication {
private final HttpServletRequest request;
private final String passcode;
public ExpiringCodeAuthentication(HttpServletRequest request, String passcode) {
this.request = request;
this.passcode = passcode;
}
@Override
public Collection<? extends GrantedAuthority> getAuthorities() {
return null;
}
@Override
public Object getCredentials() {
return null;
}
@Override
public Object getDetails() {
return null;
}
@Override
public Object getPrincipal() {
return null;
}
@Override
public boolean isAuthenticated() {
return false;
}
@Override
public void setAuthenticated(boolean isAuthenticated) throws IllegalArgumentException {
}
public HttpServletRequest getRequest() {
return request;
}
public String getPasscode() {
return passcode;
}
@Override
public String getName() {
return getPasscode();
}
}
protected static class PasscodeHttpServletRequest extends HttpServletRequestWrapper {
Map<String, String[]> extendedParameters = new HashMap<>();
public PasscodeHttpServletRequest(HttpServletRequest request) {
super(request);
}
public void addParameter(String name, String[] values) {
extendedParameters.put(name, values);
}
@Override
public Map<String, String[]> getParameterMap() {
Map<String, String[]> result = new HashMap<>(extendedParameters);
result.putAll(super.getParameterMap());
return result;
}
}
protected static class ExpiringCodeAuthenticationManager implements AuthenticationManager {
private final Log logger;
private final ExpiringCodeStore expiringCodeStore;
private final Set<String> methods;
private final AuthenticationManager parent;
private final UaaUserDatabase uaaUserDatabase;
public ExpiringCodeAuthenticationManager(UaaUserDatabase uaaUserDatabase, AuthenticationManager parent, Log logger, ExpiringCodeStore expiringCodeStore, Set<String> methods) {
this.logger = logger;
this.expiringCodeStore = expiringCodeStore;
this.methods = methods;
this.parent = parent;
this.uaaUserDatabase = uaaUserDatabase;
}
protected ExpiringCode doRetrieveCode(String code) {
return expiringCodeStore.retrieveCode(code);
}
@Override
public Authentication authenticate(Authentication authentication) throws AuthenticationException {
if (!(authentication instanceof PasscodeAuthenticationFilter.ExpiringCodeAuthentication)) {
return parent.authenticate(authentication);
} else {
PasscodeAuthenticationFilter.ExpiringCodeAuthentication expiringCodeAuthentication = (PasscodeAuthenticationFilter.ExpiringCodeAuthentication) authentication;
// Validate passcode
logger.debug("Located credentials in request, with passcode");
if (methods != null && !methods.contains(expiringCodeAuthentication.getRequest().getMethod().toUpperCase())) {
throw new BadCredentialsException("Credentials must be sent by (one of methods): " + methods);
}
String passcode = expiringCodeAuthentication.getPasscode();
if (StringUtil.isEmpty(passcode)) {
throw new InsufficientAuthenticationException("Passcode information is missing.");
}
ExpiringCode eCode = doRetrieveCode(passcode);
PasscodeInformation pi = null;
if (eCode != null && eCode.getData() != null) {
try {
pi = JsonUtils.readValue(eCode.getData(), PasscodeInformation.class);
} catch (JsonUtils.JsonUtilException e) {
throw new InsufficientAuthenticationException("Unable to deserialize passcode object.", e);
}
}
if (pi == null) {
throw new InsufficientAuthenticationException("Invalid passcode");
}
logger.debug("Successful passcode authentication request for " + pi.getUsername());
Collection<GrantedAuthority> externalAuthorities = null;
if (null != pi.getAuthorizationParameters()) {
externalAuthorities = (Collection<GrantedAuthority>) pi.getAuthorizationParameters().get("authorities");
}
UaaPrincipal principal = new UaaPrincipal(pi.getUserId(), pi.getUsername(), null, pi.getOrigin(), null,
IdentityZoneHolder.get().getId());
List<? extends GrantedAuthority> authorities;
try {
UaaUser user = uaaUserDatabase.retrieveUserById(pi.getUserId());
authorities = user.getAuthorities();
} catch (UsernameNotFoundException x) {
throw new BadCredentialsException("Invalid user.");
}
Authentication result = new UsernamePasswordAuthenticationToken(
principal,
null,
externalAuthorities == null || externalAuthorities.size() == 0 ? authorities : externalAuthorities
);
//add additional parameters for backwards compatibility
PasscodeHttpServletRequest pcRequest = (PasscodeHttpServletRequest)expiringCodeAuthentication.getRequest();
//pcRequest.addParameter("user_id", new String[] {pi.getUserId()});
pcRequest.addParameter("username", new String[] {pi.getUsername()});
pcRequest.addParameter(OriginKeys.ORIGIN, new String[] {pi.getOrigin()});
return result;
}
}
}
@Override
protected Authentication extractCredentials(HttpServletRequest request) {
String grantType = request.getParameter("grant_type");
if (grantType != null && grantType.equals("password")) {
Map<String, String> credentials = getCredentials(request);
String passcode = credentials.get("passcode");
if (passcode!=null) {
return new ExpiringCodeAuthentication(request, passcode);
} else {
return super.extractCredentials(request);
}
}
return null;
}
private Map<String, String> getCredentials(HttpServletRequest request) {
Map<String, String> credentials = new HashMap<String, String>();
for (String paramName : parameterNames) {
String value = request.getParameter(paramName);
if (value != null) {
if (value.startsWith("{")) {
try {
Map<String, String> jsonCredentials = JsonUtils.readValue(value,
new TypeReference<Map<String, String>>() {
});
credentials.putAll(jsonCredentials);
} catch (JsonUtils.JsonUtilException e) {
logger.warn("Unknown format of value for request param: " + paramName + ". Ignoring.");
}
}
else {
credentials.put(paramName, value);
}
}
}
return credentials;
}
@Override
public void init(FilterConfig filterConfig) throws ServletException {
}
@Override
public void destroy() {
}
public void setParameterNames(List<String> parameterNames) {
this.parameterNames = parameterNames;
}
}