/*
* Copyright 2016 The Simple File Server Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.sfs.encryption;
import com.google.common.base.Optional;
import io.vertx.core.logging.Logger;
import org.sfs.Server;
import org.sfs.VertxContext;
import org.sfs.elasticsearch.containerkey.GetNewestContainerKey;
import org.sfs.elasticsearch.containerkey.ListReEncryptableContainerKeys;
import org.sfs.elasticsearch.containerkey.LoadContainerKey;
import org.sfs.elasticsearch.containerkey.PersistContainerKey;
import org.sfs.elasticsearch.containerkey.UpdateContainerKey;
import org.sfs.rx.Holder2;
import org.sfs.rx.Holder3;
import org.sfs.rx.ToType;
import org.sfs.rx.ToVoid;
import org.sfs.vo.ObjectPath;
import org.sfs.vo.PersistentContainer;
import org.sfs.vo.PersistentContainerKey;
import org.sfs.vo.TransientContainerKey;
import org.sfs.vo.TransientServiceDef;
import rx.Observable;
import java.util.Calendar;
import java.util.Iterator;
import java.util.Set;
import java.util.concurrent.ConcurrentSkipListSet;
import java.util.concurrent.atomic.AtomicBoolean;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Strings.padStart;
import static com.google.common.math.LongMath.checkedAdd;
import static io.vertx.core.logging.LoggerFactory.getLogger;
import static java.lang.Long.parseLong;
import static java.lang.String.valueOf;
import static java.lang.System.currentTimeMillis;
import static java.util.Arrays.fill;
import static java.util.Calendar.getInstance;
import static java.util.concurrent.TimeUnit.DAYS;
import static org.sfs.encryption.AlgorithmDef.getPreferred;
import static org.sfs.rx.Defer.aVoid;
import static org.sfs.rx.Defer.just;
import static org.sfs.vo.ObjectPath.fromPaths;
import static rx.Observable.defer;
import static rx.Observable.using;
public class ContainerKeys {
private static final Logger LOGGER = getLogger(ContainerKeys.class);
private static final int DEFAULT_PAD = 19;
private static final long DEFAULT_RE_ENCRYPT_AGE = DAYS.toMillis(30);
private static final long DEFAULT_ROTATE_AGE = DAYS.toMillis(30);
private AtomicBoolean closed = new AtomicBoolean(true);
private VertxContext<Server> startedVertxContext;
private Set<Long> timerIds = new ConcurrentSkipListSet<>();
public Observable<Void> start(VertxContext<Server> vertxContext) {
return aVoid()
.filter(aVoid -> closed.compareAndSet(true, false))
.map(aVoid -> {
startedVertxContext = vertxContext;
return (Void) null;
})
.singleOrDefault(null);
}
public Observable<Void> stop(VertxContext<Server> vertxContext) {
return aVoid()
.filter(aVoid -> closed.compareAndSet(false, true))
.map(aVoid -> {
Iterator<Long> i = timerIds.iterator();
while (i.hasNext()) {
long timerId = i.next();
i.remove();
if (startedVertxContext != null) {
startedVertxContext.vertx().cancelTimer(timerId);
}
}
return (Void) null;
})
.singleOrDefault(null);
}
public Observable<KeyResponse> algorithm(VertxContext<Server> vertxContext, PersistentContainer persistentContainer, String keyId, byte[] salt) {
return defer(() -> {
checkOpen();
return just(new Holder2<>(persistentContainer, keyId))
.flatMap(new LoadContainerKey(vertxContext))
.map(holder -> {
Optional<PersistentContainerKey> oPersistentContainerKey = holder.value2();
checkState(oPersistentContainerKey.isPresent(), "ContainerKey %s not found", keyId);
return oPersistentContainerKey.get();
})
.flatMap(persistentContainerKey -> {
MasterKeys masterKeys = vertxContext.verticle().masterKeys();
return masterKeys.decrypt(vertxContext,
new MasterKeys.Encrypted(
persistentContainerKey.getKeyStoreKeyId().get(),
persistentContainerKey.getCipherSalt().get(),
persistentContainerKey.getEncryptedKey().get()))
.map(Optional::get)
.map(clearContainerKey -> {
try {
AlgorithmDef algorithmDef = persistentContainerKey.getAlgorithmDef().get();
Algorithm algorithm = algorithmDef.create(clearContainerKey, salt);
return new KeyResponse(persistentContainerKey.getId(), salt, algorithm);
} finally {
fill(clearContainerKey, (byte) 0);
}
});
});
});
}
public Observable<KeyResponse> preferredAlgorithm(VertxContext<Server> vertxContext, PersistentContainer persistentContainer) {
return defer(() -> {
checkOpen();
return just(persistentContainer)
.flatMap(new GetNewestContainerKey(vertxContext))
.flatMap(persistentContainerKeyOptional -> {
if (persistentContainerKeyOptional.isPresent()) {
return rotateIfRequired(vertxContext, persistentContainerKeyOptional.get());
} else {
return newIfAbsent(vertxContext, persistentContainer);
}
})
.flatMap(persistentContainerKey -> {
MasterKeys masterKeys = vertxContext.verticle().masterKeys();
return masterKeys.decrypt(vertxContext,
new MasterKeys.Encrypted(
persistentContainerKey.getKeyStoreKeyId().get(),
persistentContainerKey.getCipherSalt().get(),
persistentContainerKey.getEncryptedKey().get()))
.map(Optional::get)
.flatMap(clearContainerKey ->
Observable.using(
() -> null,
aVoid -> {
AlgorithmDef algorithmDef = persistentContainerKey.getAlgorithmDef().get();
return algorithmDef.generateSalt(vertxContext.vertx())
.map(salt -> {
Algorithm algorithm = algorithmDef.create(clearContainerKey, salt);
return new KeyResponse(persistentContainerKey.getId(), salt, algorithm);
});
},
resource -> fill(clearContainerKey, (byte) 0)));
});
});
}
public Observable<Void> maintain(VertxContext<Server> vertxContext) {
return defer(() -> {
if (vertxContext.verticle().nodes().isDataNode()) {
Calendar threshold = getInstance();
threshold.setTimeInMillis(currentTimeMillis() - DEFAULT_RE_ENCRYPT_AGE);
return aVoid()
.flatMap(new ListReEncryptableContainerKeys(vertxContext, threshold))
.flatMap(pck -> reEncrypt(vertxContext, pck))
.count()
.map(new ToVoid<>())
.singleOrDefault(null);
} else {
return aVoid();
}
});
}
protected Observable<Void> reEncrypt(VertxContext<Server> vertxContext, PersistentContainerKey persistentContainerKey) {
return defer(() -> {
boolean isDebugEnabled = LOGGER.isDebugEnabled();
if (isDebugEnabled) {
LOGGER.debug("Starting reEncrypt of key " + persistentContainerKey.getId());
}
MasterKeys masterKeys = vertxContext.verticle().masterKeys();
return masterKeys.decrypt(vertxContext,
new MasterKeys.Encrypted(
persistentContainerKey.getKeyStoreKeyId().get(),
persistentContainerKey.getCipherSalt().get(),
persistentContainerKey.getEncryptedKey().get()))
.map(Optional::get)
.flatMap(clearContainerKey ->
using(
() -> clearContainerKey,
bytes -> masterKeys.encrypt(vertxContext, bytes),
bytes -> fill(bytes, (byte) 0))
.map(encrypted -> persistentContainerKey
.setCipherSalt(encrypted.getSalt())
.setEncryptedKey(encrypted.getData())
.setKeyStoreKeyId(encrypted.getKeyId())
.setReEncryptTs(getInstance())
.setUpdateTs(getInstance())))
.flatMap(new UpdateContainerKey(vertxContext))
.onErrorResumeNext(throwable -> {
LOGGER.warn("Failed to reEncrypt key " + persistentContainerKey.getId(), throwable);
return just(null);
})
.map(new ToType<>((Void) null))
.map(aVoid -> {
if (isDebugEnabled) {
LOGGER.debug("Finished reEncrypt of key " + persistentContainerKey.getId());
}
return (Void) null;
});
});
}
protected Observable<PersistentContainerKey> rotateIfRequired(VertxContext<Server> vertxContext, PersistentContainerKey existingPersistentContainerKey) {
return defer(() -> {
Calendar createTs = existingPersistentContainerKey.getCreateTs();
AlgorithmDef preferredAlgorithmDef = getPreferred();
boolean shouldRotate = shouldRotate(createTs, existingPersistentContainerKey.getAlgorithmDef().get(), preferredAlgorithmDef);
if (shouldRotate) {
boolean isDebugEnabled = LOGGER.isDebugEnabled();
if (isDebugEnabled) {
LOGGER.debug("Starting Rotate of key " + existingPersistentContainerKey.getId());
}
PersistentContainer persistentContainer = existingPersistentContainerKey.getPersistentContainer();
ObjectPath objectPath = fromPaths(existingPersistentContainerKey.getId());
String existingKey = objectPath.objectName().get();
ObjectPath id =
fromPaths(
persistentContainer.getId(),
nextKey(existingKey));
MasterKeys masterKeys = vertxContext.verticle().masterKeys();
return preferredAlgorithmDef.generateKey(vertxContext.vertx())
.flatMap(clearContainerSecret ->
using(
() -> clearContainerSecret,
bytes -> masterKeys.encrypt(vertxContext, bytes),
bytes -> fill(bytes, (byte) 0))
.map(encrypted -> {
TransientContainerKey containerKey =
new TransientContainerKey(persistentContainer, id);
containerKey.setAlgorithmDef(preferredAlgorithmDef)
.setCipherSalt(encrypted.getSalt())
.setEncryptedKey(encrypted.getData())
.setKeyStoreKeyId(encrypted.getKeyId())
.setReEncryptTs(getInstance())
.setCreateTs(getInstance())
.setUpdateTs(getInstance());
return containerKey;
})
.doOnNext(transientContainerKey -> {
Optional<TransientServiceDef> currentMaintainerNode = vertxContext.verticle().getClusterInfo().getCurrentMaintainerNode();
if (currentMaintainerNode.isPresent()) {
transientContainerKey.setNodeId(currentMaintainerNode.get().getId());
}
})
.flatMap(new PersistContainerKey(vertxContext))
.map(Holder2::value1)
.map(newPersistentContainerKeyOptional -> {
// if this failed to persist another thread
// rotated the key so return the original
// since next time the key persisted by another
// thread will be used
if (newPersistentContainerKeyOptional.isPresent()) {
if (isDebugEnabled) {
LOGGER.debug("Finished Rotate of key " + existingPersistentContainerKey.getId() + ". New key is " + newPersistentContainerKeyOptional.get().getId());
}
return newPersistentContainerKeyOptional.get();
} else {
if (isDebugEnabled) {
LOGGER.debug("Finished Rotate of key " + existingPersistentContainerKey.getId() + ". Another thread completed the rotation");
}
return existingPersistentContainerKey;
}
}));
} else {
return just(existingPersistentContainerKey);
}
});
}
protected Observable<PersistentContainerKey> newIfAbsent(VertxContext<Server> vertxContext, PersistentContainer persistentContainer) {
return defer(() -> {
AlgorithmDef preferredAlgorithmDef = getPreferred();
boolean isDebugEnabled = LOGGER.isDebugEnabled();
if (isDebugEnabled) {
LOGGER.debug("Starting Create of new key");
}
ObjectPath id =
fromPaths(
persistentContainer.getId(),
firstKey());
MasterKeys masterKeys = vertxContext.verticle().masterKeys();
return preferredAlgorithmDef.generateKey(vertxContext.vertx())
.flatMap(clearContainerSecret ->
using(
() -> clearContainerSecret,
bytes -> masterKeys.encrypt(vertxContext, bytes),
bytes -> fill(bytes, (byte) 0))
.map(encrypted -> {
TransientContainerKey containerKey =
new TransientContainerKey(persistentContainer, id);
containerKey.setAlgorithmDef(preferredAlgorithmDef)
.setCipherSalt(encrypted.getSalt())
.setEncryptedKey(encrypted.getData())
.setKeyStoreKeyId(encrypted.getKeyId())
.setReEncryptTs(getInstance())
.setCreateTs(getInstance())
.setUpdateTs(getInstance());
return containerKey;
})
.doOnNext(transientContainerKey -> {
Optional<TransientServiceDef> currentMaintainerNode = vertxContext.verticle().getClusterInfo().getCurrentMaintainerNode();
if (currentMaintainerNode.isPresent()) {
transientContainerKey.setNodeId(currentMaintainerNode.get().getId());
}
})
.flatMap(new PersistContainerKey(vertxContext))
.map(Holder2::value1)
.flatMap(newPersistentContainerKey -> {
if (newPersistentContainerKey.isPresent()) {
return just(newPersistentContainerKey.get());
} else {
return just(new Holder2<>(persistentContainer, id.objectPath().get()))
.flatMap(new LoadContainerKey(vertxContext))
.map(Holder3::value2)
.map(Optional::get);
}
})
.map(newPersistentMasterKey -> {
if (isDebugEnabled) {
LOGGER.debug("Finished Create of key " + newPersistentMasterKey.getId());
}
return newPersistentMasterKey;
}));
});
}
protected boolean shouldRotate(Calendar createTs, AlgorithmDef currentAlgorithmDef, AlgorithmDef preferredAlgorithmDef) {
return createTs.getTimeInMillis() <= currentTimeMillis() - DEFAULT_ROTATE_AGE
|| !currentAlgorithmDef.equals(preferredAlgorithmDef);
}
protected void checkOpen() {
checkState(!closed.get(), "Already close");
}
protected String nextKey(String value) {
return pad(valueOf(checkedAdd(parseLong(value), 1)));
}
protected String firstKey() {
return pad("0");
}
protected String pad(String unpadded) {
return padStart(unpadded, DEFAULT_PAD, '0');
}
public static class KeyResponse {
private final String keyId;
private final byte[] salt;
private final Algorithm data;
public KeyResponse(String keyId, byte[] salt, Algorithm data) {
this.data = data;
this.keyId = keyId;
this.salt = salt;
}
public Algorithm getData() {
return data;
}
public String getKeyId() {
return keyId;
}
public byte[] getSalt() {
return salt;
}
}
}