/*
* Copyright 2016 Red Hat, Inc. and/or its affiliates
* and other contributors as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.keycloak.keys.infinispan;
import java.security.PublicKey;
import java.util.Collections;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.FutureTask;
import org.infinispan.Cache;
import org.jboss.logging.Logger;
import org.keycloak.cluster.ClusterProvider;
import org.keycloak.common.util.Time;
import org.keycloak.keys.PublicKeyLoader;
import org.keycloak.keys.PublicKeyStorageProvider;
import org.keycloak.models.KeycloakSession;
import org.keycloak.models.KeycloakTransaction;
import org.keycloak.models.cache.infinispan.ClearCacheEvent;
import org.keycloak.models.cache.infinispan.InfinispanCacheRealmProviderFactory;
/**
* @author <a href="mailto:mposolda@redhat.com">Marek Posolda</a>
*/
public class InfinispanPublicKeyStorageProvider implements PublicKeyStorageProvider {
private static final Logger log = Logger.getLogger(InfinispanPublicKeyStorageProvider.class);
private final KeycloakSession session;
private final Cache<String, PublicKeysEntry> keys;
private final Map<String, FutureTask<PublicKeysEntry>> tasksInProgress;
private final int minTimeBetweenRequests ;
private Set<String> invalidations = new HashSet<>();
private boolean transactionEnlisted = false;
public InfinispanPublicKeyStorageProvider(KeycloakSession session, Cache<String, PublicKeysEntry> keys, Map<String, FutureTask<PublicKeysEntry>> tasksInProgress, int minTimeBetweenRequests) {
this.session = session;
this.keys = keys;
this.tasksInProgress = tasksInProgress;
this.minTimeBetweenRequests = minTimeBetweenRequests;
}
@Override
public void clearCache() {
keys.clear();
ClusterProvider cluster = session.getProvider(ClusterProvider.class);
cluster.notify(InfinispanPublicKeyStorageProviderFactory.KEYS_CLEAR_CACHE_EVENTS, new ClearCacheEvent(), true);
}
void addInvalidation(String cacheKey) {
if (!transactionEnlisted) {
session.getTransactionManager().enlistAfterCompletion(getAfterTransaction());
transactionEnlisted = true;
}
this.invalidations.add(cacheKey);
}
protected KeycloakTransaction getAfterTransaction() {
return new KeycloakTransaction() {
@Override
public void begin() {
}
@Override
public void commit() {
runInvalidations();
}
@Override
public void rollback() {
runInvalidations();
}
@Override
public void setRollbackOnly() {
}
@Override
public boolean getRollbackOnly() {
return false;
}
@Override
public boolean isActive() {
return true;
}
};
}
protected void runInvalidations() {
ClusterProvider cluster = session.getProvider(ClusterProvider.class);
for (String cacheKey : invalidations) {
keys.remove(cacheKey);
cluster.notify(cacheKey, PublicKeyStorageInvalidationEvent.create(cacheKey), true);
}
}
@Override
public PublicKey getPublicKey(String modelKey, String kid, PublicKeyLoader loader) {
// Check if key is in cache
PublicKeysEntry entry = keys.get(modelKey);
if (entry != null) {
PublicKey publicKey = getPublicKey(entry.getCurrentKeys(), kid);
if (publicKey != null) {
return publicKey;
}
}
int lastRequestTime = entry==null ? 0 : entry.getLastRequestTime();
int currentTime = Time.currentTime();
// Check if we are allowed to send request
if (currentTime > lastRequestTime + minTimeBetweenRequests) {
WrapperCallable wrapperCallable = new WrapperCallable(modelKey, loader);
FutureTask<PublicKeysEntry> task = new FutureTask<>(wrapperCallable);
FutureTask<PublicKeysEntry> existing = tasksInProgress.putIfAbsent(modelKey, task);
if (existing == null) {
task.run();
} else {
task = existing;
}
try {
entry = task.get();
// Computation finished. Let's see if key is available
PublicKey publicKey = getPublicKey(entry.getCurrentKeys(), kid);
if (publicKey != null) {
return publicKey;
}
} catch (ExecutionException ee) {
throw new RuntimeException("Error when loading public keys", ee);
} catch (InterruptedException ie) {
throw new RuntimeException("Error. Interrupted when loading public keys", ie);
} finally {
// Our thread inserted the task. Let's clean
if (existing == null) {
tasksInProgress.remove(modelKey);
}
}
} else {
log.warnf("Won't load the keys for model '%s' . Last request time was %d", modelKey, lastRequestTime);
}
Set<String> availableKids = entry==null ? Collections.emptySet() : entry.getCurrentKeys().keySet();
log.warnf("PublicKey wasn't found in the storage. Requested kid: '%s' . Available kids: '%s'", kid, availableKids);
return null;
}
private PublicKey getPublicKey(Map<String, PublicKey> publicKeys, String kid) {
// Backwards compatibility
if (kid == null && !publicKeys.isEmpty()) {
return publicKeys.values().iterator().next();
} else {
return publicKeys.get(kid);
}
}
@Override
public void close() {
}
private class WrapperCallable implements Callable<PublicKeysEntry> {
private final String modelKey;
private final PublicKeyLoader delegate;
public WrapperCallable(String modelKey, PublicKeyLoader delegate) {
this.modelKey = modelKey;
this.delegate = delegate;
}
@Override
public PublicKeysEntry call() throws Exception {
PublicKeysEntry entry = keys.get(modelKey);
int lastRequestTime = entry==null ? 0 : entry.getLastRequestTime();
int currentTime = Time.currentTime();
// Check again if we are allowed to send request. There is a chance other task was already finished and removed from tasksInProgress in the meantime.
if (currentTime > lastRequestTime + minTimeBetweenRequests) {
Map<String, PublicKey> publicKeys = delegate.loadKeys();
if (log.isDebugEnabled()) {
log.debugf("Public keys retrieved successfully for model %s. New kids: %s", modelKey, publicKeys.keySet().toString());
}
entry = new PublicKeysEntry(currentTime, publicKeys);
keys.put(modelKey, entry);
}
return entry;
}
}
}