/*
* Copyright 2016 The Simple File Server Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.sfs.encryption;
import com.google.protobuf.InvalidProtocolBufferException;
import com.microsoft.aad.adal4j.AuthenticationContext;
import com.microsoft.aad.adal4j.AuthenticationResult;
import com.microsoft.aad.adal4j.ClientCredential;
import com.microsoft.azure.keyvault.KeyVaultClient;
import com.microsoft.azure.keyvault.authentication.KeyVaultCredentials;
import com.microsoft.azure.keyvault.models.KeyOperationResult;
import com.microsoft.windowsazure.Configuration;
import com.microsoft.windowsazure.core.pipeline.filter.ServiceRequestContext;
import com.microsoft.windowsazure.credentials.CloudCredentials;
import io.vertx.core.Context;
import io.vertx.core.json.JsonObject;
import io.vertx.core.logging.Logger;
import org.apache.http.Header;
import org.apache.http.message.BasicHeader;
import org.sfs.Server;
import org.sfs.SfsVertx;
import org.sfs.VertxContext;
import org.sfs.rx.RxHelper;
import org.sfs.util.ConfigHelper;
import rx.Observable;
import java.util.Map;
import java.util.Properties;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicBoolean;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.protobuf.ByteString.copyFrom;
import static com.microsoft.azure.keyvault.KeyVaultClientService.create;
import static com.microsoft.azure.keyvault.KeyVaultConfiguration.configure;
import static com.microsoft.azure.keyvault.extensions.cryptography.algorithms.RsaOaep.AlgorithmName;
import static io.vertx.core.logging.LoggerFactory.getLogger;
import static java.lang.String.format;
import static java.lang.System.getProperty;
import static java.lang.System.getenv;
import static java.util.Arrays.fill;
import static java.util.concurrent.Executors.newCachedThreadPool;
import static java.util.concurrent.TimeUnit.SECONDS;
import static org.sfs.protobuf.AzureProtoBuff.CipherText;
import static org.sfs.protobuf.AzureProtoBuff.CipherText.newBuilder;
import static org.sfs.protobuf.AzureProtoBuff.CipherText.parseFrom;
import static org.sfs.rx.Defer.aVoid;
import static rx.Observable.defer;
import static rx.Observable.using;
public class AzureKms implements Kms {
private static final Logger LOGGER = getLogger(AzureKms.class);
private Properties properties;
private KeyVaultClient kms;
private String endpoint;
private String keyId;
private String accessKeyId;
private String secretKey;
private String azureKeyIdentifier;
private AtomicBoolean started = new AtomicBoolean(false);
private ExecutorService executorService;
public AzureKms() {
}
public Observable<Void> start(VertxContext<Server> vertxContext,
JsonObject config) {
SfsVertx sfsVertx = vertxContext.vertx();
Context context = sfsVertx.getOrCreateContext();
return aVoid()
.filter(aVoid -> started.compareAndSet(false, true))
.flatMap(aVoid -> {
executorService = newCachedThreadPool();
endpoint = ConfigHelper.getFieldOrEnv(config, "keystore.azure.kms.endpoint");
checkArgument(endpoint != null, "keystore.azure.kms.endpoint is required");
keyId = ConfigHelper.getFieldOrEnv(config, "keystore.azure.kms.key_id");
checkArgument(keyId != null, "keystore.azure.kms.key_id is required");
accessKeyId = ConfigHelper.getFieldOrEnv(config, "keystore.azure.kms.access_key_id");
checkArgument(accessKeyId != null, "keystore.aws.kms.access_key_id is required");
secretKey = ConfigHelper.getFieldOrEnv(config, "keystore.azure.kms.secret_key");
checkArgument(secretKey != null, "keystore.azure.kms.secret_key is required");
azureKeyIdentifier = format("%s/keys/%s", endpoint, keyId);
return RxHelper.executeBlocking(context, sfsVertx.getBackgroundPool(), () -> {
try {
kms = createKeyVaultClient(vertxContext);
} catch (Exception e) {
throw new RuntimeException(e);
}
return (Void) null;
});
})
.singleOrDefault(null);
}
public Configuration createConfiguration(VertxContext<Server> vertxContext) throws Exception {
return configure(null, createCredentials(vertxContext));
}
protected KeyVaultClient createKeyVaultClient(VertxContext<Server> vertxContext) throws Exception {
Configuration config = createConfiguration(vertxContext);
return create(config);
}
private CloudCredentials createCredentials(VertxContext<Server> vertxContext) throws Exception {
return new KeyVaultCredentials() {
@Override
public Header doAuthenticate(ServiceRequestContext request, Map<String, String> challenge) {
try {
String authorization = challenge.get("authorization");
String resource = challenge.get("resource");
AuthenticationResult authResult = getAccessToken(vertxContext, accessKeyId, secretKey, authorization, resource);
return new BasicHeader("Authorization", authResult.getAccessTokenType() + " " + authResult.getAccessToken());
} catch (Exception ex) {
throw new RuntimeException(ex);
}
}
};
}
private AuthenticationResult getAccessToken(VertxContext<Server> vertxContext, String clientId, String clientKey, String authorization, String resource) throws Exception {
AuthenticationContext context = new AuthenticationContext(authorization, false, executorService);
ClientCredential credentials = new ClientCredential(clientId, clientKey);
AuthenticationResult result = context.acquireToken(resource, credentials, null).get();
checkNotNull(result, "AuthenticationResult was null");
return result;
}
public String getKeyId() {
return keyId;
}
@Override
public Observable<Encrypted> encrypt(VertxContext<Server> vertxContext, byte[] plainBytes) {
SfsVertx sfsVertx = vertxContext.vertx();
Context context = sfsVertx.getOrCreateContext();
return defer(() -> RxHelper.executeBlocking(context, sfsVertx.getBackgroundPool(), () -> {
String algorithm = AlgorithmName;
Future<KeyOperationResult> encrypted = kms.encryptAsync(azureKeyIdentifier, algorithm, plainBytes);
try {
KeyOperationResult result = encrypted.get(60, SECONDS);
CipherText instance =
newBuilder()
.setAlgorithm(algorithm)
.setKeyIdentifier(result.getKid())
.setData(copyFrom(result.getResult()))
.build();
return new Encrypted(instance.toByteArray(), format("xppsazure:%s", azureKeyIdentifier));
} catch (InterruptedException | ExecutionException | TimeoutException e) {
throw new RuntimeException(e);
}
}));
}
@Override
public Observable<Encrypted> reencrypt(VertxContext<Server> vertxContext, byte[] cipherBytes) {
return decrypt(vertxContext, cipherBytes)
.flatMap(clearBytes -> using(
() -> clearBytes,
bytes -> encrypt(vertxContext, bytes),
bytes -> fill(bytes, (byte) 0)));
}
@Override
public Observable<byte[]> decrypt(VertxContext<Server> vertxContext, byte[] cipherBytes) {
SfsVertx sfsVertx = vertxContext.vertx();
Context context = sfsVertx.getOrCreateContext();
return defer(() -> RxHelper.executeBlocking(context, sfsVertx.getBackgroundPool(), () -> {
try {
CipherText instance = parseFrom(cipherBytes.clone());
String keyIdentifier = instance.getKeyIdentifier();
String algorithm = instance.getAlgorithm();
byte[] data = instance.getData().toByteArray();
Future<KeyOperationResult> future = kms.decryptAsync(keyIdentifier, algorithm, data);
KeyOperationResult result = future.get(60, SECONDS);
return result.getResult();
} catch (InvalidProtocolBufferException | InterruptedException | ExecutionException | TimeoutException e) {
throw new RuntimeException(e);
}
}));
}
public Observable<Void> stop(VertxContext<Server> vertxContext) {
SfsVertx sfsVertx = vertxContext.vertx();
Context context = sfsVertx.getOrCreateContext();
return aVoid()
.filter(aVoid -> started.compareAndSet(true, false))
.flatMap(aVoid -> {
if (properties != null) {
properties.clear();
properties = null;
}
if (kms != null) {
return RxHelper.executeBlocking(context, sfsVertx.getBackgroundPool(), () -> {
try {
kms.close();
} catch (Throwable e) {
LOGGER.warn("Unhandled Exception", e);
}
return (Void) null;
});
}
if (executorService != null) {
try {
executorService.shutdown();
} catch (Throwable e) {
LOGGER.warn("Unhandled Exception", e);
} finally {
executorService = null;
}
}
return aVoid();
})
.singleOrDefault(null);
}
}