/* * Copyright 2013-2014 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.springframework.cloud.aws.core.io.s3; import com.amazonaws.auth.AWSCredentialsProvider; import com.amazonaws.regions.Regions; import com.amazonaws.services.s3.AmazonS3; import com.amazonaws.services.s3.AmazonS3Client; import com.amazonaws.services.s3.AmazonS3ClientBuilder; import com.amazonaws.services.s3.AmazonS3URI; import org.springframework.util.Assert; import org.springframework.util.ReflectionUtils; import java.lang.reflect.Field; import java.net.URI; import java.net.URISyntaxException; import java.util.concurrent.ConcurrentHashMap; /** * {@link AmazonS3} client factory that create clients for other regions based on the source client and a endpoint url. * Caches clients per region to enable re-use on a region base. * * @author Agim Emruli * @since 1.2 */ public class AmazonS3ClientFactory { private static final String CREDENTIALS_PROVIDER_FIELD_NAME = "awsCredentialsProvider"; private final ConcurrentHashMap<String, AmazonS3> clientCache = new ConcurrentHashMap<>(Regions.values().length); private final Field credentialsProviderField; public AmazonS3ClientFactory() { this.credentialsProviderField = ReflectionUtils.findField(AmazonS3Client.class, CREDENTIALS_PROVIDER_FIELD_NAME); Assert.notNull(this.credentialsProviderField, "Credentials Provider field not found, this class does not work with the current " + "AWS SDK release"); ReflectionUtils.makeAccessible(this.credentialsProviderField); } public AmazonS3 createClientForEndpointUrl(AmazonS3 prototype, String endpointUrl) { Assert.notNull(prototype, "AmazonS3 must not be null"); Assert.notNull(endpointUrl, "Endpoint Url must not be null"); String region = getRegion(endpointUrl); Assert.notNull(region, "Error detecting region from endpoint url:'" + endpointUrl + "'"); if (!this.clientCache.containsKey(region)) { AmazonS3ClientBuilder amazonS3ClientBuilder = buildAmazonS3ForRegion(prototype, region); this.clientCache.putIfAbsent(region, amazonS3ClientBuilder.build()); } return this.clientCache.get(region); } private static String getRegion(String endpointUrl) { Assert.notNull(endpointUrl, "Endpoint Url must not be null"); try { URI uri = new URI(endpointUrl); if ("s3.amazonaws.com".equals(uri.getHost())) { return Regions.DEFAULT_REGION.getName(); } else { return new AmazonS3URI(endpointUrl).getRegion(); } } catch (URISyntaxException e) { throw new RuntimeException("Malformed URL received for endpoint", e); } } private AmazonS3ClientBuilder buildAmazonS3ForRegion(AmazonS3 prototype, String region) { AmazonS3ClientBuilder clientBuilder = AmazonS3ClientBuilder.standard(); if (prototype instanceof AmazonS3Client) { AWSCredentialsProvider awsCredentialsProvider = (AWSCredentialsProvider) ReflectionUtils.getField(this.credentialsProviderField, prototype); clientBuilder.withCredentials(awsCredentialsProvider); } return clientBuilder.withRegion(region); } }