package org.multiverse.stms.gamma.integration.blocking;
import org.junit.Before;
import org.junit.Test;
import org.multiverse.TestThread;
import org.multiverse.api.Txn;
import org.multiverse.api.TxnExecutor;
import org.multiverse.api.callables.TxnCallable;
import org.multiverse.api.callables.TxnIntCallable;
import org.multiverse.api.callables.TxnVoidCallable;
import org.multiverse.stms.gamma.GammaConstants;
import org.multiverse.stms.gamma.GammaStm;
import org.multiverse.stms.gamma.transactionalobjects.GammaTxnInteger;
import org.multiverse.stms.gamma.transactionalobjects.GammaTxnRef;
import java.util.concurrent.atomic.AtomicInteger;
import static org.junit.Assert.*;
import static org.multiverse.TestUtils.*;
import static org.multiverse.api.GlobalStmInstance.getGlobalStmInstance;
import static org.multiverse.api.TxnThreadLocal.clearThreadLocalTxn;
/**
* A StressTest that simulates a database connection pool. The code is quite ugly, but that is because
* no instrumentation is used here.
*
* @author Peter Veentjer.
*/
public abstract class ConnectionPool_AbstractTest implements GammaConstants {
private int poolsize = processorCount();
private int threadCount = processorCount() * 2;
private volatile boolean stop;
private ConnectionPool pool;
protected GammaStm stm;
@Before
public void setUp() {
clearThreadLocalTxn();
stm = (GammaStm) getGlobalStmInstance();
stop = false;
}
protected abstract TxnExecutor newReturnBlock();
protected abstract TxnExecutor newTakeBlock();
@Test
public void sanityTest() {
ConnectionPool pool = new ConnectionPool(2);
Connection c1 = pool.takeConnection();
assertEquals(1, pool.size());
Connection c2 = pool.takeConnection();
assertEquals(0, pool.size());
pool.returnConnection(c1);
assertEquals(1, pool.size());
pool.returnConnection(c2);
assertEquals(2, pool.size());
}
public void run() {
pool = new ConnectionPool(poolsize);
WorkerThread[] threads = createThreads();
startAll(threads);
sleepMs(30 * 1000);
stop = true;
joinAll(threads);
assertEquals(poolsize, pool.size());
}
class ConnectionPool {
final TxnExecutor takeConnectionBlock = newTakeBlock();
final TxnExecutor returnConnectionBlock = newReturnBlock();
final TxnExecutor sizeBlock = stm.newTxnFactoryBuilder().newTxnExecutor();
final GammaTxnInteger size = new GammaTxnInteger(stm);
final GammaTxnRef<Node<Connection>> head = new GammaTxnRef<Node<Connection>>(stm);
ConnectionPool(final int poolsize) {
stm.getDefaultTxnExecutor().execute(new TxnVoidCallable() {
@Override
public void call(Txn tx) {
size.set(poolsize);
Node<Connection> h = null;
for (int k = 0; k < poolsize; k++) {
h = new Node<Connection>(h, new Connection());
}
head.set(h);
}
});
}
Connection takeConnection() {
return takeConnectionBlock.execute(new TxnCallable<Connection>() {
@Override
public Connection call(Txn tx) {
if (size.get() == 0) {
tx.retry();
}
size.decrement();
Node<Connection> current = head.get();
head.set(current.next);
return current.item;
}
});
}
void returnConnection(final Connection c) {
returnConnectionBlock.execute(new TxnVoidCallable() {
@Override
public void call(Txn tx) throws Exception {
size.incrementAndGet(1);
Node<Connection> oldHead = head.get();
head.set(new Node<Connection>(oldHead, c));
}
});
}
int size() {
return sizeBlock.execute(new TxnIntCallable() {
@Override
public int call(Txn tx) throws Exception {
return size.get();
}
});
}
}
static class Node<E> {
final Node<E> next;
final E item;
Node(Node<E> next, E item) {
this.next = next;
this.item = item;
}
}
static class Connection {
AtomicInteger users = new AtomicInteger();
void startUsing() {
if (!users.compareAndSet(0, 1)) {
fail();
}
}
void stopUsing() {
if (!users.compareAndSet(1, 0)) {
fail();
}
}
}
private WorkerThread[] createThreads() {
WorkerThread[] threads = new WorkerThread[threadCount];
for (int k = 0; k < threads.length; k++) {
threads[k] = new WorkerThread(k);
}
return threads;
}
class WorkerThread extends TestThread {
public WorkerThread(int id) {
super("WorkerThread-" + id);
}
@Override
public void doRun() throws Exception {
int k = 0;
while (!stop) {
if (k % 100 == 0) {
System.out.printf("%s is at %s\n", getName(), k);
}
Connection c = pool.takeConnection();
assertNotNull(c);
c.startUsing();
try {
sleepRandomMs(50);
} finally {
c.stopUsing();
pool.returnConnection(c);
}
k++;
}
}
}
}