/* * Copyright 2016 Red Hat, Inc. and/or its affiliates * and other contributors as indicated by the @author tags. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.keycloak.models.sessions.infinispan; import org.infinispan.Cache; import org.infinispan.CacheStream; import org.infinispan.context.Flag; import org.jboss.logging.Logger; import org.keycloak.common.util.Time; import org.keycloak.models.ClientInitialAccessModel; import org.keycloak.models.AuthenticatedClientSessionModel; import org.keycloak.models.ClientModel; import org.keycloak.models.KeycloakSession; import org.keycloak.models.RealmModel; import org.keycloak.models.UserLoginFailureModel; import org.keycloak.models.UserModel; import org.keycloak.models.UserSessionModel; import org.keycloak.models.UserSessionProvider; import org.keycloak.models.session.UserSessionPersisterProvider; import org.keycloak.models.sessions.infinispan.entities.ClientInitialAccessEntity; import org.keycloak.models.sessions.infinispan.entities.AuthenticatedClientSessionEntity; import org.keycloak.models.sessions.infinispan.entities.LoginFailureEntity; import org.keycloak.models.sessions.infinispan.entities.LoginFailureKey; import org.keycloak.models.sessions.infinispan.entities.SessionEntity; import org.keycloak.models.sessions.infinispan.entities.UserSessionEntity; import org.keycloak.models.sessions.infinispan.stream.ClientInitialAccessPredicate; import org.keycloak.models.sessions.infinispan.stream.Comparators; import org.keycloak.models.sessions.infinispan.stream.Mappers; import org.keycloak.models.sessions.infinispan.stream.SessionPredicate; import org.keycloak.models.sessions.infinispan.stream.UserLoginFailurePredicate; import org.keycloak.models.sessions.infinispan.stream.UserSessionPredicate; import org.keycloak.models.utils.KeycloakModelUtils; import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.Iterator; import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.function.Consumer; import java.util.function.Predicate; import java.util.stream.Collectors; import java.util.stream.Stream; /** * @author <a href="mailto:sthorger@redhat.com">Stian Thorgersen</a> */ public class InfinispanUserSessionProvider implements UserSessionProvider { private static final Logger log = Logger.getLogger(InfinispanUserSessionProvider.class); protected final KeycloakSession session; protected final Cache<String, SessionEntity> sessionCache; protected final Cache<String, SessionEntity> offlineSessionCache; protected final Cache<LoginFailureKey, LoginFailureEntity> loginFailureCache; protected final InfinispanKeycloakTransaction tx; public InfinispanUserSessionProvider(KeycloakSession session, Cache<String, SessionEntity> sessionCache, Cache<String, SessionEntity> offlineSessionCache, Cache<LoginFailureKey, LoginFailureEntity> loginFailureCache) { this.session = session; this.sessionCache = sessionCache; this.offlineSessionCache = offlineSessionCache; this.loginFailureCache = loginFailureCache; this.tx = new InfinispanKeycloakTransaction(); session.getTransactionManager().enlistAfterCompletion(tx); } protected Cache<String, SessionEntity> getCache(boolean offline) { return offline ? offlineSessionCache : sessionCache; } @Override public AuthenticatedClientSessionModel createClientSession(RealmModel realm, ClientModel client, UserSessionModel userSession) { AuthenticatedClientSessionEntity entity = new AuthenticatedClientSessionEntity(); AuthenticatedClientSessionAdapter adapter = new AuthenticatedClientSessionAdapter(entity, client, (UserSessionAdapter) userSession, this, sessionCache); adapter.setUserSession(userSession); return adapter; } @Override public UserSessionModel createUserSession(String id, RealmModel realm, UserModel user, String loginUsername, String ipAddress, String authMethod, boolean rememberMe, String brokerSessionId, String brokerUserId) { UserSessionEntity entity = new UserSessionEntity(); entity.setId(id); updateSessionEntity(entity, realm, user, loginUsername, ipAddress, authMethod, rememberMe, brokerSessionId, brokerUserId); tx.putIfAbsent(sessionCache, id, entity); return wrap(realm, entity, false); } void updateSessionEntity(UserSessionEntity entity, RealmModel realm, UserModel user, String loginUsername, String ipAddress, String authMethod, boolean rememberMe, String brokerSessionId, String brokerUserId) { entity.setRealm(realm.getId()); entity.setUser(user.getId()); entity.setLoginUsername(loginUsername); entity.setIpAddress(ipAddress); entity.setAuthMethod(authMethod); entity.setRememberMe(rememberMe); entity.setBrokerSessionId(brokerSessionId); entity.setBrokerUserId(brokerUserId); int currentTime = Time.currentTime(); entity.setStarted(currentTime); entity.setLastSessionRefresh(currentTime); } @Override public UserSessionModel getUserSession(RealmModel realm, String id) { return getUserSession(realm, id, false); } protected UserSessionAdapter getUserSession(RealmModel realm, String id, boolean offline) { Cache<String, SessionEntity> cache = getCache(offline); UserSessionEntity entity = (UserSessionEntity) tx.get(cache, id); // Chance created in this transaction if (entity == null) { entity = (UserSessionEntity) cache.get(id); } return wrap(realm, entity, offline); } protected List<UserSessionModel> getUserSessions(RealmModel realm, Predicate<Map.Entry<String, SessionEntity>> predicate, boolean offline) { CacheStream<Map.Entry<String, SessionEntity>> cacheStream = getCache(offline).entrySet().stream(); Iterator<Map.Entry<String, SessionEntity>> itr = cacheStream.filter(predicate).iterator(); List<UserSessionModel> sessions = new LinkedList<>(); while (itr.hasNext()) { UserSessionEntity e = (UserSessionEntity) itr.next().getValue(); sessions.add(wrap(realm, e, offline)); } return sessions; } @Override public List<UserSessionModel> getUserSessions(final RealmModel realm, UserModel user) { return getUserSessions(realm, UserSessionPredicate.create(realm.getId()).user(user.getId()), false); } @Override public List<UserSessionModel> getUserSessionByBrokerUserId(RealmModel realm, String brokerUserId) { return getUserSessions(realm, UserSessionPredicate.create(realm.getId()).brokerUserId(brokerUserId), false); } @Override public UserSessionModel getUserSessionByBrokerSessionId(RealmModel realm, String brokerSessionId) { List<UserSessionModel> userSessions = getUserSessions(realm, UserSessionPredicate.create(realm.getId()).brokerSessionId(brokerSessionId), false); return userSessions.isEmpty() ? null : userSessions.get(0); } @Override public List<UserSessionModel> getUserSessions(RealmModel realm, ClientModel client) { return getUserSessions(realm, client, -1, -1); } @Override public List<UserSessionModel> getUserSessions(RealmModel realm, ClientModel client, int firstResult, int maxResults) { return getUserSessions(realm, client, firstResult, maxResults, false); } protected List<UserSessionModel> getUserSessions(final RealmModel realm, ClientModel client, int firstResult, int maxResults, final boolean offline) { final Cache<String, SessionEntity> cache = getCache(offline); Stream<UserSessionEntity> stream = cache.entrySet().stream() .filter(UserSessionPredicate.create(realm.getId()).client(client.getId())) .map(Mappers.userSessionEntity()) .sorted(Comparators.userSessionLastSessionRefresh()); // Doesn't work due to ISPN-6575 . TODO Fix once infinispan upgraded to 8.2.2.Final or 9.0 // if (firstResult > 0) { // stream = stream.skip(firstResult); // } // // if (maxResults > 0) { // stream = stream.limit(maxResults); // } // // List<UserSessionEntity> entities = stream.collect(Collectors.toList()); // Workaround for ISPN-6575 TODO Fix once infinispan upgraded to 8.2.2.Final or 9.0 and replace with the more effective code above if (firstResult < 0) { firstResult = 0; } if (maxResults < 0) { maxResults = Integer.MAX_VALUE; } int count = firstResult + maxResults; if (count > 0) { stream = stream.limit(count); } List<UserSessionEntity> entities = stream.collect(Collectors.toList()); if (firstResult > entities.size()) { return Collections.emptyList(); } maxResults = Math.min(maxResults, entities.size() - firstResult); entities = entities.subList(firstResult, firstResult + maxResults); final List<UserSessionModel> sessions = new LinkedList<>(); entities.stream().forEach(new Consumer<UserSessionEntity>() { @Override public void accept(UserSessionEntity userSessionEntity) { sessions.add(wrap(realm, userSessionEntity, offline)); } }); return sessions; } @Override public long getActiveUserSessions(RealmModel realm, ClientModel client) { return getUserSessionsCount(realm, client, false); } protected long getUserSessionsCount(RealmModel realm, ClientModel client, boolean offline) { return getCache(offline).entrySet().stream() .filter(UserSessionPredicate.create(realm.getId()).client(client.getId())) .count(); } @Override public void removeUserSession(RealmModel realm, UserSessionModel session) { UserSessionEntity entity = getUserSessionEntity(session, false); if (entity != null) { removeUserSession(realm, entity, false); } } @Override public void removeUserSessions(RealmModel realm, UserModel user) { removeUserSessions(realm, user, false); } protected void removeUserSessions(RealmModel realm, UserModel user, boolean offline) { Cache<String, SessionEntity> cache = getCache(offline); Iterator<SessionEntity> itr = cache.entrySet().stream().filter(UserSessionPredicate.create(realm.getId()).user(user.getId())).map(Mappers.sessionEntity()).iterator(); while (itr.hasNext()) { UserSessionEntity userSessionEntity = (UserSessionEntity) itr.next(); removeUserSession(realm, userSessionEntity, offline); } } @Override public void removeExpired(RealmModel realm) { log.debugf("Removing expired sessions"); removeExpiredUserSessions(realm); removeExpiredOfflineUserSessions(realm); removeExpiredClientInitialAccess(realm); } private void removeExpiredUserSessions(RealmModel realm) { int expired = Time.currentTime() - realm.getSsoSessionMaxLifespan(); int expiredRefresh = Time.currentTime() - realm.getSsoSessionIdleTimeout(); // Each cluster node cleanups just local sessions, which are those owned by himself (+ few more taking l1 cache into account) Iterator<Map.Entry<String, SessionEntity>> itr = sessionCache.getAdvancedCache().withFlags(Flag.CACHE_MODE_LOCAL) .entrySet().stream().filter(UserSessionPredicate.create(realm.getId()).expired(expired, expiredRefresh)).iterator(); int counter = 0; while (itr.hasNext()) { counter++; UserSessionEntity entity = (UserSessionEntity) itr.next().getValue(); tx.remove(sessionCache, entity.getId()); } log.debugf("Removed %d expired user sessions for realm '%s'", counter, realm.getName()); } private void removeExpiredOfflineUserSessions(RealmModel realm) { UserSessionPersisterProvider persister = session.getProvider(UserSessionPersisterProvider.class); int expiredOffline = Time.currentTime() - realm.getOfflineSessionIdleTimeout(); // Each cluster node cleanups just local sessions, which are those owned by himself (+ few more taking l1 cache into account) UserSessionPredicate predicate = UserSessionPredicate.create(realm.getId()).expired(null, expiredOffline); Iterator<Map.Entry<String, SessionEntity>> itr = offlineSessionCache.getAdvancedCache().withFlags(Flag.CACHE_MODE_LOCAL) .entrySet().stream().filter(predicate).iterator(); int counter = 0; while (itr.hasNext()) { counter++; UserSessionEntity entity = (UserSessionEntity) itr.next().getValue(); tx.remove(offlineSessionCache, entity.getId()); persister.removeUserSession(entity.getId(), true); for (String clientUUID : entity.getAuthenticatedClientSessions().keySet()) { persister.removeClientSession(entity.getId(), clientUUID, true); } } log.debugf("Removed %d expired offline user sessions for realm '%s'", counter, realm.getName()); } private void removeExpiredClientInitialAccess(RealmModel realm) { Iterator<String> itr = sessionCache.getAdvancedCache().withFlags(Flag.CACHE_MODE_LOCAL) .entrySet().stream().filter(ClientInitialAccessPredicate.create(realm.getId()).expired(Time.currentTime())).map(Mappers.sessionId()).iterator(); while (itr.hasNext()) { tx.remove(sessionCache, itr.next()); } } @Override public void removeUserSessions(RealmModel realm) { removeUserSessions(realm, false); } protected void removeUserSessions(RealmModel realm, boolean offline) { Cache<String, SessionEntity> cache = getCache(offline); Iterator<String> itr = cache.entrySet().stream().filter(SessionPredicate.create(realm.getId())).map(Mappers.sessionId()).iterator(); while (itr.hasNext()) { cache.remove(itr.next()); } } @Override public UserLoginFailureModel getUserLoginFailure(RealmModel realm, String userId) { LoginFailureKey key = new LoginFailureKey(realm.getId(), userId); return wrap(key, loginFailureCache.get(key)); } @Override public UserLoginFailureModel addUserLoginFailure(RealmModel realm, String userId) { LoginFailureKey key = new LoginFailureKey(realm.getId(), userId); LoginFailureEntity entity = new LoginFailureEntity(); entity.setRealm(realm.getId()); entity.setUserId(userId); tx.put(loginFailureCache, key, entity); return wrap(key, entity); } @Override public void removeUserLoginFailure(RealmModel realm, String userId) { tx.remove(loginFailureCache, new LoginFailureKey(realm.getId(), userId)); } @Override public void removeAllUserLoginFailures(RealmModel realm) { Iterator<LoginFailureKey> itr = loginFailureCache.entrySet().stream().filter(UserLoginFailurePredicate.create(realm.getId())).map(Mappers.loginFailureId()).iterator(); while (itr.hasNext()) { LoginFailureKey key = itr.next(); tx.remove(loginFailureCache, key); } } @Override public void onRealmRemoved(RealmModel realm) { removeUserSessions(realm, true); removeUserSessions(realm, false); removeAllUserLoginFailures(realm); } @Override public void onClientRemoved(RealmModel realm, ClientModel client) { // Nothing for now. userSession.getAuthenticatedClientSessions() will check lazily if particular client exists and update userSession on-the-fly. } protected void onUserRemoved(RealmModel realm, UserModel user) { removeUserSessions(realm, user, true); removeUserSessions(realm, user, false); loginFailureCache.remove(new LoginFailureKey(realm.getId(), user.getUsername())); loginFailureCache.remove(new LoginFailureKey(realm.getId(), user.getEmail())); } @Override public void close() { } protected void removeUserSession(RealmModel realm, UserSessionEntity sessionEntity, boolean offline) { Cache<String, SessionEntity> cache = getCache(offline); tx.remove(cache, sessionEntity.getId()); } InfinispanKeycloakTransaction getTx() { return tx; } UserSessionAdapter wrap(RealmModel realm, UserSessionEntity entity, boolean offline) { Cache<String, SessionEntity> cache = getCache(offline); return entity != null ? new UserSessionAdapter(session, this, cache, realm, entity, offline) : null; } List<UserSessionModel> wrapUserSessions(RealmModel realm, Collection<UserSessionEntity> entities, boolean offline) { List<UserSessionModel> models = new LinkedList<>(); for (UserSessionEntity e : entities) { models.add(wrap(realm, e, offline)); } return models; } List<ClientInitialAccessModel> wrapClientInitialAccess(RealmModel realm, Collection<ClientInitialAccessEntity> entities) { List<ClientInitialAccessModel> models = new LinkedList<>(); for (ClientInitialAccessEntity e : entities) { models.add(wrap(realm, e)); } return models; } ClientInitialAccessAdapter wrap(RealmModel realm, ClientInitialAccessEntity entity) { Cache<String, SessionEntity> cache = getCache(false); return entity != null ? new ClientInitialAccessAdapter(session, this, cache, realm, entity) : null; } UserLoginFailureModel wrap(LoginFailureKey key, LoginFailureEntity entity) { return entity != null ? new UserLoginFailureAdapter(this, loginFailureCache, key, entity) : null; } UserSessionEntity getUserSessionEntity(UserSessionModel userSession, boolean offline) { if (userSession instanceof UserSessionAdapter) { return ((UserSessionAdapter) userSession).getEntity(); } else { Cache<String, SessionEntity> cache = getCache(offline); return cache != null ? (UserSessionEntity) cache.get(userSession.getId()) : null; } } @Override public UserSessionModel createOfflineUserSession(UserSessionModel userSession) { UserSessionAdapter offlineUserSession = importUserSession(userSession, true, false); // started and lastSessionRefresh set to current time int currentTime = Time.currentTime(); offlineUserSession.getEntity().setStarted(currentTime); offlineUserSession.setLastSessionRefresh(currentTime); return offlineUserSession; } @Override public UserSessionAdapter getOfflineUserSession(RealmModel realm, String userSessionId) { return getUserSession(realm, userSessionId, true); } @Override public void removeOfflineUserSession(RealmModel realm, UserSessionModel userSession) { UserSessionEntity userSessionEntity = getUserSessionEntity(userSession, true); if (userSessionEntity != null) { removeUserSession(realm, userSessionEntity, true); } } @Override public AuthenticatedClientSessionModel createOfflineClientSession(AuthenticatedClientSessionModel clientSession, UserSessionModel offlineUserSession) { UserSessionAdapter userSessionAdapter = (offlineUserSession instanceof UserSessionAdapter) ? (UserSessionAdapter) offlineUserSession : getOfflineUserSession(offlineUserSession.getRealm(), offlineUserSession.getId()); AuthenticatedClientSessionAdapter offlineClientSession = importClientSession(userSessionAdapter, clientSession); // update timestamp to current time offlineClientSession.setTimestamp(Time.currentTime()); return offlineClientSession; } @Override public List<UserSessionModel> getOfflineUserSessions(RealmModel realm, UserModel user) { Iterator<Map.Entry<String, SessionEntity>> itr = offlineSessionCache.entrySet().stream().filter(UserSessionPredicate.create(realm.getId()).user(user.getId())).iterator(); List<UserSessionModel> userSessions = new LinkedList<>(); while(itr.hasNext()) { UserSessionEntity entity = (UserSessionEntity) itr.next().getValue(); UserSessionModel userSession = wrap(realm, entity, true); userSessions.add(userSession); } return userSessions; } @Override public long getOfflineSessionsCount(RealmModel realm, ClientModel client) { return getUserSessionsCount(realm, client, true); } @Override public List<UserSessionModel> getOfflineUserSessions(RealmModel realm, ClientModel client, int first, int max) { return getUserSessions(realm, client, first, max, true); } @Override public UserSessionAdapter importUserSession(UserSessionModel userSession, boolean offline, boolean importAuthenticatedClientSessions) { UserSessionEntity entity = new UserSessionEntity(); entity.setId(userSession.getId()); entity.setRealm(userSession.getRealm().getId()); entity.setAuthMethod(userSession.getAuthMethod()); entity.setBrokerSessionId(userSession.getBrokerSessionId()); entity.setBrokerUserId(userSession.getBrokerUserId()); entity.setIpAddress(userSession.getIpAddress()); entity.setLoginUsername(userSession.getLoginUsername()); entity.setNotes(userSession.getNotes()== null ? new ConcurrentHashMap<>() : userSession.getNotes()); entity.setAuthenticatedClientSessions(new ConcurrentHashMap<>()); entity.setRememberMe(userSession.isRememberMe()); entity.setState(userSession.getState()); entity.setUser(userSession.getUser().getId()); entity.setStarted(userSession.getStarted()); entity.setLastSessionRefresh(userSession.getLastSessionRefresh()); Cache<String, SessionEntity> cache = getCache(offline); tx.put(cache, userSession.getId(), entity); UserSessionAdapter importedSession = wrap(userSession.getRealm(), entity, offline); // Handle client sessions if (importAuthenticatedClientSessions) { for (AuthenticatedClientSessionModel clientSession : userSession.getAuthenticatedClientSessions().values()) { importClientSession(importedSession, clientSession); } } return importedSession; } private AuthenticatedClientSessionAdapter importClientSession(UserSessionAdapter importedUserSession, AuthenticatedClientSessionModel clientSession) { AuthenticatedClientSessionEntity entity = new AuthenticatedClientSessionEntity(); entity.setAction(clientSession.getAction()); entity.setAuthMethod(clientSession.getProtocol()); entity.setNotes(clientSession.getNotes()); entity.setProtocolMappers(clientSession.getProtocolMappers()); entity.setRedirectUri(clientSession.getRedirectUri()); entity.setRoles(clientSession.getRoles()); entity.setTimestamp(clientSession.getTimestamp()); Map<String, AuthenticatedClientSessionEntity> clientSessions = importedUserSession.getEntity().getAuthenticatedClientSessions(); clientSessions.put(clientSession.getClient().getId(), entity); importedUserSession.update(); return new AuthenticatedClientSessionAdapter(entity, clientSession.getClient(), importedUserSession, this, importedUserSession.getCache()); } @Override public ClientInitialAccessModel createClientInitialAccessModel(RealmModel realm, int expiration, int count) { String id = KeycloakModelUtils.generateId(); ClientInitialAccessEntity entity = new ClientInitialAccessEntity(); entity.setId(id); entity.setRealm(realm.getId()); entity.setTimestamp(Time.currentTime()); entity.setExpiration(expiration); entity.setCount(count); entity.setRemainingCount(count); tx.put(sessionCache, id, entity); return wrap(realm, entity); } @Override public ClientInitialAccessModel getClientInitialAccessModel(RealmModel realm, String id) { Cache<String, SessionEntity> cache = getCache(false); ClientInitialAccessEntity entity = (ClientInitialAccessEntity) tx.get(cache, id); // Chance created in this transaction if (entity == null) { entity = (ClientInitialAccessEntity) cache.get(id); } return wrap(realm, entity); } @Override public void removeClientInitialAccessModel(RealmModel realm, String id) { tx.remove(getCache(false), id); } @Override public List<ClientInitialAccessModel> listClientInitialAccess(RealmModel realm) { Iterator<Map.Entry<String, SessionEntity>> itr = sessionCache.entrySet().stream().filter(ClientInitialAccessPredicate.create(realm.getId())).iterator(); List<ClientInitialAccessModel> list = new LinkedList<>(); while (itr.hasNext()) { list.add(wrap(realm, (ClientInitialAccessEntity) itr.next().getValue())); } return list; } }