/*
* Copyright 2016 Red Hat, Inc. and/or its affiliates
* and other contributors as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.keycloak.testsuite.model;
import org.jboss.logging.Logger;
import org.junit.Assert;
import org.junit.Test;
import org.keycloak.models.ClientModel;
import org.keycloak.models.KeycloakSession;
import org.keycloak.models.KeycloakSessionFactory;
import org.keycloak.models.KeycloakSessionTask;
import org.keycloak.models.RealmModel;
import org.keycloak.models.RealmProvider;
import org.keycloak.models.UserModel;
import org.keycloak.models.utils.KeycloakModelUtils;
import java.util.Arrays;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicReference;
/**
* @author <a href="mailto:mposolda@redhat.com">Marek Posolda</a>
*/
public class ConcurrentTransactionsTest extends AbstractModelTest {
private static final Logger logger = Logger.getLogger(ConcurrentTransactionsTest.class);
@Test
public void persistClient() throws Exception {
RealmModel realm = realmManager.createRealm("original");
KeycloakSession session = realmManager.getSession();
ClientModel client = session.realms().addClient(realm, "client");
client.setSecret("old");
String clientDBId = client.getId();
final KeycloakSessionFactory sessionFactory = session.getKeycloakSessionFactory();
commit();
final CountDownLatch transactionsCounter = new CountDownLatch(2);
final CountDownLatch readLatch = new CountDownLatch(1);
final CountDownLatch updateLatch = new CountDownLatch(1);
Thread thread1 = new Thread() {
@Override
public void run() {
KeycloakModelUtils.runJobInTransaction(sessionFactory, new KeycloakSessionTask() {
@Override
public void run(KeycloakSession session) {
try {
// Wait until transaction in both threads started
transactionsCounter.countDown();
logger.info("transaction1 started");
transactionsCounter.await();
// Read client
RealmModel realm = session.realms().getRealmByName("original");
ClientModel client = session.realms().getClientByClientId("client", realm);
logger.info("transaction1: Read client finished");
readLatch.countDown();
// Wait until thread2 updates client and commits
updateLatch.await();
logger.info("transaction1: Going to read client again");
client = session.realms().getClientByClientId("client", realm);
logger.info("transaction1: secret: " + client.getSecret());
} catch (Exception e) {
throw new RuntimeException(e);
}
}
});
}
};
Thread thread2 = new Thread() {
@Override
public void run() {
KeycloakModelUtils.runJobInTransaction(sessionFactory, new KeycloakSessionTask() {
@Override
public void run(KeycloakSession session) {
try {
// Wait until transaction in both threads started
transactionsCounter.countDown();
logger.info("transaction2 started");
transactionsCounter.await();
readLatch.await();
logger.info("transaction2: Going to update client secret");
RealmModel realm = session.realms().getRealmByName("original");
ClientModel client = session.realms().getClientByClientId("client", realm);
client.setSecret("new");
} catch (Exception e) {
throw new RuntimeException(e);
}
}
});
logger.info("transaction2: commited");
updateLatch.countDown();
}
};
thread1.start();
thread2.start();
thread1.join();
thread2.join();
logger.info("after thread join");
commit();
session = realmManager.getSession();
realm = session.realms().getRealmByName("original");
ClientModel clientFromCache = session.realms().getClientById(clientDBId, realm);
ClientModel clientFromDB = session.getProvider(RealmProvider.class).getClientById(clientDBId, realm);
logger.info("SECRET FROM DB : " + clientFromDB.getSecret());
logger.info("SECRET FROM CACHE : " + clientFromCache.getSecret());
Assert.assertEquals("new", clientFromDB.getSecret());
Assert.assertEquals("new", clientFromCache.getSecret());
}
// KEYCLOAK-3296 , KEYCLOAK-3494
@Test
public void removeUserAttribute() throws Exception {
RealmModel realm = realmManager.createRealm("original");
KeycloakSession session = realmManager.getSession();
UserModel john = session.users().addUser(realm, "john");
john.setSingleAttribute("foo", "val1");
UserModel john2 = session.users().addUser(realm, "john2");
john2.setAttribute("foo", Arrays.asList("val1", "val2"));
final KeycloakSessionFactory sessionFactory = session.getKeycloakSessionFactory();
commit();
AtomicReference<Exception> reference = new AtomicReference<>();
final CountDownLatch readAttrLatch = new CountDownLatch(2);
Runnable runnable = new Runnable() {
@Override
public void run() {
try {
KeycloakModelUtils.runJobInTransaction(sessionFactory, new KeycloakSessionTask() {
@Override
public void run(KeycloakSession session) {
try {
// Read user attribute
RealmModel realm = session.realms().getRealmByName("original");
UserModel john = session.users().getUserByUsername("john", realm);
String attrVal = john.getFirstAttribute("foo");
UserModel john2 = session.users().getUserByUsername("john2", realm);
String attrVal2 = john2.getFirstAttribute("foo");
// Wait until it's read in both threads
readAttrLatch.countDown();
readAttrLatch.await();
// KEYCLOAK-3296 : Remove user attribute in both threads
john.removeAttribute("foo");
// KEYCLOAK-3494 : Set single attribute in both threads
john2.setSingleAttribute("foo", "bar");
} catch (Exception e) {
throw new RuntimeException(e);
}
}
});
} catch (Exception e) {
reference.set(e);
throw new RuntimeException(e);
} finally {
readAttrLatch.countDown();
}
}
};
Thread thread1 = new Thread(runnable);
Thread thread2 = new Thread(runnable);
thread1.start();
thread2.start();
thread1.join();
thread2.join();
logger.info("removeUserAttribute: after thread join");
commit();
if (reference.get() != null) {
Assert.fail("Exception happened in some of threads. Details: " + reference.get().getMessage());
}
}
}