/*************************************************************************
* (c) Copyright 2016 Hewlett Packard Enterprise Development Company LP
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; version 3 of the License.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
************************************************************************/
package com.eucalyptus.tokens.oidc;
import static java.lang.System.getProperty;
import static com.google.common.base.MoreObjects.firstNonNull;
import static com.google.common.primitives.Ints.tryParse;
import java.io.IOException;
import java.io.InputStream;
import java.net.HttpURLConnection;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.security.cert.Certificate;
import java.util.concurrent.atomic.AtomicReference;
import javax.net.ssl.HttpsURLConnection;
import org.apache.commons.io.input.BoundedInputStream;
import com.eucalyptus.crypto.util.SslSetup;
import com.eucalyptus.util.Pair;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheBuilderSpec;
import com.google.common.collect.ImmutableList;
import com.google.common.io.ByteStreams;
import com.google.common.net.HttpHeaders;
import javaslang.control.Option;
/**
*
*/
public class OidcDiscoveryCache {
private static final int CONNECT_TIMEOUT =
firstNonNull( tryParse( getProperty( "com.eucalyptus.tokens.oidc.connectTimeout", "" ) ), 20_000 );
private static final int READ_TIMEOUT =
firstNonNull( tryParse( getProperty( "com.eucalyptus.tokens.oidc.readTimeout", "" ) ), 30_000 );
private static final int MAX_LENGTH =
firstNonNull( tryParse( getProperty( "com.eucalyptus.tokens.oidc.maxLength", "" ) ), 128 * 1024 );
private final AtomicReference<Pair<String,Cache<String,OidcDiscoveryCachedResource>>> cacheReference =
new AtomicReference<>( );
public Pair<String, Certificate[]> get(
final String cacheSpec,
final long minimumRefreshInterval,
final long timeNow,
final String url
) throws IOException {
final Cache<String,OidcDiscoveryCachedResource> cache = cache( cacheSpec );
final OidcDiscoveryCachedResource cachedResource = cache.getIfPresent( url );
final OidcDiscoveryCachedResource resource;
if ( cachedResource == null ) { // not cached
resource = fetchResource( url, timeNow, null );
} else if ( cachedResource.needsRefresh( minimumRefreshInterval, timeNow ) ) { // cache refresh expired, check if current
resource = fetchResource( url, timeNow, cachedResource );
} else { // use existing
resource = cachedResource;
}
if ( resource != cachedResource ) {
cache.put( url, resource );
}
return resource.contentPair( );
}
private OidcDiscoveryCachedResource fetchResource(
final String url,
final long timeNow,
final OidcDiscoveryCachedResource cached
) throws IOException {
final URL location = new URL( url );
final OidcResource oidcResource;
{ // setup url connection and resolve
final HttpURLConnection conn = (HttpURLConnection) location.openConnection( );
conn.setAllowUserInteraction( false );
conn.setInstanceFollowRedirects( false );
conn.setConnectTimeout( CONNECT_TIMEOUT );
conn.setReadTimeout( READ_TIMEOUT );
conn.setUseCaches( false );
if ( cached != null ) {
if ( cached.lastModified.isDefined( ) ) {
conn.setRequestProperty( HttpHeaders.IF_MODIFIED_SINCE, cached.lastModified.get( ) );
}
if ( cached.etag.isDefined( ) ) {
conn.setRequestProperty( HttpHeaders.IF_NONE_MATCH, cached.etag.get( ) );
}
}
oidcResource = resolve( conn );
}
// build cache entry from resource
if ( oidcResource.statusCode == 304 ) {
return new OidcDiscoveryCachedResource( timeNow, cached );
} else {
return new OidcDiscoveryCachedResource(
timeNow,
Option.of( oidcResource.lastModifiedHeader ),
Option.of( oidcResource.etagHeader ),
ImmutableList.copyOf( oidcResource.certs ),
url,
new String( oidcResource.content, StandardCharsets.UTF_8 )
);
}
}
protected OidcResource resolve( final HttpURLConnection conn ) throws IOException {
SslSetup.configureHttpsUrlConnection( conn );
try ( final InputStream istr = conn.getInputStream( ) ) {
final int statusCode = conn.getResponseCode( );
if ( statusCode == 304 ) {
return new OidcResource( statusCode );
} else {
Certificate[] certs = new Certificate[0];
if (conn instanceof HttpsURLConnection ) {
certs = ((HttpsURLConnection)conn).getServerCertificates();
}
final long contentLength = conn.getContentLengthLong( );
if ( contentLength > MAX_LENGTH) {
throw new IOException( conn.getURL( ) + " content exceeds maximum size, " + MAX_LENGTH );
}
final byte[] content = ByteStreams.toByteArray( new BoundedInputStream( istr, MAX_LENGTH + 1 ) );
if ( content.length > MAX_LENGTH) {
throw new IOException( conn.getURL( ) + " content exceeds maximum size, " + MAX_LENGTH );
}
return new OidcResource(
statusCode,
conn.getHeaderField( HttpHeaders.LAST_MODIFIED ),
conn.getHeaderField( HttpHeaders.ETAG ),
certs,
content
);
}
}
}
private Cache<String,OidcDiscoveryCachedResource> cache( final String cacheSpec ) {
Cache<String,OidcDiscoveryCachedResource> cache;
final Pair<String,Cache<String,OidcDiscoveryCachedResource>> cachePair = cacheReference.get( );
if ( cachePair == null || !cacheSpec.equals( cachePair.getLeft( ) ) ) {
final Pair<String,Cache<String,OidcDiscoveryCachedResource>> newCachePair = Pair.pair(
cacheSpec,
CacheBuilder.from( CacheBuilderSpec.parse( cacheSpec ) ).build( ) );
if ( cacheReference.compareAndSet( cachePair, newCachePair ) || cachePair == null ) {
cache = newCachePair.getRight( );
} else {
cache = cachePair.getRight( );
}
} else {
cache = cachePair.getRight( );
}
return cache;
}
protected static class OidcResource {
private final int statusCode;
private final String lastModifiedHeader;
private final String etagHeader;
private final Certificate[] certs;
private final byte[] content;
protected OidcResource(
final int statusCode
) {
this( statusCode, null, null, null, null );
}
protected OidcResource(
final int statusCode,
final String lastModifiedHeader,
final String etagHeader,
final Certificate[] certs,
final byte[] content
) {
this.statusCode = statusCode;
this.lastModifiedHeader = lastModifiedHeader;
this.etagHeader = etagHeader;
this.certs = certs;
this.content = content;
}
}
private static class OidcDiscoveryCachedResource {
private final long cached;
private final Option<String> lastModified;
private final Option<String> etag;
private final ImmutableList<Certificate> certificateChain;
private final String url;
private final String resource;
private OidcDiscoveryCachedResource(
final long timeNow,
final OidcDiscoveryCachedResource from
) {
this.cached = timeNow;
this.lastModified = from.lastModified;
this.etag = from.etag;
this.certificateChain = from.certificateChain;
this.url = from.url;
this.resource = from.resource;
}
private OidcDiscoveryCachedResource(
final long cached,
final Option<String> lastModified,
final Option<String> etag,
final ImmutableList<Certificate> certificateChain,
final String url,
final String resource
) {
this.cached = cached;
this.lastModified = lastModified;
this.etag = etag;
this.certificateChain = certificateChain;
this.url = url;
this.resource = resource;
}
private boolean needsRefresh(
final long minimumRefreshInterval,
final long timeNow
) {
return timeNow > ( cached + minimumRefreshInterval );
}
private Pair<String, Certificate[]> contentPair( ) {
return Pair.pair( resource, certificateChain.toArray( new Certificate[ certificateChain.size( ) ] ) );
}
}
}