/*
* Copyright 2016 Red Hat, Inc. and/or its affiliates
* and other contributors as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.keycloak.connections.mongo;
import com.mongodb.DB;
import com.mongodb.MongoClient;
import com.mongodb.MongoClientOptions;
import com.mongodb.MongoClientURI;
import com.mongodb.MongoCredential;
import com.mongodb.ServerAddress;
import org.jboss.logging.Logger;
import org.keycloak.Config;
import org.keycloak.connections.mongo.api.MongoStore;
import org.keycloak.connections.mongo.impl.MongoStoreImpl;
import org.keycloak.connections.mongo.impl.context.TransactionMongoStoreInvocationContext;
import org.keycloak.connections.mongo.updater.MongoUpdaterProvider;
import org.keycloak.models.KeycloakSession;
import org.keycloak.models.KeycloakSessionFactory;
import org.keycloak.models.KeycloakSessionTask;
import org.keycloak.models.dblock.DBLockManager;
import org.keycloak.models.dblock.DBLockProvider;
import org.keycloak.models.utils.KeycloakModelUtils;
import org.keycloak.provider.ServerInfoAwareProviderFactory;
import javax.net.ssl.SSLSocketFactory;
import java.lang.reflect.Method;
import java.net.UnknownHostException;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Map;
/**
* @author <a href="mailto:sthorger@redhat.com">Stian Thorgersen</a>
*/
public class DefaultMongoConnectionFactoryProvider implements MongoConnectionProviderFactory, ServerInfoAwareProviderFactory {
enum MigrationStrategy {
UPDATE, VALIDATE
}
// TODO Make it dynamic
private String[] entities = new String[]{
"org.keycloak.models.mongo.keycloak.entities.MongoRealmEntity",
"org.keycloak.models.mongo.keycloak.entities.MongoUserEntity",
"org.keycloak.models.mongo.keycloak.entities.MongoRoleEntity",
"org.keycloak.models.mongo.keycloak.entities.MongoGroupEntity",
"org.keycloak.models.mongo.keycloak.entities.MongoClientEntity",
"org.keycloak.models.mongo.keycloak.entities.MongoClientTemplateEntity",
"org.keycloak.models.mongo.keycloak.entities.MongoUserConsentEntity",
"org.keycloak.models.mongo.keycloak.entities.MongoMigrationModelEntity",
"org.keycloak.models.mongo.keycloak.entities.MongoOnlineUserSessionEntity",
"org.keycloak.models.mongo.keycloak.entities.MongoOfflineUserSessionEntity",
"org.keycloak.models.mongo.keycloak.entities.IdentityProviderEntity",
"org.keycloak.models.mongo.keycloak.entities.ClientIdentityProviderMappingEntity",
"org.keycloak.models.mongo.keycloak.entities.RequiredCredentialEntity",
"org.keycloak.models.mongo.keycloak.entities.CredentialEntity",
"org.keycloak.models.mongo.keycloak.entities.FederatedIdentityEntity",
"org.keycloak.models.mongo.keycloak.entities.UserFederationProviderEntity",
"org.keycloak.models.mongo.keycloak.entities.UserFederationMapperEntity",
"org.keycloak.models.mongo.keycloak.entities.ProtocolMapperEntity",
"org.keycloak.models.mongo.keycloak.entities.IdentityProviderMapperEntity",
"org.keycloak.models.mongo.keycloak.entities.AuthenticationExecutionEntity",
"org.keycloak.models.mongo.keycloak.entities.AuthenticationFlowEntity",
"org.keycloak.models.mongo.keycloak.entities.AuthenticatorConfigEntity",
"org.keycloak.models.mongo.keycloak.entities.RequiredActionProviderEntity",
"org.keycloak.models.mongo.keycloak.entities.PersistentUserSessionEntity",
"org.keycloak.models.mongo.keycloak.entities.PersistentClientSessionEntity",
"org.keycloak.models.mongo.keycloak.entities.ComponentEntity",
"org.keycloak.storage.mongo.entity.FederatedUser",
"org.keycloak.authorization.mongo.entities.PolicyEntity",
"org.keycloak.authorization.mongo.entities.ResourceEntity",
"org.keycloak.authorization.mongo.entities.ResourceServerEntity",
"org.keycloak.authorization.mongo.entities.ScopeEntity"
};
private static final Logger logger = Logger.getLogger(DefaultMongoConnectionFactoryProvider.class);
private static final int STATE_BEFORE_INIT = 0; // Even before MongoClient is created
private static final int STATE_BEFORE_UPDATE = 1; // Mongo client was created, but DB is not yet updated to last version
private static final int STATE_AFTER_UPDATE = 2; // Mongo client was created and DB updated. DB is fully initialized now
private volatile int state = STATE_BEFORE_INIT;
private MongoClient client;
private MongoStore mongoStore;
private DB db;
protected Config.Scope config;
private Map<String,String> operationalInfo;
@Override
public void init(Config.Scope config) {
this.config = config;
}
@Override
public void postInit(KeycloakSessionFactory factory) {
}
@Override
public DB getDBBeforeUpdate() {
lazyInitBeforeUpdate();
return db;
}
private void lazyInitBeforeUpdate() {
if (state == STATE_BEFORE_INIT) {
synchronized (this) {
if (state == STATE_BEFORE_INIT) {
try {
this.client = createMongoClient();
String dbName = config.get("db", "keycloak");
this.db = client.getDB(dbName);
state = STATE_BEFORE_UPDATE;
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}
}
}
@Override
public MongoConnectionProvider create(KeycloakSession session) {
lazyInit(session);
TransactionMongoStoreInvocationContext invocationContext = new TransactionMongoStoreInvocationContext(mongoStore);
session.getTransactionManager().enlist(new MongoKeycloakTransaction(invocationContext));
return new DefaultMongoConnectionProvider(db, mongoStore, invocationContext);
}
private void lazyInit(KeycloakSession session) {
lazyInitBeforeUpdate();
if (state == STATE_BEFORE_UPDATE) {
synchronized (this) {
if (state == STATE_BEFORE_UPDATE) {
try {
update(session);
this.mongoStore = new MongoStoreImpl(db, getManagedEntities());
state = STATE_AFTER_UPDATE;
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}
}
}
private void update(KeycloakSession session) {
MigrationStrategy strategy = getMigrationStrategy();
MongoUpdaterProvider mongoUpdater = session.getProvider(MongoUpdaterProvider.class);
if (mongoUpdater == null) {
throw new RuntimeException("Can't update database: Mongo updater provider not found");
}
DBLockProvider dbLock = new DBLockManager(session).getDBLock();
if (dbLock.hasLock()) {
updateOrValidateDB(strategy, session, mongoUpdater);
} else {
logger.trace("Don't have DBLock retrieved before upgrade. Needs to acquire lock first in separate transaction");
KeycloakModelUtils.runJobInTransaction(session.getKeycloakSessionFactory(), new KeycloakSessionTask() {
@Override
public void run(KeycloakSession lockSession) {
DBLockManager dbLockManager = new DBLockManager(lockSession);
DBLockProvider dbLock2 = dbLockManager.getDBLock();
dbLock2.waitForLock();
try {
updateOrValidateDB(strategy, session, mongoUpdater);
} finally {
dbLock2.releaseLock();
}
}
});
}
}
private Class[] getManagedEntities() throws ClassNotFoundException {
Class[] entityClasses = new Class[entities.length];
for (int i = 0; i < entities.length; i++) {
entityClasses[i] = getClass().getClassLoader().loadClass(entities[i]);
}
return entityClasses;
}
protected void updateOrValidateDB(MigrationStrategy strategy, KeycloakSession session, MongoUpdaterProvider mongoUpdater) {
switch (strategy) {
case UPDATE:
mongoUpdater.update(session, db);
break;
case VALIDATE:
mongoUpdater.validate(session, db);
break;
}
}
@Override
public void close() {
if (client != null) {
client.close();
}
}
@Override
public String getId() {
return "default";
}
/**
* Override this method if you want more possibility to configure Mongo client. It can be also used to inject mongo client
* from different source.
*
* This method can assume that "config" is already set and can use it.
*
* @return mongoClient instance, which will be shared for whole Keycloak
*
* @throws UnknownHostException
*/
protected MongoClient createMongoClient() throws UnknownHostException {
operationalInfo = new LinkedHashMap<>();
String dbName = config.get("db", "keycloak");
String uriString = config.get("uri");
if (uriString != null) {
MongoClientURI uri = new MongoClientURI(uriString);
MongoClient client = new MongoClient(uri);
StringBuilder hostsBuilder = new StringBuilder();
for (int i=0 ; i<uri.getHosts().size() ; i++) {
if (i!=0) {
hostsBuilder.append(", ");
}
hostsBuilder.append(uri.getHosts().get(i));
}
String hosts = hostsBuilder.toString();
operationalInfo.put("mongoHosts", hosts);
operationalInfo.put("mongoDatabaseName", dbName);
operationalInfo.put("mongoUser", uri.getUsername());
logger.debugv("Initialized mongo model. host(s): %s, db: %s", uri.getHosts(), dbName);
return client;
} else {
String host = config.get("host", ServerAddress.defaultHost());
int port = config.getInt("port", ServerAddress.defaultPort());
String user = config.get("user");
String password = config.get("password");
MongoClientOptions clientOptions = getClientOptions();
MongoClient client;
if (user != null && password != null) {
MongoCredential credential = MongoCredential.createCredential(user, dbName, password.toCharArray());
client = new MongoClient(new ServerAddress(host, port), Collections.singletonList(credential), clientOptions);
} else {
client = new MongoClient(new ServerAddress(host, port), clientOptions);
}
operationalInfo.put("mongoServerAddress", client.getAddress().toString());
operationalInfo.put("mongoDatabaseName", dbName);
operationalInfo.put("mongoUser", user);
logger.debugv("Initialized mongo model. host: %s, port: %d, db: %s", host, port, dbName);
return client;
}
}
protected MongoClientOptions getClientOptions() {
MongoClientOptions.Builder builder = MongoClientOptions.builder();
checkIntOption("connectionsPerHost", builder);
checkIntOption("threadsAllowedToBlockForConnectionMultiplier", builder);
checkIntOption("maxWaitTime", builder);
checkIntOption("connectTimeout", builder);
checkIntOption("socketTimeout", builder);
checkBooleanOption("socketKeepAlive", builder);
checkBooleanOption("autoConnectRetry", builder);
if(config.getBoolean("ssl", false)) {
builder.socketFactory(SSLSocketFactory.getDefault());
}
return builder.build();
}
protected void checkBooleanOption(String optionName, MongoClientOptions.Builder builder) {
Boolean val = config.getBoolean(optionName);
if (val != null) {
try {
Method m = MongoClientOptions.Builder.class.getMethod(optionName, boolean.class);
m.invoke(builder, val);
} catch (Exception e) {
throw new IllegalStateException("Problem configuring boolean option " + optionName + " for mongo client. Ensure you used correct value true or false and if this option is supported by mongo driver", e);
}
}
}
protected void checkIntOption(String optionName, MongoClientOptions.Builder builder) {
Integer val = config.getInt(optionName);
if (val != null) {
try {
Method m = MongoClientOptions.Builder.class.getMethod(optionName, int.class);
m.invoke(builder, val);
} catch (Exception e) {
throw new IllegalStateException("Problem configuring int option " + optionName + " for mongo client. Ensure you used correct value (number) and if this option is supported by mongo driver", e);
}
}
}
@Override
public Map<String,String> getOperationalInfo() {
return operationalInfo;
}
private MigrationStrategy getMigrationStrategy() {
String migrationStrategy = config.get("migrationStrategy");
if (migrationStrategy == null) {
// Support 'databaseSchema' for backwards compatibility
migrationStrategy = config.get("databaseSchema");
}
if (migrationStrategy != null) {
return MigrationStrategy.valueOf(migrationStrategy.toUpperCase());
} else {
return MigrationStrategy.UPDATE;
}
}
}