/*
* Copyright 2016 Red Hat, Inc. and/or its affiliates
* and other contributors as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.keycloak.services;
import org.keycloak.component.ComponentFactory;
import org.keycloak.component.ComponentModel;
import org.keycloak.credential.UserCredentialStoreManager;
import org.keycloak.keys.DefaultKeyManager;
import org.keycloak.models.KeycloakContext;
import org.keycloak.models.KeycloakSession;
import org.keycloak.models.KeycloakSessionFactory;
import org.keycloak.models.KeycloakTransactionManager;
import org.keycloak.models.KeyManager;
import org.keycloak.models.RealmProvider;
import org.keycloak.models.UserCredentialManager;
import org.keycloak.models.UserProvider;
import org.keycloak.models.UserSessionProvider;
import org.keycloak.models.cache.CacheRealmProvider;
import org.keycloak.models.cache.UserCache;
import org.keycloak.provider.Provider;
import org.keycloak.provider.ProviderFactory;
import org.keycloak.sessions.AuthenticationSessionProvider;
import org.keycloak.storage.UserStorageManager;
import org.keycloak.storage.federated.UserFederatedStorageProvider;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
/**
* @author <a href="mailto:sthorger@redhat.com">Stian Thorgersen</a>
*/
public class DefaultKeycloakSession implements KeycloakSession {
private final DefaultKeycloakSessionFactory factory;
private final Map<Integer, Provider> providers = new HashMap<>();
private final List<Provider> closable = new LinkedList<Provider>();
private final DefaultKeycloakTransactionManager transactionManager;
private final Map<String, Object> attributes = new HashMap<>();
private RealmProvider model;
private UserStorageManager userStorageManager;
private UserCredentialStoreManager userCredentialStorageManager;
private UserSessionProvider sessionProvider;
private AuthenticationSessionProvider authenticationSessionProvider;
private UserFederatedStorageProvider userFederatedStorageProvider;
private KeycloakContext context;
private KeyManager keyManager;
public DefaultKeycloakSession(DefaultKeycloakSessionFactory factory) {
this.factory = factory;
this.transactionManager = new DefaultKeycloakTransactionManager(this);
context = new DefaultKeycloakContext(this);
}
@Override
public KeycloakContext getContext() {
return context;
}
private RealmProvider getRealmProvider() {
CacheRealmProvider cache = getProvider(CacheRealmProvider.class);
if (cache != null) {
return cache;
} else {
return getProvider(RealmProvider.class);
}
}
@Override
public UserCache userCache() {
return getProvider(UserCache.class);
}
@Override
public void enlistForClose(Provider provider) {
closable.add(provider);
}
@Override
public Object getAttribute(String attribute) {
return attributes.get(attribute);
}
@Override
public Object removeAttribute(String attribute) {
return attributes.remove(attribute);
}
@Override
public void setAttribute(String name, Object value) {
attributes.put(name, value);
}
@Override
public KeycloakTransactionManager getTransactionManager() {
return transactionManager;
}
@Override
public KeycloakSessionFactory getKeycloakSessionFactory() {
return factory;
}
@Override
public UserFederatedStorageProvider userFederatedStorage() {
if (userFederatedStorageProvider == null) {
userFederatedStorageProvider = getProvider(UserFederatedStorageProvider.class);
}
return userFederatedStorageProvider;
}
@Override
public UserProvider userLocalStorage() {
return getProvider(UserProvider.class);
}
@Override
public UserProvider userStorageManager() {
if (userStorageManager == null) userStorageManager = new UserStorageManager(this);
return userStorageManager;
}
@Override
public UserProvider users() {
UserCache cache = getProvider(UserCache.class);
if (cache != null) {
return cache;
} else {
return userStorageManager();
}
}
@Override
public UserCredentialManager userCredentialManager() {
if (userCredentialStorageManager == null) userCredentialStorageManager = new UserCredentialStoreManager(this);
return userCredentialStorageManager;
}
public <T extends Provider> T getProvider(Class<T> clazz) {
Integer hash = clazz.hashCode();
T provider = (T) providers.get(hash);
if (provider == null) {
ProviderFactory<T> providerFactory = factory.getProviderFactory(clazz);
if (providerFactory != null) {
provider = providerFactory.create(this);
providers.put(hash, provider);
}
}
return provider;
}
public <T extends Provider> T getProvider(Class<T> clazz, String id) {
Integer hash = clazz.hashCode() + id.hashCode();
T provider = (T) providers.get(hash);
if (provider == null) {
ProviderFactory<T> providerFactory = factory.getProviderFactory(clazz, id);
if (providerFactory != null) {
provider = providerFactory.create(this);
providers.put(hash, provider);
}
}
return provider;
}
@Override
public <T extends Provider> T getProvider(Class<T> clazz, ComponentModel componentModel) {
String modelId = componentModel.getId();
Object found = getAttribute(modelId);
if (found != null) {
return clazz.cast(found);
}
ProviderFactory<T> providerFactory = factory.getProviderFactory(clazz, componentModel.getProviderId());
if (providerFactory == null) {
return null;
}
ComponentFactory<T, T> componentFactory = (ComponentFactory<T, T>) providerFactory;
T provider = componentFactory.create(this, componentModel);
enlistForClose(provider);
setAttribute(modelId, provider);
return provider;
}
public <T extends Provider> Set<String> listProviderIds(Class<T> clazz) {
return factory.getAllProviderIds(clazz);
}
@Override
public <T extends Provider> Set<T> getAllProviders(Class<T> clazz) {
Set<T> providers = new HashSet<T>();
for (String id : listProviderIds(clazz)) {
providers.add(getProvider(clazz, id));
}
return providers;
}
@Override
public Class<? extends Provider> getProviderClass(String providerClassName) {
return factory.getProviderClass(providerClassName);
}
@Override
public RealmProvider realms() {
if (model == null) {
model = getRealmProvider();
}
return model;
}
@Override
public UserSessionProvider sessions() {
if (sessionProvider == null) {
sessionProvider = getProvider(UserSessionProvider.class);
}
return sessionProvider;
}
@Override
public AuthenticationSessionProvider authenticationSessions() {
if (authenticationSessionProvider == null) {
authenticationSessionProvider = getProvider(AuthenticationSessionProvider.class);
}
return authenticationSessionProvider;
}
@Override
public KeyManager keys() {
if (keyManager == null) {
keyManager = new DefaultKeyManager(this);
}
return keyManager;
}
public void close() {
for (Provider p : providers.values()) {
try {
p.close();
} catch (Exception e) {
}
}
for (Provider p : closable) {
try {
p.close();
} catch (Exception e) {
}
}
}
}