package org.multiverse.commitbarriers; import org.junit.Before; import org.junit.Test; import org.multiverse.api.Txn; import org.multiverse.api.callables.TxnVoidCallable; import org.multiverse.stms.gamma.GammaStm; import org.multiverse.stms.gamma.transactionalobjects.GammaTxnInteger; import java.util.Vector; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicLong; import static org.junit.Assert.assertEquals; import static org.junit.Assert.fail; import static org.multiverse.TestUtils.*; import static org.multiverse.api.TxnThreadLocal.clearThreadLocalTxn; public class CountDownCommitBarrier_StressTest { private AtomicLong totalInc; private AtomicLong commitInc; private int oneOfFails = 4; private int refCount = 50; private int maxPartiesCount = 5; private int spawnCountPerThread = 2 * 1000; private int spawnCount = 5; private GammaTxnInteger[] refs; private ThreadPoolExecutor executor = new ThreadPoolExecutor(50, 50, 0, TimeUnit.SECONDS, new LinkedBlockingQueue<Runnable>()); private ThreadPoolExecutor spawnExecutor = new ThreadPoolExecutor(spawnCount, spawnCount, 0, TimeUnit.SECONDS, new LinkedBlockingQueue<Runnable>()); private GammaStm stm; @Before public void setUp() { clearThreadLocalTxn(); stm = new GammaStm(); commitInc = new AtomicLong(); totalInc = new AtomicLong(); refs = new GammaTxnInteger[refCount]; for (int k = 0; k < refCount; k++) { refs[k] = stm.getDefaultRefFactory().newTxnInteger(0); } } @Test public void test() throws InterruptedException, TimeoutException { for (int k = 0; k < spawnCount; k++) { spawnExecutor.execute(new SpawnTask("SpawnTask-" + k)); } Runnable shutdownTask = new Runnable() { @Override public void run() { spawnExecutor.shutdown(); } }; spawnExecutor.execute(shutdownTask); if (!spawnExecutor.awaitTermination(5, TimeUnit.MINUTES)) { fail("failed to complete test, it took too long"); } System.out.printf("commitInc %s totalInc %s\n", commitInc.get(), totalInc.get()); assertEquals(commitInc.get(), sum()); } public long sum() { long sum = 0; for (int k = 0; k < refCount; k++) { sum += refs[k].atomicGet(); } return sum; } public class SpawnTask implements Runnable { private String name; public SpawnTask(String name) { this.name = name; } @Override public void run() { for (int k = 0; k < spawnCountPerThread; k++) { runOnce(); if (k % 100 == 0) { System.out.println(name + " is at " + k); } } } public void runOnce() { int partyCount = randomInt(maxPartiesCount) + 1; totalInc.addAndGet(partyCount); CountDownCommitBarrier countDownCommitBarrier = new CountDownCommitBarrier(partyCount); Vector<Txn> txns = new Vector<Txn>(); for (int k = 0; k < partyCount; k++) { executor.execute(new WorkerTask(k == 0, countDownCommitBarrier, txns)); } countDownCommitBarrier.awaitOpenUninterruptibly(); if (countDownCommitBarrier.isCommitted()) { commitInc.getAndAdd(partyCount); } } } class WorkerTask implements Runnable { final CountDownCommitBarrier countDownCommitBarrier; final boolean first; private Vector<Txn> txns; WorkerTask(boolean first, CountDownCommitBarrier countDownCommitBarrier, Vector<Txn> txns) { this.countDownCommitBarrier = countDownCommitBarrier; this.txns = txns; this.first = first; } @Override public void run() { try { clearThreadLocalTxn(); doRun(); } catch (IllegalStateException ignore) { } } public void doRun() { stm.getDefaultTxnExecutor().execute(new TxnVoidCallable() { @Override public void call(Txn tx) throws Exception { sleepRandomMs(10); refs[randomInt(refs.length)].getAndIncrement(tx, 1); sleepRandomMs(10); txns.add(tx); if (first && randomOneOf(oneOfFails)) { countDownCommitBarrier.abort(); } countDownCommitBarrier.joinCommitUninterruptibly(tx); } }); } } }