/** * Mule Development Kit * Copyright 2010-2011 (c) MuleSoft, Inc. All rights reserved. http://www.mulesoft.com * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.mule.devkit.generation.adapter; import org.apache.commons.lang.StringUtils; import org.mule.api.annotations.oauth.OAuth2; import org.mule.api.annotations.oauth.OAuthConsumerKey; import org.mule.api.annotations.oauth.OAuthConsumerSecret; import org.mule.api.annotations.oauth.OAuthScope; import org.mule.api.oauth.OAuth2Adapter; import org.mule.api.oauth.UnableToAcquireAccessTokenException; import org.mule.devkit.generation.AbstractOAuthAdapterGenerator; import org.mule.devkit.generation.DevKitTypeElement; import org.mule.devkit.generation.GenerationException; import org.mule.devkit.model.code.Block; import org.mule.devkit.model.code.CatchBlock; import org.mule.devkit.model.code.Conditional; import org.mule.devkit.model.code.DefinedClass; import org.mule.devkit.model.code.ExpressionFactory; import org.mule.devkit.model.code.FieldVariable; import org.mule.devkit.model.code.Invocation; import org.mule.devkit.model.code.Method; import org.mule.devkit.model.code.Modifier; import org.mule.devkit.model.code.Op; import org.mule.devkit.model.code.TryStatement; import org.mule.devkit.model.code.Variable; import org.mule.devkit.model.code.builders.FieldBuilder; import org.mule.util.IOUtils; import java.io.OutputStreamWriter; import java.net.HttpURLConnection; import java.net.URL; import java.net.URLDecoder; import java.net.URLEncoder; import java.util.Date; import java.util.regex.Matcher; import java.util.regex.Pattern; public class OAuth2AdapterGenerator extends AbstractOAuthAdapterGenerator { @Override protected boolean shouldGenerate(DevKitTypeElement typeElement) { return typeElement.hasAnnotation(OAuth2.class); } @Override protected void doGenerate(DevKitTypeElement typeElement) throws GenerationException { DefinedClass oauthAdapter = getOAuthAdapterClass(typeElement, "OAuth2Adapter", OAuth2Adapter.class); OAuth2 oauth2 = typeElement.getAnnotation(OAuth2.class); authorizationCodePatternConstant(oauthAdapter, oauth2.verifierRegex()); accessTokenPatternConstant(oauthAdapter, oauth2); expirationPatternConstant(oauthAdapter, oauth2); muleContextField(oauthAdapter); authorizationCodeField(oauthAdapter); redirectUrlField(oauthAdapter); oauthCallbackField(oauthAdapter); FieldVariable oauthAccessToken = accessTokenField(oauthAdapter); FieldVariable saveAccessTokenCallback = saveAccessTokenCallbackField(oauthAdapter); FieldVariable restoreAccessTokenCallback = restoreAccessTokenCallbackField(oauthAdapter); expirationField(oauthAdapter, typeElement.getAnnotation(OAuth2.class)); DefinedClass messageProcessor = generateMessageProcessorInnerClass(oauthAdapter); generateStartMethod(oauthAdapter); generateStopMethod(oauthAdapter); generateInitialiseMethod(oauthAdapter, messageProcessor, oauth2.callbackPath()); FieldVariable logger = FieldBuilder.newLoggerField(oauthAdapter); generateGetAuthorizationUrlMethod(oauthAdapter, typeElement, oauth2, logger); generateRestoreAccessTokenMethod(oauthAdapter, restoreAccessTokenCallback, logger); generateFetchAccessTokenMethod(oauthAdapter, typeElement, oauth2, saveAccessTokenCallback, logger); generateHasTokenExpiredMethod(oauthAdapter, oauth2); generateResetMethod(oauthAdapter, oauth2); generateHasBeenAuthorizedMethod(oauthAdapter, oauthAccessToken); generateOverrides(typeElement, oauthAdapter, oauthAccessToken, null); } private void accessTokenPatternConstant(DefinedClass oauthAdapter, OAuth2 oauth2) { new FieldBuilder(oauthAdapter).type(Pattern.class).name(ACCESS_CODE_PATTERN_FIELD_NAME).staticField().finalField(). initialValue(ref(Pattern.class).staticInvoke("compile").arg(oauth2.accessTokenRegex())).build(); } private void expirationPatternConstant(DefinedClass oauthAdapter, OAuth2 oauth2) { if (!StringUtils.isEmpty(oauth2.expirationRegex())) { new FieldBuilder(oauthAdapter).type(Pattern.class).name(EXPIRATION_TIME_PATTERN_FIELD_NAME).staticField().finalField(). initialValue(ref(Pattern.class).staticInvoke("compile").arg(oauth2.expirationRegex())).build(); } } private void expirationField(DefinedClass oauthAdapter, OAuth2 oauth2) { if (!StringUtils.isEmpty(oauth2.expirationRegex())) { new FieldBuilder(oauthAdapter).type(Date.class).name(EXPIRATION_FIELD_NAME).setter().build(); } } private void generateGetAuthorizationUrlMethod(DefinedClass oauthAdapter, DevKitTypeElement typeElement, OAuth2 oauth2, FieldVariable logger) { Method getAuthorizationUrl = oauthAdapter.method(Modifier.PUBLIC, context.getCodeModel().VOID, GET_AUTHORIZATION_URL_METHOD_NAME); getAuthorizationUrl.type(ref(String.class)); Variable urlBuilder = getAuthorizationUrl.body().decl(ref(StringBuilder.class), "urlBuilder", ExpressionFactory._new(ref(StringBuilder.class))); getAuthorizationUrl.body().invoke(urlBuilder, "append").arg(oauth2.authorizationUrl()); getAuthorizationUrl.body().invoke(urlBuilder, "append").arg("?"); getAuthorizationUrl.body().invoke(urlBuilder, "append").arg("response_type=code&"); getAuthorizationUrl.body().invoke(urlBuilder, "append").arg("client_id="); getAuthorizationUrl.body().invoke(urlBuilder, "append").arg(ExpressionFactory.invoke(getterMethodForFieldAnnotatedWith(typeElement, OAuthConsumerKey.class))); getAuthorizationUrl.body().invoke(urlBuilder, "append").arg("&redirect_uri="); getAuthorizationUrl.body().invoke(urlBuilder, "append").arg(oauthAdapter.fields().get(REDIRECT_URL_FIELD_NAME)); if (typeElement.hasFieldAnnotatedWith(OAuthScope.class)) { Variable scope = getAuthorizationUrl.body().decl(ref(String.class), "scope", ExpressionFactory.invoke(getterMethodForFieldAnnotatedWith(typeElement, OAuthScope.class))); Block ifScopeNotNull = getAuthorizationUrl.body()._if(Op.ne(scope, ExpressionFactory._null()))._then(); ifScopeNotNull.invoke(urlBuilder, "append").arg("&scope="); ifScopeNotNull.invoke(urlBuilder, "append").arg(scope); } getAuthorizationUrl.body().invoke(logger, "debug").arg(ExpressionFactory.direct("\"OAUth 2 authorization url: \" + urlBuilder")); getAuthorizationUrl.body()._return(urlBuilder.invoke("toString")); } private void generateRestoreAccessTokenMethod(DefinedClass oauthAdapter, FieldVariable restoreAccessTokenCallbackField, FieldVariable logger) { Method restoreAccessTokenMethod = oauthAdapter.method(Modifier.PUBLIC, context.getCodeModel().BOOLEAN, "restoreAccessToken"); Conditional ifRestoreCallbackNotNull = restoreAccessTokenMethod.body()._if(Op.ne(restoreAccessTokenCallbackField, ExpressionFactory._null())); Conditional ifDebugEnabled = ifRestoreCallbackNotNull._then()._if(logger.invoke("isDebugEnabled")); Variable messageStringBuilder = ifDebugEnabled._then().decl(ref(StringBuilder.class), "messageStringBuilder", ExpressionFactory._new(ref(StringBuilder.class))); ifDebugEnabled._then().add(messageStringBuilder.invoke("append").arg("Attempting to restore access token...")); ifDebugEnabled._then().add(logger.invoke("debug").arg(messageStringBuilder.invoke("toString"))); TryStatement tryToRestore = ifRestoreCallbackNotNull._then()._try(); tryToRestore.body().add(restoreAccessTokenCallbackField.invoke("restoreAccessToken")); tryToRestore.body().assign(oauthAdapter.fields().get(OAUTH_ACCESS_TOKEN_FIELD_NAME), restoreAccessTokenCallbackField.invoke("getAccessToken")); ifDebugEnabled = tryToRestore.body()._if(logger.invoke("isDebugEnabled")); messageStringBuilder = ifDebugEnabled._then().decl(ref(StringBuilder.class), "messageStringBuilder", ExpressionFactory._new(ref(StringBuilder.class))); ifDebugEnabled._then().add(messageStringBuilder.invoke("append").arg("Access token and secret has been restored successfully ")); ifDebugEnabled._then().add(messageStringBuilder.invoke("append").arg(ExpressionFactory.lit("[accessToken = "))); ifDebugEnabled._then().add(messageStringBuilder.invoke("append").arg(restoreAccessTokenCallbackField.invoke("getAccessToken"))); ifDebugEnabled._then().add(messageStringBuilder.invoke("append").arg(ExpressionFactory.lit("] "))); ifDebugEnabled._then().add(logger.invoke("debug").arg(messageStringBuilder.invoke("toString"))); tryToRestore.body()._return(ExpressionFactory.TRUE); CatchBlock logIfCannotRestore = tryToRestore._catch(ref(Exception.class)); Variable e = logIfCannotRestore.param("e"); logIfCannotRestore.body().add(logger.invoke("error").arg("Cannot restore access token, an unexpected error occurred").arg(e)); restoreAccessTokenMethod.body()._return(ExpressionFactory.FALSE); } private void generateFetchAccessTokenMethod(DefinedClass oauthAdapter, DevKitTypeElement typeElement, OAuth2 oauth2, FieldVariable saveAccessTokenCallback, FieldVariable logger) { Method fetchAccessToken = oauthAdapter.method(Modifier.PUBLIC, context.getCodeModel().VOID, "fetchAccessToken"); fetchAccessToken._throws(ref(UnableToAcquireAccessTokenException.class)); fetchAccessToken.body().invoke("restoreAccessToken"); Conditional ifAccessTokenNull = fetchAccessToken.body()._if(Op.eq(oauthAdapter.fields().get(OAUTH_ACCESS_TOKEN_FIELD_NAME), ExpressionFactory._null())); TryStatement tryStatement = ifAccessTokenNull._then()._try(); Conditional ifDebugEnabled = tryStatement.body()._if(logger.invoke("isDebugEnabled")); Variable messageStringBuilder = ifDebugEnabled._then().decl(ref(StringBuilder.class), "messageStringBuilder", ExpressionFactory._new(ref(StringBuilder.class))); ifDebugEnabled._then().add(messageStringBuilder.invoke("append").arg("Retrieving access token...")); ifDebugEnabled._then().add(logger.invoke("debug").arg(messageStringBuilder.invoke("toString"))); Block body = tryStatement.body(); Variable conn = body.decl(ref(HttpURLConnection.class), "conn", ExpressionFactory.cast(ref(HttpURLConnection.class), ExpressionFactory._new(ref(URL.class)).arg(oauth2.accessTokenUrl()).invoke("openConnection"))); body.invoke(conn, "setRequestMethod").arg("POST"); body.invoke(conn, "setDoOutput").arg(ExpressionFactory.lit(true)); Invocation consumerKey = ExpressionFactory.invoke(getterMethodForFieldAnnotatedWith(typeElement, OAuthConsumerKey.class)); Invocation consumerSecret = ExpressionFactory.invoke(getterMethodForFieldAnnotatedWith(typeElement, OAuthConsumerSecret.class)); Variable builder = body.decl(ref(StringBuilder.class), "builder", ExpressionFactory._new(ref(StringBuilder.class))); body.invoke(builder, "append").arg("code="); body.invoke(builder, "append").arg(ref(URLEncoder.class).staticInvoke("encode").arg(oauthAdapter.fields().get(VERIFIER_FIELD_NAME)).arg(ENCODING)); body.invoke(builder, "append").arg("&client_id="); body.invoke(builder, "append").arg(ref(URLEncoder.class).staticInvoke("encode").arg(consumerKey).arg(ENCODING)); body.invoke(builder, "append").arg("&client_secret="); body.invoke(builder, "append").arg(ref(URLEncoder.class).staticInvoke("encode").arg(consumerSecret).arg(ENCODING)); body.invoke(builder, "append").arg("&grant_type="); body.invoke(builder, "append").arg(ref(URLEncoder.class).staticInvoke("encode").arg(GRANT_TYPE).arg(ENCODING)); body.invoke(builder, "append").arg("&redirect_uri="); body.invoke(builder, "append").arg(ref(URLEncoder.class).staticInvoke("encode").arg(oauthAdapter.fields().get(REDIRECT_URL_FIELD_NAME)).arg(ENCODING)); ifDebugEnabled = tryStatement.body()._if(logger.invoke("isDebugEnabled")); messageStringBuilder = ifDebugEnabled._then().decl(ref(StringBuilder.class), "messageStringBuilder", ExpressionFactory._new(ref(StringBuilder.class))); ifDebugEnabled._then().add(messageStringBuilder.invoke("append").arg("Sending request to [")); ifDebugEnabled._then().add(messageStringBuilder.invoke("append").arg(oauth2.accessTokenUrl())); ifDebugEnabled._then().add(messageStringBuilder.invoke("append").arg("] using the following as content [")); ifDebugEnabled._then().add(messageStringBuilder.invoke("append").arg(builder.invoke("toString"))); ifDebugEnabled._then().add(messageStringBuilder.invoke("append").arg("]")); ifDebugEnabled._then().add(logger.invoke("debug").arg(messageStringBuilder.invoke("toString"))); Variable out = body.decl(ref(OutputStreamWriter.class), "out", ExpressionFactory._new(ref(OutputStreamWriter.class)).arg(conn.invoke("getOutputStream"))); body.invoke(out, "write").arg(builder.invoke("toString")); body.invoke(out, "close"); Variable response = body.decl(ref(String.class), "response", ref(IOUtils.class).staticInvoke("toString").arg(conn.invoke("getInputStream"))); ifDebugEnabled = tryStatement.body()._if(logger.invoke("isDebugEnabled")); messageStringBuilder = ifDebugEnabled._then().decl(ref(StringBuilder.class), "messageStringBuilder", ExpressionFactory._new(ref(StringBuilder.class))); ifDebugEnabled._then().add(messageStringBuilder.invoke("append").arg("Received response [")); ifDebugEnabled._then().add(messageStringBuilder.invoke("append").arg(response)); ifDebugEnabled._then().add(messageStringBuilder.invoke("append").arg("]")); ifDebugEnabled._then().add(logger.invoke("debug").arg(messageStringBuilder.invoke("toString"))); Variable matcher = body.decl(ref(Matcher.class), "matcher", oauthAdapter.fields().get(ACCESS_CODE_PATTERN_FIELD_NAME).invoke("matcher").arg(response)); Conditional ifAccessTokenFound = body._if(Op.cand(matcher.invoke("find"), Op.gte(matcher.invoke("groupCount"), ExpressionFactory.lit(1)))); Invocation group = matcher.invoke("group").arg(ExpressionFactory.lit(1)); ifAccessTokenFound._then().assign(oauthAdapter.fields().get(ACCESS_TOKEN_FIELD_NAME), ref(URLDecoder.class).staticInvoke("decode").arg(group).arg(ENCODING)); ifDebugEnabled = ifAccessTokenFound._then()._if(logger.invoke("isDebugEnabled")); messageStringBuilder = ifDebugEnabled._then().decl(ref(StringBuilder.class), "messageStringBuilder", ExpressionFactory._new(ref(StringBuilder.class))); ifDebugEnabled._then().add(messageStringBuilder.invoke("append").arg("Access token retrieved successfully ")); ifDebugEnabled._then().add(messageStringBuilder.invoke("append").arg(ExpressionFactory.lit("[accessToken = "))); ifDebugEnabled._then().add(messageStringBuilder.invoke("append").arg(oauthAdapter.fields().get(ACCESS_TOKEN_FIELD_NAME))); ifDebugEnabled._then().add(messageStringBuilder.invoke("append").arg(ExpressionFactory.lit("] "))); ifDebugEnabled._then().add(logger.invoke("debug").arg(messageStringBuilder.invoke("toString"))); ifAccessTokenFound._else()._throw(ExpressionFactory._new( ref(Exception.class)).arg(ref(String.class).staticInvoke("format").arg("OAuth access token could not be extracted from: %s").arg(response))); Conditional ifSaveCallbackNotNull = ifAccessTokenFound._then()._if(Op.ne(saveAccessTokenCallback, ExpressionFactory._null())); Invocation saveAccessToken = saveAccessTokenCallback.invoke("saveAccessToken").arg(oauthAdapter.fields().get(OAUTH_ACCESS_TOKEN_FIELD_NAME)) .arg(ExpressionFactory._null()); TryStatement tryToSave = ifSaveCallbackNotNull._then()._try(); ifDebugEnabled = ifSaveCallbackNotNull._then()._if(logger.invoke("isDebugEnabled")); messageStringBuilder = ifDebugEnabled._then().decl(ref(StringBuilder.class), "messageStringBuilder", ExpressionFactory._new(ref(StringBuilder.class))); ifDebugEnabled._then().add(messageStringBuilder.invoke("append").arg("Attempting to save access token...")); ifDebugEnabled._then().add(messageStringBuilder.invoke("append").arg(ExpressionFactory.lit("[accessToken = "))); ifDebugEnabled._then().add(messageStringBuilder.invoke("append").arg(oauthAdapter.fields().get(OAUTH_ACCESS_TOKEN_FIELD_NAME))); ifDebugEnabled._then().add(messageStringBuilder.invoke("append").arg(ExpressionFactory.lit("] "))); ifDebugEnabled._then().add(logger.invoke("debug").arg(messageStringBuilder.invoke("toString"))); tryToSave.body().add(saveAccessToken); CatchBlock logIfCannotSave = tryToSave._catch(ref(Exception.class)); Variable e2 = logIfCannotSave.param("e"); logIfCannotSave.body().add(logger.invoke("error").arg("Cannot save access token, an unexpected error occurred").arg(e2)); if (!StringUtils.isEmpty(oauth2.expirationRegex())) { ifDebugEnabled = ifAccessTokenFound._then()._if(logger.invoke("isDebugEnabled")); messageStringBuilder = ifDebugEnabled._then().decl(ref(StringBuilder.class), "messageStringBuilder", ExpressionFactory._new(ref(StringBuilder.class))); ifDebugEnabled._then().add(messageStringBuilder.invoke("append").arg("Attempting to extract expiration time using ")); ifDebugEnabled._then().add(messageStringBuilder.invoke("append").arg(ExpressionFactory.lit("[expirationPattern = "))); ifDebugEnabled._then().add(messageStringBuilder.invoke("append").arg(oauth2.expirationRegex())); ifDebugEnabled._then().add(messageStringBuilder.invoke("append").arg(ExpressionFactory.lit("] "))); ifDebugEnabled._then().add(logger.invoke("debug").arg(messageStringBuilder.invoke("toString"))); Variable expirationMatcher = ifAccessTokenFound._then().decl(ref(Matcher.class), "expirationMatcher", oauthAdapter.fields().get(EXPIRATION_TIME_PATTERN_FIELD_NAME).invoke("matcher").arg(response)); Conditional ifExpirationFound = ifAccessTokenFound._then()._if(Op.cand(expirationMatcher.invoke("find"), Op.gte(expirationMatcher.invoke("groupCount"), ExpressionFactory.lit(1)))); Variable seconds = ifExpirationFound._then().decl(ref(Long.class), "expirationSecsAhead", ref(Long.class).staticInvoke("parseLong").arg(expirationMatcher.invoke("group").arg(ExpressionFactory.lit(1)))); ifExpirationFound._then().assign(oauthAdapter.fields().get(EXPIRATION_FIELD_NAME), ExpressionFactory._new(ref(Date.class)).arg( Op.plus(ref(System.class).staticInvoke("currentTimeMillis"), Op.mul(seconds, ExpressionFactory.lit(1000))))); ifDebugEnabled = ifExpirationFound._then()._if(logger.invoke("isDebugEnabled")); messageStringBuilder = ifDebugEnabled._then().decl(ref(StringBuilder.class), "messageStringBuilder", ExpressionFactory._new(ref(StringBuilder.class))); ifDebugEnabled._then().add(messageStringBuilder.invoke("append").arg("Token expiration extracted successfully ")); ifDebugEnabled._then().add(messageStringBuilder.invoke("append").arg(ExpressionFactory.lit("[expiration = "))); ifDebugEnabled._then().add(messageStringBuilder.invoke("append").arg(oauthAdapter.fields().get(EXPIRATION_FIELD_NAME))); ifDebugEnabled._then().add(messageStringBuilder.invoke("append").arg(ExpressionFactory.lit("] "))); ifDebugEnabled._then().add(logger.invoke("debug").arg(messageStringBuilder.invoke("toString"))); ifDebugEnabled = ifExpirationFound._else()._if(logger.invoke("isDebugEnabled")); messageStringBuilder = ifDebugEnabled._then().decl(ref(StringBuilder.class), "messageStringBuilder", ExpressionFactory._new(ref(StringBuilder.class))); ifDebugEnabled._then().add(messageStringBuilder.invoke("append").arg("Token expiration could not be extracted from ")); ifDebugEnabled._then().add(messageStringBuilder.invoke("append").arg(ExpressionFactory.lit("[response = "))); ifDebugEnabled._then().add(messageStringBuilder.invoke("append").arg(response)); ifDebugEnabled._then().add(messageStringBuilder.invoke("append").arg(ExpressionFactory.lit("] "))); ifDebugEnabled._then().add(logger.invoke("debug").arg(messageStringBuilder.invoke("toString"))); } generateReThrow(tryStatement, Exception.class, RuntimeException.class); } private void generateHasTokenExpiredMethod(DefinedClass oauthAdapter, OAuth2 oauth2) { Method hasTokenExpired = oauthAdapter.method(Modifier.PUBLIC, context.getCodeModel().BOOLEAN, HAS_TOKEN_EXPIRED_METHOD_NAME); if (!StringUtils.isEmpty(oauth2.expirationRegex())) { FieldVariable expirationDate = oauthAdapter.fields().get(EXPIRATION_FIELD_NAME); hasTokenExpired.body()._return(Op.cand( Op.ne(expirationDate, ExpressionFactory._null()), expirationDate.invoke("before").arg(ExpressionFactory._new(ref(Date.class))))); } else { hasTokenExpired.body()._return(ExpressionFactory.FALSE); } } private void generateResetMethod(DefinedClass oauthAdapter, OAuth2 oauth2) { Method reset = oauthAdapter.method(Modifier.PUBLIC, context.getCodeModel().VOID, RESET_METHOD_NAME); if (!StringUtils.isEmpty(oauth2.expirationRegex())) { reset.body().assign(oauthAdapter.fields().get(EXPIRATION_FIELD_NAME), ExpressionFactory._null()); } reset.body().assign(oauthAdapter.fields().get(VERIFIER_FIELD_NAME), ExpressionFactory._null()); reset.body().assign(oauthAdapter.fields().get(ACCESS_TOKEN_FIELD_NAME), ExpressionFactory._null()); } private void generateReThrow(TryStatement tryStatement, Class<? extends Exception> exceptionToCatch, Class<? extends Exception> exceptionToThrow) { CatchBlock catchBlock = tryStatement._catch(ref(exceptionToCatch)); Variable caughtException = catchBlock.param("e"); catchBlock.body()._throw(ExpressionFactory._new(ref(exceptionToThrow)).arg(caughtException)); } }