/*
* Copyright 2013-2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.cloudfoundry.reactor.tokenprovider;
import io.jsonwebtoken.Claims;
import io.jsonwebtoken.Jwts;
import io.netty.util.AsciiString;
import org.cloudfoundry.reactor.ConnectionContext;
import org.cloudfoundry.reactor.TokenProvider;
import org.cloudfoundry.reactor.util.ErrorPayloadMapper;
import org.cloudfoundry.reactor.util.JsonCodec;
import org.cloudfoundry.reactor.util.NetworkLogging;
import org.cloudfoundry.reactor.util.UserAgent;
import org.cloudfoundry.uaa.UaaException;
import org.immutables.value.Value;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.web.util.UriComponentsBuilder;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.publisher.ReplayProcessor;
import reactor.ipc.netty.http.client.HttpClientRequest;
import reactor.ipc.netty.http.client.HttpClientResponse;
import java.time.LocalDateTime;
import java.time.ZoneId;
import java.util.Base64;
import java.util.Date;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.function.Consumer;
import java.util.function.Function;
import static io.netty.handler.codec.http.HttpHeaderNames.ACCEPT;
import static io.netty.handler.codec.http.HttpHeaderNames.AUTHORIZATION;
import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_TYPE;
import static io.netty.handler.codec.http.HttpHeaderValues.APPLICATION_JSON;
import static io.netty.handler.codec.http.HttpHeaderValues.APPLICATION_X_WWW_FORM_URLENCODED;
import static io.netty.handler.codec.http.HttpResponseStatus.UNAUTHORIZED;
/**
* An abstract base class for all token providers that interact with the UAA. It encapsulates the logic to refresh the token before expiration.
*/
public abstract class AbstractUaaTokenProvider implements TokenProvider {
private static final Logger LOGGER = LoggerFactory.getLogger("cloudfoundry-client.token");
private static final String ACCESS_TOKEN = "access_token";
private static final String AUTHORIZATION_ENDPOINT = "authorization_endpoint";
private static final String REFRESH_TOKEN = "refresh_token";
private static final String TOKEN_TYPE = "token_type";
private static final ZoneId UTC = ZoneId.of("UTC");
private final ConcurrentMap<ConnectionContext, Mono<String>> accessTokens = new ConcurrentHashMap<>(1);
private final ConcurrentMap<ConnectionContext, ReplayProcessor<String>> refreshTokenStreams = new ConcurrentHashMap<>(1);
private final ConcurrentMap<ConnectionContext, Mono<String>> refreshTokens = new ConcurrentHashMap<>(1);
/**
* The client id. Defaults to {@code cf}.
*/
@Value.Default
public String getClientId() {
return "cf";
}
/**
* The client secret Defaults to {@code ""}.
*/
@Value.Default
public String getClientSecret() {
return "";
}
/**
* Returns a {@link Flux} of refresh tokens for a connection
*
* @param connectionContext A {@link ConnectionContext} to be used to identity which connection the refresh tokens be retrieved for
* @return a {@link Flux} that emits the last token on subscribe and new refresh tokens as they are negotiated
*/
public Flux<String> getRefreshTokens(ConnectionContext connectionContext) {
return getRefreshTokenStream(connectionContext);
}
@Override
public final Mono<String> getToken(ConnectionContext connectionContext) {
return this.accessTokens.computeIfAbsent(connectionContext, this::token);
}
@Override
public void invalidate(ConnectionContext connectionContext) {
this.accessTokens.put(connectionContext, token(connectionContext));
}
/**
* Transforms a {@code Mono} in order to make a request to negotiate an access token
*
* @param outbound the {@link Mono} to transform to perform the token request
*/
abstract Mono<Void> tokenRequestTransformer(Mono<HttpClientRequest> outbound);
private static HttpClientRequest addContentTypes(HttpClientRequest request) {
return request
.header(ACCEPT, APPLICATION_JSON)
.header(CONTENT_TYPE, APPLICATION_X_WWW_FORM_URLENCODED);
}
private static HttpClientRequest disableChunkedTransfer(HttpClientRequest request) {
return request.chunkedTransfer(false);
}
private static HttpClientRequest disableFailOnError(HttpClientRequest request) {
return request
.failOnClientError(false)
.failOnServerError(false);
}
private static String extractAccessToken(Map<String, String> payload) {
String accessToken = payload.get(ACCESS_TOKEN);
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("Access Token: {}", accessToken);
parseToken(accessToken)
.ifPresent(claims -> {
LOGGER.debug("Access Token Issued At: {} UTC", toLocalDateTime(claims.getIssuedAt()));
LOGGER.debug("Access Token Expires At: {} UTC", toLocalDateTime(claims.getExpiration()));
});
}
return String.format("%s %s", payload.get(TOKEN_TYPE), accessToken);
}
private static String getTokenUri(String root) {
return UriComponentsBuilder.fromUriString(root)
.pathSegment("oauth", "token")
.build().encode().toUriString();
}
private static Optional<Claims> parseToken(String token) {
try {
String jws = token.substring(0, token.lastIndexOf('.') + 1);
return Optional.of(Jwts.parser().parseClaimsJwt(jws).getBody());
} catch (Exception e) {
return Optional.empty();
}
}
private static LocalDateTime toLocalDateTime(Date date) {
return LocalDateTime.from(date.toInstant().atZone(UTC));
}
private HttpClientRequest addAuthorization(HttpClientRequest request) {
String encoded = Base64.getEncoder().encodeToString(new AsciiString(getClientId()).concat(":").concat(getClientSecret()).toByteArray());
return request.header(AUTHORIZATION, String.format("Basic %s", encoded));
}
private Consumer<Map<String, String>> extractRefreshToken(ConnectionContext connectionContext) {
return payload -> Optional.ofNullable(payload.get(REFRESH_TOKEN))
.ifPresent(refreshToken -> {
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("Refresh Token: {}", refreshToken);
parseToken(refreshToken)
.ifPresent(claims -> {
LOGGER.debug("Refresh Token Issued At: {} UTC", toLocalDateTime(claims.getIssuedAt()));
LOGGER.debug("Refresh Token Expires At: {} UTC", toLocalDateTime(claims.getExpiration()));
});
}
this.refreshTokens.put(connectionContext, Mono.just(refreshToken));
getRefreshTokenStream(connectionContext).onNext(refreshToken);
});
}
@SuppressWarnings("unchecked")
private Function<Mono<HttpClientResponse>, Mono<String>> extractTokens(ConnectionContext connectionContext) {
return inbound -> inbound
.transform(JsonCodec.decode(connectionContext.getObjectMapper(), Map.class))
.map(payload -> (Map<String, String>) payload)
.doOnNext(extractRefreshToken(connectionContext))
.map(AbstractUaaTokenProvider::extractAccessToken);
}
private ReplayProcessor<String> getRefreshTokenStream(ConnectionContext connectionContext) {
return this.refreshTokenStreams.computeIfAbsent(connectionContext, c -> ReplayProcessor.create(1));
}
private Mono<HttpClientResponse> primaryToken(ConnectionContext connectionContext) {
return requestToken(connectionContext, this::tokenRequestTransformer);
}
private Mono<HttpClientResponse> refreshToken(ConnectionContext connectionContext, String refreshToken) {
return requestToken(connectionContext, refreshTokenGrantTokenRequestTransformer(refreshToken))
.onErrorResume(t -> t instanceof UaaException && ((UaaException) t).getStatusCode() == UNAUTHORIZED.code(), t -> Mono.empty());
}
private Function<Mono<HttpClientRequest>, Mono<Void>> refreshTokenGrantTokenRequestTransformer(String refreshToken) {
return outbound -> outbound
.then(request -> request
.sendForm(form -> form
.multipart(false)
.attr("client_id", getClientId())
.attr("client_secret", getClientSecret())
.attr("grant_type", "refresh_token")
.attr("refresh_token", refreshToken))
.then());
}
private Mono<HttpClientResponse> requestToken(ConnectionContext connectionContext, Function<Mono<HttpClientRequest>, Mono<Void>> tokenRequestTransformer) {
return connectionContext
.getRoot(AUTHORIZATION_ENDPOINT)
.map(AbstractUaaTokenProvider::getTokenUri)
.then(uri -> connectionContext.getHttpClient()
.post(uri, request -> Mono.just(request)
.map(AbstractUaaTokenProvider::disableChunkedTransfer)
.map(AbstractUaaTokenProvider::disableFailOnError)
.map(this::addAuthorization)
.map(UserAgent::addUserAgent)
.map(AbstractUaaTokenProvider::addContentTypes)
.transform(tokenRequestTransformer))
.doOnSubscribe(NetworkLogging.post(uri))
.transform(NetworkLogging.response(uri)))
.transform(ErrorPayloadMapper.uaa(connectionContext.getObjectMapper()));
}
private Mono<String> token(ConnectionContext connectionContext) {
return this.refreshTokens.getOrDefault(connectionContext, Mono.empty())
.then(refreshToken -> refreshToken(connectionContext, refreshToken)
.doOnSubscribe(s -> LOGGER.debug("Negotiating using refresh token")))
.switchIfEmpty(primaryToken(connectionContext)
.doOnSubscribe(s -> LOGGER.debug("Negotiating using token provider")))
.transform(ErrorPayloadMapper.fallback())
.transform(extractTokens(connectionContext))
.cache()
.checkpoint();
}
}