package com.feth.play.module.pa.providers.oauth2; import com.feth.play.module.pa.PlayAuthenticate; import com.feth.play.module.pa.exceptions.*; import com.feth.play.module.pa.providers.ext.ExternalAuthProvider; import com.feth.play.module.pa.user.AuthUser; import com.feth.play.module.pa.user.AuthUserIdentity; import org.apache.http.NameValuePair; import org.apache.http.client.utils.URLEncodedUtils; import org.apache.http.message.BasicNameValuePair; import play.Configuration; import play.Logger; import play.i18n.Messages; import play.inject.ApplicationLifecycle; import play.libs.ws.WSClient; import play.libs.ws.WSRequest; import play.libs.ws.WSResponse; import play.mvc.Http.Context; import play.mvc.Http.Request; import play.mvc.Http.Session; import java.util.*; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeoutException; import static java.util.concurrent.TimeUnit.MILLISECONDS; public abstract class OAuth2AuthProvider<U extends AuthUserIdentity, I extends OAuth2AuthInfo> extends ExternalAuthProvider { protected class QueryParam { private String param; private String value; public QueryParam(String param, String value) { this.param = param; this.value = value; } } private static final String STATE_TOKEN = "pa.oauth2.state"; protected static final String CONTENT_TYPE = "Content-Type"; protected final WSClient wsClient; public OAuth2AuthProvider(final PlayAuthenticate auth, final ApplicationLifecycle lifecycle, final WSClient wsClient) { super(auth, lifecycle); this.wsClient = wsClient; } @Override protected List<String> neededSettingKeys() { final List<String> ret = new ArrayList<String>(); ret.addAll(super.neededSettingKeys()); ret.add(SettingKeys.ACCESS_TOKEN_URL); ret.add(SettingKeys.AUTHORIZATION_URL); ret.add(SettingKeys.CLIENT_ID); ret.add(SettingKeys.CLIENT_SECRET); return ret; } public static abstract class SettingKeys { public static final String AUTHORIZATION_URL = "authorizationUrl"; public static final String ACCESS_TOKEN_URL = "accessTokenUrl"; public static final String CLIENT_ID = "clientId"; public static final String CLIENT_SECRET = "clientSecret"; public static final String SCOPE = "scope"; public static final String ACCESS_TYPE = "accessType"; public static final String APPROVAL_PROMPT = "approvalPrompt"; } public static abstract class Constants { public static final String CLIENT_ID = "client_id"; public static final String CLIENT_SECRET = "client_secret"; public static final String REDIRECT_URI = "redirect_uri"; public static final String SCOPE = "scope"; public static final String ACCESS_TYPE = "access_type"; public static final String APPROVAL_PROMPT = "approval_prompt"; public static final String RESPONSE_TYPE = "response_type"; public static final String STATE = "state"; public static final String GRANT_TYPE = "grant_type"; public static final String AUTHORIZATION_CODE = "authorization_code"; public static final String ACCESS_TOKEN = "access_token"; public static final String ERROR = "error"; public static final String CODE = "code"; public static final String TOKEN_TYPE = "token_type"; public static final String EXPIRES_IN = "expires_in"; public static final String REFRESH_TOKEN = "refresh_token"; public static final String ACCESS_DENIED = "access_denied"; public static final String REDIRECT_URI_MISMATCH = "redirect_uri_mismatch"; } protected WSResponse fetchAuthResponse(String url, QueryParam...params) throws AuthException { final List<QueryParam> queryParams = Arrays.asList(params); final WSRequest request = wsClient.url(url); for(QueryParam param : queryParams) { request.setQueryParameter(param.param, param.value); } try { return request.get() .toCompletableFuture() .get(getTimeout(), MILLISECONDS); } catch(InterruptedException | ExecutionException | TimeoutException e) { throw new AuthException(e.getMessage(), e); } } protected String getAccessTokenParams(final Configuration c, final String code, Request request) throws ResolverMissingException { final List<NameValuePair> params = getParams(request, c); params.add(new BasicNameValuePair(Constants.CLIENT_SECRET, c .getString(SettingKeys.CLIENT_SECRET))); params.add(new BasicNameValuePair(Constants.GRANT_TYPE, Constants.AUTHORIZATION_CODE)); params.add(new BasicNameValuePair(Constants.CODE, code)); return URLEncodedUtils.format(params, "UTF-8"); } protected Map<String, String> getHeaders() { return Collections.<String, String>emptyMap(); } protected I getAccessToken(final String code, final Request request) throws AccessTokenException, ResolverMissingException { final Configuration c = getConfiguration(); final String params = getAccessTokenParams(c, code, request); final String url = c.getString(SettingKeys.ACCESS_TOKEN_URL); final WSRequest wrh = wsClient.url(url); wrh.setHeader(CONTENT_TYPE, "application/x-www-form-urlencoded"); for(final Map.Entry<String, String> header : getHeaders().entrySet()) { wrh.setHeader(header.getKey(), header.getValue()); } try { final WSResponse r = wrh.post(params).toCompletableFuture().get(PlayAuthenticate.TIMEOUT, MILLISECONDS); return buildInfo(r); } catch(InterruptedException | ExecutionException | TimeoutException e) { throw new AccessTokenException(e.getMessage(), e); } } protected abstract I buildInfo(final WSResponse r) throws AccessTokenException; protected String getAuthUrl(final Request request, final String state) throws AuthException { final Configuration c = getConfiguration(); final List<NameValuePair> params = getAuthParams(c, request, state); return generateURI(c.getString(SettingKeys.AUTHORIZATION_URL), params); } protected List<NameValuePair> getAuthParams(final Configuration c, final Request request, final String state) throws AuthException { final List<NameValuePair> params = getParams(request, c); if (c.getString(SettingKeys.SCOPE) != null) { params.add(new BasicNameValuePair(Constants.SCOPE, c .getString(SettingKeys.SCOPE))); } params.add(new BasicNameValuePair(Constants.RESPONSE_TYPE, Constants.CODE)); if (c.getString(SettingKeys.ACCESS_TYPE) != null) { params.add(new BasicNameValuePair(Constants.ACCESS_TYPE, c .getString(SettingKeys.ACCESS_TYPE))); } if (c.getString(SettingKeys.APPROVAL_PROMPT) != null) { params.add(new BasicNameValuePair(Constants.APPROVAL_PROMPT, c .getString(SettingKeys.APPROVAL_PROMPT))); } if (state != null) { params.add(new BasicNameValuePair(Constants.STATE, state)); } return params; } protected List<NameValuePair> getParams(final Request request, final Configuration c) throws ResolverMissingException { final List<NameValuePair> params = new ArrayList<NameValuePair>(); params.add(new BasicNameValuePair(Constants.CLIENT_ID, c .getString(SettingKeys.CLIENT_ID))); params.add(new BasicNameValuePair(getRedirectUriKey(), getRedirectUrl(request))); return params; } protected String getRedirectUriKey() { return Constants.REDIRECT_URI; } @Override public Object authenticate(final Context context, final Object payload) throws AuthException { final Request request = context.request(); if (Logger.isDebugEnabled()) { Logger.debug("Returned with URL: '" + request.uri() + "'"); } final String error = request.getQueryString(getErrorParameterKey()); if (error != null) { if (error.equals(Constants.ACCESS_DENIED)) { throw new AccessDeniedException(getKey()); } else if (error.equals(Constants.REDIRECT_URI_MISMATCH)) { Logger.error("You must set the redirect URI for your provider to whatever you defined in your routes file." + "For this provider it is: '" + getRedirectUrl(request) + "'"); throw new RedirectUriMismatch(); } else { throw new AuthException(error); } } else if (isCallbackRequest(context)) { // second step in auth process final UUID storedState = this.auth.getFromCache(context.session(), STATE_TOKEN); if(storedState == null) { Logger.warn("Cache either timed out, or you are using a setup with multiple servers and a non-shared cache implementation"); // we will just behave as if there was no auth, yet... return generateRedirectUrl(context, request); } final String callbackState = request.getQueryString(Constants.STATE); if(!storedState.equals(UUID.fromString(callbackState))) { // the return callback may have been forged throw new AuthException(Messages.get("playauthenticate.core.exception.oauth2.state_param_forged")); } final String code = request.getQueryString(Constants.CODE); final I info = getAccessToken(code, request); return transform(info, callbackState); } else { // no auth, yet return generateRedirectUrl(context, request); } } private String generateRedirectUrl(Context context, Request request) throws AuthException { final UUID state = UUID.randomUUID(); this.auth.storeInCache(context.session(), STATE_TOKEN, state); final String url = getAuthUrl(request, state.toString()); Logger.debug("generated redirect URL for dialog: " + url); return url; } protected boolean isCallbackRequest(final Context context) { return context.request().queryString().containsKey(Constants.CODE); } protected String getErrorParameterKey() { return Constants.ERROR; } @Override public void afterSave(final AuthUser user, final Object identity, final Session session) { this.auth.removeFromCache(session, STATE_TOKEN); } /** * This allows custom implementations to enrich an AuthUser object or * provide their own implementation * * @param info * @param state * @return * @throws AuthException */ protected abstract AuthUserIdentity transform(final I info, final String state) throws AuthException; }