package io.airlift.airship.coordinator.auth;
import com.amazonaws.services.s3.AmazonS3;
import com.amazonaws.services.s3.model.ListObjectsRequest;
import com.amazonaws.services.s3.model.ObjectListing;
import com.amazonaws.services.s3.model.S3Object;
import com.amazonaws.services.s3.model.S3ObjectSummary;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Function;
import com.google.common.collect.AbstractSequentialIterator;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterators;
import com.google.common.io.CharSource;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import io.airlift.airship.coordinator.AwsProvisionerConfig;
import io.airlift.units.Duration;
import javax.annotation.PostConstruct;
import javax.annotation.PreDestroy;
import javax.inject.Inject;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.Reader;
import java.nio.charset.StandardCharsets;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import static com.google.common.collect.Lists.newArrayList;
public class S3AuthorizedKeyStore
implements AuthorizedKeyStore
{
private final AmazonS3 s3Client;
private final String bucketName;
private final String path;
private final AtomicReference<Map<Fingerprint, AuthorizedKey>> authorizedKeys = new AtomicReference<Map<Fingerprint, AuthorizedKey>>(ImmutableMap.<Fingerprint, AuthorizedKey>of());
private Map<String, KeyFile> keyFiles = new TreeMap<String, KeyFile>();
private final ScheduledExecutorService executor;
private final Duration refreshInterval;
@Inject
public S3AuthorizedKeyStore(AmazonS3 s3Client, AwsProvisionerConfig awsProvisionerConfig)
{
this(s3Client, awsProvisionerConfig.getS3KeystoreBucket(), awsProvisionerConfig.getS3KeystorePath(), awsProvisionerConfig.getS3KeystoreRefreshInterval());
}
public S3AuthorizedKeyStore(AmazonS3 s3Client, String bucketName, String path, Duration refreshInterval)
{
this.s3Client = s3Client;
this.bucketName = bucketName;
if (bucketName != null) {
if (path == null) {
path = "/";
}
else if (!path.endsWith("/")) {
path = path + "/";
}
}
this.path = path;
this.refreshInterval = refreshInterval;
if (refreshInterval != null) {
executor = Executors.newSingleThreadScheduledExecutor(new ThreadFactoryBuilder().setDaemon(true).setNameFormat("S3AuthorizedKeyStore-%s").build());
}
else {
executor = null;
}
refreshKeys();
}
@PostConstruct
public void start()
{
if (executor != null) {
executor.scheduleWithFixedDelay(new Runnable()
{
@Override
public void run()
{
refreshKeys();
}
}, 0, (long) refreshInterval.toMillis(), TimeUnit.MILLISECONDS);
}
}
@PreDestroy
public void stop()
{
if (executor != null) {
executor.shutdownNow();
}
}
@Override
public AuthorizedKey get(Fingerprint fingerprint)
{
return authorizedKeys.get().get(fingerprint);
}
@VisibleForTesting
synchronized Map<Fingerprint, AuthorizedKey> refreshKeys()
{
ImmutableMap.Builder<Fingerprint, AuthorizedKey> newAuthorizedKeys = ImmutableMap.builder();
if (bucketName == null) {
return newAuthorizedKeys.build();
}
Map<String, KeyFile> newKeyFiles = new TreeMap<String, KeyFile>();
for (S3ObjectSummary objectSummary : new S3ObjectListing(s3Client, new ListObjectsRequest(bucketName, path, null, "/", null))) {
KeyFile keyFile = keyFiles.get(objectSummary.getKey());
// only load s3 data if the file is new or has changed
if (keyFile == null || !keyFile.etag.equals(objectSummary.getETag())) {
try {
String userId = objectSummary.getKey().substring(path.length());
if (userId.isEmpty()) {
// sometimes s3 returns the directory as an object
continue;
}
if (userId.endsWith(".pub")) {
userId = userId.substring(0, userId.length() - ".pub".length());
}
List<AuthorizedKey> keys = newArrayList();
for (String line : new S3InputSupplier(s3Client, bucketName, objectSummary.getKey()).readLines()) {
line = line.trim();
if (!line.isEmpty()) {
PublicKey key = PublicKey.valueOf(line);
keys.add(new AuthorizedKey(userId, key));
}
}
keyFile = new KeyFile(objectSummary.getKey(), objectSummary.getETag(), keys);
}
catch (IOException e) {
// assume key file was removed between listing and fetch
keyFile = null;
}
catch (Exception e) {
// corrupt key file
// todo warn?
keyFile = null;
}
}
if (keyFile != null) {
newKeyFiles.put(keyFile.s3Key, keyFile);
for (AuthorizedKey authorizedKey : keyFile.authorizedKeys) {
newAuthorizedKeys.put(authorizedKey.getPublicKey().getFingerprint(), authorizedKey);
}
}
}
keyFiles = newKeyFiles;
authorizedKeys.set(newAuthorizedKeys.build());
return newAuthorizedKeys.build();
}
private static class KeyFile
{
private final String s3Key;
private final String etag;
private final List<AuthorizedKey> authorizedKeys;
private KeyFile(String s3Key, String etag, List<AuthorizedKey> authorizedKeys)
{
this.s3Key = s3Key;
this.etag = etag;
this.authorizedKeys = authorizedKeys;
}
@Override
public String toString()
{
final StringBuilder sb = new StringBuilder();
sb.append("KeyFile");
sb.append("{s3Key='").append(s3Key).append('\'');
sb.append(", etag='").append(etag).append('\'');
sb.append(", authorizedKeys=").append(authorizedKeys);
sb.append('}');
return sb.toString();
}
}
private static class S3ObjectListing
implements Iterable<S3ObjectSummary>
{
private final AmazonS3 s3Client;
private final ListObjectsRequest listObjectsRequest;
public S3ObjectListing(AmazonS3 s3Client, ListObjectsRequest listObjectsRequest)
{
this.s3Client = s3Client;
this.listObjectsRequest = listObjectsRequest;
}
@Override
public Iterator<S3ObjectSummary> iterator()
{
Iterator<ObjectListing> objectListings = new AbstractSequentialIterator<ObjectListing>(s3Client.listObjects(listObjectsRequest))
{
@Override
protected ObjectListing computeNext(ObjectListing previous)
{
if (!previous.isTruncated()) {
return null;
}
return s3Client.listNextBatchOfObjects(previous);
}
};
return Iterators.concat(Iterators.transform(objectListings, new Function<ObjectListing, Iterator<S3ObjectSummary>>()
{
@Override
public Iterator<S3ObjectSummary> apply(ObjectListing input)
{
return input.getObjectSummaries().iterator();
}
}));
}
}
private static class S3InputSupplier
extends CharSource
{
private final AmazonS3 s3Client;
private final String bucketName;
private final String key;
public S3InputSupplier(AmazonS3 s3Client, String bucketName, String key)
{
this.s3Client = s3Client;
this.bucketName = bucketName;
this.key = key;
}
@Override
public Reader openStream()
throws IOException
{
S3Object object = s3Client.getObject(bucketName, key);
return new InputStreamReader(object.getObjectContent(), StandardCharsets.UTF_8);
}
}
}