package org.infinispan.api;
import static org.testng.AssertJUnit.assertEquals;
import static org.testng.AssertJUnit.assertFalse;
import static org.testng.AssertJUnit.assertTrue;
import java.util.HashSet;
import java.util.List;
import java.util.concurrent.BrokenBarrierException;
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import org.infinispan.AdvancedCache;
import org.infinispan.Cache;
import org.infinispan.commons.CacheException;
import org.infinispan.configuration.cache.CacheMode;
import org.infinispan.configuration.cache.ConfigurationBuilder;
import org.infinispan.remoting.transport.Address;
import org.infinispan.test.MultipleCacheManagersTest;
import org.infinispan.transaction.LockingMode;
import org.infinispan.util.concurrent.IsolationLevel;
import org.infinispan.util.concurrent.locks.LockManager;
import org.infinispan.util.logging.Log;
import org.infinispan.util.logging.LogFactory;
import org.testng.annotations.BeforeMethod;
import org.testng.annotations.Test;
/**
* Verifies the atomic semantic of Infinispan's implementations of java.util.concurrent.ConcurrentMap'
* conditional operations.
*
* @author Sanne Grinovero <sanne@infinispan.org> (C) 2012 Red Hat Inc.
* @see java.util.concurrent.ConcurrentMap#replace(Object, Object, Object)
* @since 5.2
*/
@Test(groups = "functional", testName = "api.ConditionalOperationsConcurrentTest")
public class ConditionalOperationsConcurrentTest extends MultipleCacheManagersTest {
private final Log log = LogFactory.getLog(getClass());
public ConditionalOperationsConcurrentTest() {
this(2, 10, 2);
}
public ConditionalOperationsConcurrentTest(int nodes, int operations, int threads) {
this.nodes = nodes;
this.operations = operations;
this.threads = threads;
this.validMoves = generateValidMoves();
}
protected final int nodes;
protected final int operations;
protected final int threads;
private static final String SHARED_KEY = "thisIsTheKeyForConcurrentAccess";
private final String[] validMoves;
private final AtomicBoolean failed = new AtomicBoolean(false);
private final AtomicBoolean quit = new AtomicBoolean(false);
private final AtomicInteger liveWorkers = new AtomicInteger();
private volatile String failureMessage = "";
protected boolean transactional = false;
private final CacheMode mode = CacheMode.DIST_SYNC;
protected LockingMode lockingMode = LockingMode.OPTIMISTIC;
protected boolean writeSkewCheck = false;
@BeforeMethod
public void init() {
failed.set(false);
quit.set(false);
liveWorkers.set(0);
failureMessage = "";
assertEquals(operations, validMoves.length);
}
@Override
protected void createCacheManagers() throws Throwable {
ConfigurationBuilder dcc = getDefaultClusteredCacheConfig(mode, transactional);
dcc.transaction().lockingMode(lockingMode);
if (writeSkewCheck) {
dcc.transaction().locking().isolationLevel(IsolationLevel.REPEATABLE_READ);
}
createCluster(dcc, nodes);
waitForClusterToForm();
}
public void testReplace() throws Exception {
List caches = caches(null);
testOnCaches(caches, new ReplaceOperation(true));
}
public void testConditionalRemove() throws Exception {
List caches = caches(null);
testOnCaches(caches, new ConditionalRemoveOperation(true));
}
public void testPutIfAbsent() throws Exception {
List caches = caches(null);
testOnCaches(caches, new PutIfAbsentOperation(true));
}
protected void testOnCaches(List<Cache> caches, CacheOperation operation) {
failed.set(false);
quit.set(false);
caches.get(0).put(SHARED_KEY, "initialValue");
final SharedState state = new SharedState(threads);
final PostOperationStateCheck stateCheck = new PostOperationStateCheck(caches, state, operation);
final CyclicBarrier barrier = new CyclicBarrier(threads, stateCheck);
final String className = getClass().getSimpleName();//in order to be able filter this test's log file correctly
ExecutorService exec = Executors.newFixedThreadPool(threads, getTestThreadFactory("Mover"));
for (int threadIndex = 0; threadIndex < threads; threadIndex++) {
Runnable validMover = new ValidMover(caches, barrier, threadIndex, state, operation);
exec.execute(validMover);
}
exec.shutdown();
try {
boolean finished = exec.awaitTermination(5, TimeUnit.MINUTES);
assertTrue("Test took too long", finished);
} catch (InterruptedException e) {
fail("Thread interrupted!");
} finally {
// Stop the worker threads so that they don't affect the following tests
exec.shutdownNow();
}
assertFalse(failureMessage, failed.get());
}
private String[] generateValidMoves() {
String[] validMoves = new String[operations];
for (int i = 0; i < operations; i++) {
validMoves[i] = "v_" + i;
}
print("Valid moves ready");
return validMoves;
}
private void fail(final String message) {
boolean firstFailure = failed.compareAndSet(false, true);
if (firstFailure) {
failureMessage = message;
}
}
private void fail(final Exception e) {
log.error("Failing because of exception", e);
fail(e.toString());
}
final class ValidMover implements Runnable {
private final List<Cache> caches;
private final int threadIndex;
private final CyclicBarrier barrier;
private final SharedState state;
private final CacheOperation operation;
public ValidMover(List<Cache> caches, CyclicBarrier barrier, int threadIndex, SharedState state, CacheOperation operation) {
this.caches = caches;
this.barrier = barrier;
this.threadIndex = threadIndex;
this.state = state;
this.operation = operation;
}
@Override
public void run() {
int cachePickIndex = threadIndex;
liveWorkers.incrementAndGet();
try {
for (int moveToIndex = threadIndex;
moveToIndex < validMoves.length && !barrier.isBroken() && !failed.get() && !quit.get();
moveToIndex += threads) {
operation.beforeOperation(caches.get(0));
cachePickIndex = ++cachePickIndex % caches.size();
Cache cache = caches.get(cachePickIndex);
Object existing = cache.get(SHARED_KEY);
String targetValue = validMoves[moveToIndex];
state.beforeOperation(threadIndex, existing, targetValue);
blockAtTheBarrier();
boolean successful = operation.execute(cache, SHARED_KEY, existing, targetValue);
state.afterOperation(threadIndex, existing, targetValue, successful);
blockAtTheBarrier();
}
//not all threads might finish at the same block, so make sure none stays waiting for us when we exit
quit.set(true);
barrier.reset();
} catch (InterruptedException | RuntimeException e) {
log.error("Caught exception", e);
fail(e);
} catch (BrokenBarrierException e) {
log.error("Caught exception", e);
//just quit
print("Broken barrier!");
} finally {
int andGet = liveWorkers.decrementAndGet();
barrier.reset();
print("Thread #" + threadIndex + " terminating. Still " + andGet + " threads alive");
}
}
private void blockAtTheBarrier() throws InterruptedException, BrokenBarrierException {
try {
barrier.await(10000, TimeUnit.MILLISECONDS);
} catch (TimeoutException e) {
if (!quit.get()) {
throw new RuntimeException(e);
}
}
}
}
static final class SharedState {
private final SharedThreadState[] threadStates;
private volatile boolean after = false;
public SharedState(final int threads) {
threadStates = new SharedThreadState[threads];
for (int i = 0; i < threads; i++) {
threadStates[i] = new SharedThreadState();
}
}
synchronized void beforeOperation(int threadIndex, Object expected, String targetValue) {
threadStates[threadIndex].beforeReplace(expected, targetValue);
after = false;
}
synchronized void afterOperation(int threadIndex, Object expected, String targetValue, boolean successful) {
threadStates[threadIndex].afterReplace(expected, targetValue, successful);
after = true;
}
public boolean isAfter() {
return after;
}
}
static final class SharedThreadState {
Object beforeExpected;
Object beforeTargetValue;
Object afterExpected;
Object afterTargetValue;
boolean successfulOperation;
public void beforeReplace(Object expected, Object targetValue) {
this.beforeExpected = expected;
this.beforeTargetValue = targetValue;
}
public void afterReplace(Object expected, Object targetValue, boolean replaced) {
this.afterExpected = expected;
this.afterTargetValue = targetValue;
this.successfulOperation = replaced;
}
public boolean sameBeforeValue(Object currentStored) {
return currentStored == null ? beforeExpected == null : currentStored.equals(beforeExpected);
}
}
final class PostOperationStateCheck implements Runnable {
private final List<Cache> caches;
private final SharedState state;
private final CacheOperation operation;
private volatile int cycle = 0;
public PostOperationStateCheck(final List<Cache> caches, final SharedState state, CacheOperation operation) {
this.caches = caches;
this.state = state;
this.operation = operation;
}
@Override
public void run() {
if (state.isAfter()) {
cycle++;
log.tracef("Starting cycle %d", cycle);
if (cycle % Math.max(operations / 100, 1) == 0) {
print((cycle * 100 * threads / operations) + "%");
}
checkAfterState();
} else {
checkBeforeState();
}
}
private void checkSameValueOnAllCaches() {
final Object currentStored = caches.get(0).get(SHARED_KEY);
log.tracef("Value seen by (first) cache %s is %s ", caches.get(0).getAdvancedCache().getRpcManager().getAddress(),
currentStored);
for (Cache c : caches) {
Object v = c.get(SHARED_KEY);
Address currentCache = c.getAdvancedCache().getRpcManager().getAddress();
log.tracef("Value seen by cache %s is %s", currentCache, v);
boolean sameValue = v == null ? currentStored == null : v.equals(currentStored);
if (!sameValue) {
fail("Not all the caches see the same value. first cache: " + currentStored + " cache " + currentCache +" saw " + v);
}
}
}
private void checkBeforeState() {
final Object currentStored = caches.get(0).get(SHARED_KEY);
for (SharedThreadState threadState : state.threadStates) {
if ( !threadState.sameBeforeValue(currentStored)) {
fail("Some cache expected a different value than what is stored");
}
}
}
private void checkAfterState() {
final Object currentStored = assertTestCorrectness();
checkSameValueOnAllCaches();
if (operation.isCas()) {
checkSingleSuccessfulThread();
checkSuccessfulOperation(currentStored);
}
checkNoLocks();
}
private Object assertTestCorrectness() {
AdvancedCache someCache = caches.get(0).getAdvancedCache();
final Object currentStored = someCache.get(SHARED_KEY);
HashSet uniqueValueVerify = new HashSet();
for (SharedThreadState threadState : state.threadStates) {
uniqueValueVerify.add(threadState.afterTargetValue);
}
if (uniqueValueVerify.size() != threads) {
fail("test bug");
}
return currentStored;
}
private void checkNoLocks() {
for (Cache c : caches) {
LockManager lockManager = c.getAdvancedCache().getComponentRegistry().getComponent(LockManager.class);
//locks might be released async, so give it some time
boolean isLocked = true;
for (int i = 0; i < 30; i++) {
if (!lockManager.isLocked(SHARED_KEY)) {
isLocked = false;
break;
}
try {
Thread.sleep(500);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
if (isLocked) {
fail("lock on the entry wasn't cleaned up");
}
}
}
private void checkSuccessfulOperation(Object currentStored) {
for (SharedThreadState threadState : state.threadStates) {
if (threadState.successfulOperation) {
if (!operation.validateTargetValueForSuccess(threadState.afterTargetValue, currentStored)) {
fail("operation successful but the current stored value doesn't match the write operation of the successful thread");
}
} else {
if (threadState.afterTargetValue.equals(currentStored)) {
fail("operation not successful (which is fine) but the current stored value matches the write attempt");
}
}
}
}
private void checkSingleSuccessfulThread() {
//for CAS operations there's only one successful thread
int successfulThreads = 0;
for (SharedThreadState threadState : state.threadStates) {
if (threadState.successfulOperation) {
successfulThreads++;
}
}
if (successfulThreads != 1) {
fail(successfulThreads + " threads assume a successful replacement! (CAS should succeed on a single thread only)");
}
}
}
public static abstract class CacheOperation {
private final boolean isCas;
protected CacheOperation(boolean cas) {
isCas = cas;
}
public final boolean isCas() {
return isCas;
}
abstract boolean execute(Cache cache, String sharedKey, Object existing, String targetValue);
abstract void beforeOperation(Cache cache);
boolean validateTargetValueForSuccess(Object afterTargetValue, Object currentStored) {
return afterTargetValue.equals(currentStored);
}
}
static class ReplaceOperation extends CacheOperation {
ReplaceOperation(boolean cas) {
super(cas);
}
@Override
public boolean execute(Cache cache, String sharedKey, Object existing, String targetValue) {
try {
return cache.replace(SHARED_KEY, existing, targetValue);
} catch (CacheException e) {
return false;
}
}
@Override
public void beforeOperation(Cache cache) {
}
}
class PutIfAbsentOperation extends CacheOperation {
PutIfAbsentOperation(boolean cas) {
super(cas);
}
@Override
public boolean execute(Cache cache, String sharedKey, Object existing, String targetValue) {
try {
Object o = cache.putIfAbsent(SHARED_KEY, targetValue);
return o == null;
} catch (CacheException e) {
return false;
}
}
@Override
public void beforeOperation(Cache cache) {
try {
cache.remove(SHARED_KEY);
} catch (CacheException e) {
log.debug("Write skew check error while removing the key", e);
}
}
}
class ConditionalRemoveOperation extends CacheOperation {
ConditionalRemoveOperation(boolean cas) {
super(cas);
}
@Override
public boolean execute(Cache cache, String sharedKey, Object existing, String targetValue) {
try {
return cache.remove(SHARED_KEY, existing);
} catch (CacheException e) {
return false;
}
}
@Override
public void beforeOperation(Cache cache) {
try {
cache.put(SHARED_KEY, "someValue");
} catch (CacheException e) {
log.warn("Write skew check error while inserting the key", e);
}
}
@Override
boolean validateTargetValueForSuccess(Object afterTargetValue, Object currentStored) {
return currentStored == null;
}
}
private void print(String s) {
log.debug(s);
}
}