/*
* Copyright 2016 Red Hat, Inc. and/or its affiliates
* and other contributors as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.keycloak.testsuite.model;
import org.jboss.logging.Logger;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.keycloak.models.KeycloakSession;
import org.keycloak.models.KeycloakSessionFactory;
import org.keycloak.models.KeycloakSessionTask;
import org.keycloak.models.dblock.DBLockManager;
import org.keycloak.models.dblock.DBLockProvider;
import org.keycloak.models.dblock.DBLockProviderFactory;
import org.keycloak.models.utils.KeycloakModelUtils;
import java.util.LinkedList;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
/**
* @author <a href="mailto:mposolda@redhat.com">Marek Posolda</a>
*/
public class DBLockTest extends AbstractModelTest {
private static final Logger log = Logger.getLogger(DBLockTest.class);
private static final int SLEEP_TIME_MILLIS = 10;
private static final int THREADS_COUNT = 20;
private static final int ITERATIONS_PER_THREAD = 2;
private static final int LOCK_TIMEOUT_MILLIS = 240000; // Rather bigger to handle slow DB connections in testing env
private static final int LOCK_RECHECK_MILLIS = 10;
@Before
@Override
public void before() throws Exception {
super.before();
// Set timeouts for testing
DBLockManager lockManager = new DBLockManager(session);
DBLockProviderFactory lockFactory = lockManager.getDBLockFactory();
lockFactory.setTimeouts(LOCK_RECHECK_MILLIS, LOCK_TIMEOUT_MILLIS);
// Drop lock table, just to simulate racing threads for create lock table and insert lock record into it.
lockManager.getDBLock().destroyLockInfo();
commit();
}
@Test
public void testLockConcurrently() throws Exception {
long startupTime = System.currentTimeMillis();
final Semaphore semaphore = new Semaphore();
final KeycloakSessionFactory sessionFactory = realmManager.getSession().getKeycloakSessionFactory();
List<Thread> threads = new LinkedList<>();
for (int i=0 ; i<THREADS_COUNT ; i++) {
Thread thread = new Thread() {
@Override
public void run() {
for (int i=0 ; i<ITERATIONS_PER_THREAD ; i++) {
try {
KeycloakModelUtils.runJobInTransaction(sessionFactory, new KeycloakSessionTask() {
@Override
public void run(KeycloakSession session) {
lock(session, semaphore);
}
});
} catch (RuntimeException e) {
semaphore.setException(e);
throw e;
}
}
}
};
threads.add(thread);
}
for (Thread thread : threads) {
thread.start();
}
for (Thread thread : threads) {
thread.join();
}
long took = (System.currentTimeMillis() - startupTime);
log.infof("DBLockTest executed in %d ms with total counter %d. THREADS_COUNT=%d, ITERATIONS_PER_THREAD=%d", took, semaphore.getTotal(), THREADS_COUNT, ITERATIONS_PER_THREAD);
Assert.assertEquals(semaphore.getTotal(), THREADS_COUNT * ITERATIONS_PER_THREAD);
Assert.assertNull(semaphore.getException());
}
private void lock(KeycloakSession session, Semaphore semaphore) {
DBLockProvider dbLock = new DBLockManager(session).getDBLock();
dbLock.waitForLock();
try {
semaphore.increase();
Thread.sleep(SLEEP_TIME_MILLIS);
semaphore.decrease();
} catch (InterruptedException ie) {
throw new RuntimeException(ie);
} finally {
dbLock.releaseLock();
}
}
// Ensure just one thread is allowed to run at the same time
private class Semaphore {
private AtomicInteger counter = new AtomicInteger(0);
private AtomicInteger totalIncreases = new AtomicInteger(0);
private volatile Exception exception = null;
private void increase() {
int current = counter.incrementAndGet();
if (current != 1) {
IllegalStateException ex = new IllegalStateException("Counter has illegal value: " + current);
setException(ex);
throw ex;
}
totalIncreases.incrementAndGet();
}
private void decrease() {
int current = counter.decrementAndGet();
if (current != 0) {
IllegalStateException ex = new IllegalStateException("Counter has illegal value: " + current);
setException(ex);
throw ex;
}
}
private synchronized void setException(Exception exception) {
if (this.exception == null) {
this.exception = exception;
}
}
private synchronized Exception getException() {
return exception;
}
private int getTotal() {
return totalIncreases.get();
}
}
}