/************************************************************************* * (c) Copyright 2016 Hewlett Packard Enterprise Development Company LP * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; version 3 of the License. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program. If not, see http://www.gnu.org/licenses/. ************************************************************************/ package com.eucalyptus.tokens.oidc; import static java.lang.System.getProperty; import static com.google.common.base.MoreObjects.firstNonNull; import static com.google.common.primitives.Ints.tryParse; import java.io.IOException; import java.io.InputStream; import java.net.HttpURLConnection; import java.net.URL; import java.nio.charset.StandardCharsets; import java.security.cert.Certificate; import java.util.concurrent.atomic.AtomicReference; import javax.net.ssl.HttpsURLConnection; import org.apache.commons.io.input.BoundedInputStream; import com.eucalyptus.crypto.util.SslSetup; import com.eucalyptus.util.Pair; import com.google.common.cache.Cache; import com.google.common.cache.CacheBuilder; import com.google.common.cache.CacheBuilderSpec; import com.google.common.collect.ImmutableList; import com.google.common.io.ByteStreams; import com.google.common.net.HttpHeaders; import javaslang.control.Option; /** * */ public class OidcDiscoveryCache { private static final int CONNECT_TIMEOUT = firstNonNull( tryParse( getProperty( "com.eucalyptus.tokens.oidc.connectTimeout", "" ) ), 20_000 ); private static final int READ_TIMEOUT = firstNonNull( tryParse( getProperty( "com.eucalyptus.tokens.oidc.readTimeout", "" ) ), 30_000 ); private static final int MAX_LENGTH = firstNonNull( tryParse( getProperty( "com.eucalyptus.tokens.oidc.maxLength", "" ) ), 128 * 1024 ); private final AtomicReference<Pair<String,Cache<String,OidcDiscoveryCachedResource>>> cacheReference = new AtomicReference<>( ); public Pair<String, Certificate[]> get( final String cacheSpec, final long minimumRefreshInterval, final long timeNow, final String url ) throws IOException { final Cache<String,OidcDiscoveryCachedResource> cache = cache( cacheSpec ); final OidcDiscoveryCachedResource cachedResource = cache.getIfPresent( url ); final OidcDiscoveryCachedResource resource; if ( cachedResource == null ) { // not cached resource = fetchResource( url, timeNow, null ); } else if ( cachedResource.needsRefresh( minimumRefreshInterval, timeNow ) ) { // cache refresh expired, check if current resource = fetchResource( url, timeNow, cachedResource ); } else { // use existing resource = cachedResource; } if ( resource != cachedResource ) { cache.put( url, resource ); } return resource.contentPair( ); } private OidcDiscoveryCachedResource fetchResource( final String url, final long timeNow, final OidcDiscoveryCachedResource cached ) throws IOException { final URL location = new URL( url ); final OidcResource oidcResource; { // setup url connection and resolve final HttpURLConnection conn = (HttpURLConnection) location.openConnection( ); conn.setAllowUserInteraction( false ); conn.setInstanceFollowRedirects( false ); conn.setConnectTimeout( CONNECT_TIMEOUT ); conn.setReadTimeout( READ_TIMEOUT ); conn.setUseCaches( false ); if ( cached != null ) { if ( cached.lastModified.isDefined( ) ) { conn.setRequestProperty( HttpHeaders.IF_MODIFIED_SINCE, cached.lastModified.get( ) ); } if ( cached.etag.isDefined( ) ) { conn.setRequestProperty( HttpHeaders.IF_NONE_MATCH, cached.etag.get( ) ); } } oidcResource = resolve( conn ); } // build cache entry from resource if ( oidcResource.statusCode == 304 ) { return new OidcDiscoveryCachedResource( timeNow, cached ); } else { return new OidcDiscoveryCachedResource( timeNow, Option.of( oidcResource.lastModifiedHeader ), Option.of( oidcResource.etagHeader ), ImmutableList.copyOf( oidcResource.certs ), url, new String( oidcResource.content, StandardCharsets.UTF_8 ) ); } } protected OidcResource resolve( final HttpURLConnection conn ) throws IOException { SslSetup.configureHttpsUrlConnection( conn ); try ( final InputStream istr = conn.getInputStream( ) ) { final int statusCode = conn.getResponseCode( ); if ( statusCode == 304 ) { return new OidcResource( statusCode ); } else { Certificate[] certs = new Certificate[0]; if (conn instanceof HttpsURLConnection ) { certs = ((HttpsURLConnection)conn).getServerCertificates(); } final long contentLength = conn.getContentLengthLong( ); if ( contentLength > MAX_LENGTH) { throw new IOException( conn.getURL( ) + " content exceeds maximum size, " + MAX_LENGTH ); } final byte[] content = ByteStreams.toByteArray( new BoundedInputStream( istr, MAX_LENGTH + 1 ) ); if ( content.length > MAX_LENGTH) { throw new IOException( conn.getURL( ) + " content exceeds maximum size, " + MAX_LENGTH ); } return new OidcResource( statusCode, conn.getHeaderField( HttpHeaders.LAST_MODIFIED ), conn.getHeaderField( HttpHeaders.ETAG ), certs, content ); } } } private Cache<String,OidcDiscoveryCachedResource> cache( final String cacheSpec ) { Cache<String,OidcDiscoveryCachedResource> cache; final Pair<String,Cache<String,OidcDiscoveryCachedResource>> cachePair = cacheReference.get( ); if ( cachePair == null || !cacheSpec.equals( cachePair.getLeft( ) ) ) { final Pair<String,Cache<String,OidcDiscoveryCachedResource>> newCachePair = Pair.pair( cacheSpec, CacheBuilder.from( CacheBuilderSpec.parse( cacheSpec ) ).build( ) ); if ( cacheReference.compareAndSet( cachePair, newCachePair ) || cachePair == null ) { cache = newCachePair.getRight( ); } else { cache = cachePair.getRight( ); } } else { cache = cachePair.getRight( ); } return cache; } protected static class OidcResource { private final int statusCode; private final String lastModifiedHeader; private final String etagHeader; private final Certificate[] certs; private final byte[] content; protected OidcResource( final int statusCode ) { this( statusCode, null, null, null, null ); } protected OidcResource( final int statusCode, final String lastModifiedHeader, final String etagHeader, final Certificate[] certs, final byte[] content ) { this.statusCode = statusCode; this.lastModifiedHeader = lastModifiedHeader; this.etagHeader = etagHeader; this.certs = certs; this.content = content; } } private static class OidcDiscoveryCachedResource { private final long cached; private final Option<String> lastModified; private final Option<String> etag; private final ImmutableList<Certificate> certificateChain; private final String url; private final String resource; private OidcDiscoveryCachedResource( final long timeNow, final OidcDiscoveryCachedResource from ) { this.cached = timeNow; this.lastModified = from.lastModified; this.etag = from.etag; this.certificateChain = from.certificateChain; this.url = from.url; this.resource = from.resource; } private OidcDiscoveryCachedResource( final long cached, final Option<String> lastModified, final Option<String> etag, final ImmutableList<Certificate> certificateChain, final String url, final String resource ) { this.cached = cached; this.lastModified = lastModified; this.etag = etag; this.certificateChain = certificateChain; this.url = url; this.resource = resource; } private boolean needsRefresh( final long minimumRefreshInterval, final long timeNow ) { return timeNow > ( cached + minimumRefreshInterval ); } private Pair<String, Certificate[]> contentPair( ) { return Pair.pair( resource, certificateChain.toArray( new Certificate[ certificateChain.size( ) ] ) ); } } }