/* * Copyright 2017 Red Hat, Inc. and/or its affiliates * and other contributors as indicated by the @author tags. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.keycloak.models.sessions.infinispan; import org.keycloak.cluster.ClusterEvent; import org.keycloak.cluster.ClusterProvider; import org.infinispan.context.Flag; import org.keycloak.models.KeycloakTransaction; import java.util.LinkedHashMap; import java.util.Map; import java.util.concurrent.TimeUnit; import org.infinispan.Cache; import org.jboss.logging.Logger; /** * @author <a href="mailto:sthorger@redhat.com">Stian Thorgersen</a> */ public class InfinispanKeycloakTransaction implements KeycloakTransaction { private final static Logger log = Logger.getLogger(InfinispanKeycloakTransaction.class); public enum CacheOperation { ADD, ADD_WITH_LIFESPAN, REMOVE, REPLACE, ADD_IF_ABSENT // ADD_IF_ABSENT throws an exception if there is existing value } private boolean active; private boolean rollback; private final Map<Object, CacheTask> tasks = new LinkedHashMap<>(); @Override public void begin() { active = true; } @Override public void commit() { if (rollback) { throw new RuntimeException("Rollback only!"); } tasks.values().forEach(CacheTask::execute); } @Override public void rollback() { tasks.clear(); } @Override public void setRollbackOnly() { rollback = true; } @Override public boolean getRollbackOnly() { return rollback; } @Override public boolean isActive() { return active; } public <K, V> void put(Cache<K, V> cache, K key, V value) { log.tracev("Adding cache operation: {0} on {1}", CacheOperation.ADD, key); Object taskKey = getTaskKey(cache, key); if (tasks.containsKey(taskKey)) { throw new IllegalStateException("Can't add session: task in progress for session"); } else { tasks.put(taskKey, new CacheTaskWithValue<V>(value) { @Override public void execute() { decorateCache(cache).put(key, value); } }); } } public <K, V> void put(Cache<K, V> cache, K key, V value, long lifespan, TimeUnit lifespanUnit) { log.tracev("Adding cache operation: {0} on {1}", CacheOperation.ADD_WITH_LIFESPAN, key); Object taskKey = getTaskKey(cache, key); if (tasks.containsKey(taskKey)) { throw new IllegalStateException("Can't add session: task in progress for session"); } else { tasks.put(taskKey, new CacheTaskWithValue<V>(value) { @Override public void execute() { decorateCache(cache).put(key, value, lifespan, lifespanUnit); } }); } } public <K, V> void putIfAbsent(Cache<K, V> cache, K key, V value) { log.tracev("Adding cache operation: {0} on {1}", CacheOperation.ADD_IF_ABSENT, key); Object taskKey = getTaskKey(cache, key); if (tasks.containsKey(taskKey)) { throw new IllegalStateException("Can't add session: task in progress for session"); } else { tasks.put(taskKey, new CacheTaskWithValue<V>(value) { @Override public void execute() { V existing = cache.putIfAbsent(key, value); if (existing != null) { throw new IllegalStateException("There is already existing value in cache for key " + key); } } }); } } public <K, V> void replace(Cache<K, V> cache, K key, V value) { log.tracev("Adding cache operation: {0} on {1}", CacheOperation.REPLACE, key); Object taskKey = getTaskKey(cache, key); CacheTask current = tasks.get(taskKey); if (current != null) { if (current instanceof CacheTaskWithValue) { ((CacheTaskWithValue<V>) current).setValue(value); } } else { tasks.put(taskKey, new CacheTaskWithValue<V>(value) { @Override public void execute() { decorateCache(cache).replace(key, value); } }); } } public <K, V> void notify(ClusterProvider clusterProvider, String taskKey, ClusterEvent event, boolean ignoreSender) { log.tracev("Adding cache operation SEND_EVENT: {0}", event); String theTaskKey = taskKey; int i = 1; while (tasks.containsKey(theTaskKey)) { theTaskKey = taskKey + "-" + (i++); } tasks.put(taskKey, () -> clusterProvider.notify(taskKey, event, ignoreSender)); } public <K, V> void remove(Cache<K, V> cache, K key) { log.tracev("Adding cache operation: {0} on {1}", CacheOperation.REMOVE, key); Object taskKey = getTaskKey(cache, key); tasks.put(taskKey, () -> decorateCache(cache).remove(key)); } // This is for possibility to lookup for session by id, which was created in this transaction public <K, V> V get(Cache<K, V> cache, K key) { Object taskKey = getTaskKey(cache, key); CacheTask<V> current = tasks.get(taskKey); if (current != null) { if (current instanceof CacheTaskWithValue) { return ((CacheTaskWithValue<V>) current).getValue(); } return null; } // Should we have per-transaction cache for lookups? return cache.get(key); } private static <K, V> Object getTaskKey(Cache<K, V> cache, K key) { if (key instanceof String) { return new StringBuilder(cache.getName()) .append("::") .append(key).toString(); } else { return key; } } public interface CacheTask<V> { void execute(); } public abstract class CacheTaskWithValue<V> implements CacheTask<V> { protected V value; public CacheTaskWithValue(V value) { this.value = value; } public V getValue() { return value; } public void setValue(V value) { this.value = value; } } // Ignore return values. Should have better performance within cluster / cross-dc env private static <K, V> Cache<K, V> decorateCache(Cache<K, V> cache) { return cache.getAdvancedCache() .withFlags(Flag.IGNORE_RETURN_VALUES, Flag.SKIP_REMOTE_LOOKUP); } }