package org.multiverse;
import org.multiverse.api.Txn;
import org.multiverse.api.TxnStatus;
import org.multiverse.api.blocking.RetryLatch;
import org.multiverse.stms.gamma.GammaConstants;
import org.multiverse.stms.gamma.transactionalobjects.AbstractGammaObject;
import org.multiverse.utils.Bugshaker;
import org.multiverse.utils.ThreadLocalRandom;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.io.Writer;
import java.lang.reflect.Field;
import java.util.Arrays;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.TimeUnit;
import static java.lang.String.format;
import static org.junit.Assert.*;
/**
* @author Peter Veentjer
*/
public class TestUtils implements MultiverseConstants {
public static void assertOrecValue(AbstractGammaObject object, long expected) {
assertEquals(expected, object.orec);
}
public static void assertFailure(int value) {
assertEquals(GammaConstants.FAILURE, value);
}
public static void assertHasMasks(int value, int... masks) {
for (int mask : masks) {
assertTrue((value & mask) > 0);
}
}
public static void assertNotHasMasks(int value, int... masks) {
for (int mask : masks) {
assertTrue((value & mask) == 0);
}
}
public static void clearCurrentThreadInterruptedStatus() {
Thread.interrupted();
}
public static void assertEqualsDouble(String msg, double expected, double found) {
assertEquals(msg, Double.doubleToLongBits(expected), Double.doubleToLongBits(found));
}
public static void assertEqualsDouble(double expected, double found) {
assertEqualsDouble(format("expected %s found %s", expected, found), expected, found);
}
public static int processorCount() {
return Runtime.getRuntime().availableProcessors();
}
public static void assertEra(RetryLatch latch, long era) {
assertEquals(era, latch.getEra());
}
public static void assertOpen(RetryLatch latch) {
assertTrue(latch.isOpen());
}
public static void assertClosed(RetryLatch latch) {
assertFalse(latch.isOpen());
}
public static void assertEqualByteArray(byte[] array1, byte[] array2) {
if (array1 == array2) {
return;
}
if (array1 == null) {
fail();
}
int length = array1.length;
assertEquals(length, array2.length);
for (int k = 0; k < array1.length; k++) {
assertEquals(array1[k], array2[k]);
}
}
public static Object getField(Object o, String fieldname) {
if (o == null || fieldname == null) {
throw new NullPointerException();
}
try {
Field field = findField(o.getClass(), fieldname);
if (field == null) {
fail(format("field '%s' is not found on class '%s' or on one of its super classes", fieldname, o.getClass().getName()));
}
field.setAccessible(true);
return field.get(o);
} catch (IllegalAccessException e) {
throw new RuntimeException(e);
}
}
public static Field findField(Class clazz, String fieldname) {
try {
return clazz.getDeclaredField(fieldname);
} catch (NoSuchFieldException e) {
if (clazz.equals(Object.class)) {
return null;
}
return findField(clazz.getSuperclass(), fieldname);
}
}
public static void assertNotEquals(long l1, long l2) {
assertFalse(format("both values are %s, but should not be equal", l2), l1 == l2);
}
public static void assertIsPrepared(Txn... txns) {
for (Txn tx : txns) {
assertEquals(TxnStatus.Prepared, tx.getStatus());
}
}
public static void assertIsAborted(Txn... txns) {
for (Txn tx : txns) {
assertEquals(TxnStatus.Aborted, tx.getStatus());
}
}
public static void assertIsCommitted(Txn... txns) {
for (Txn tx : txns) {
assertEquals(TxnStatus.Committed, tx.getStatus());
}
}
public static void assertIsActive(Txn... txns) {
for (Txn tx : txns) {
assertEquals(TxnStatus.Active, tx.getStatus());
}
}
public static int randomInt(int max) {
if (max <= 0) {
return 0;
}
return ThreadLocalRandom.current().nextInt(max);
}
public static void sleepRandomMs(int maxMs) {
Bugshaker.sleepUs((long) randomInt((int) TimeUnit.MILLISECONDS.toMicros(maxMs)));
}
public static void sleepMs(long ms) {
long us = TimeUnit.MILLISECONDS.toMicros(ms);
Bugshaker.sleepUs(us);
}
public static boolean randomBoolean() {
return randomInt(10) % 2 == 0;
}
public static boolean randomOneOf(int chance) {
return randomInt(Integer.MAX_VALUE) % chance == 0;
}
public static long getStressTestDurationMs(long defaultDuration) {
String value = System.getProperty("org.multiverse.integrationtest.durationMs", String.valueOf(defaultDuration));
return Long.parseLong(value);
}
public static void assertIsInterrupted(Thread t) {
assertTrue(t.isInterrupted());
}
public static void assertAlive(Thread... threads) {
for (Thread thread : threads) {
assertTrue(thread.getName(), thread.isAlive());
}
}
public static boolean isAlive(Thread... threads) {
for (Thread thread : threads) {
if (!thread.isAlive()) {
return false;
}
}
return true;
}
public static void assertNothingThrown(TestThread... threads){
for(TestThread t: threads){
Throwable throwable = t.getThrowable();
if(throwable != null){
fail(String.format("TestThread [%s] failed with the following exception\n%s",
t.getName(),getStackTrace(throwable)));
}
}
}
public static String getStackTrace(Throwable aThrowable) {
final Writer result = new StringWriter();
final PrintWriter printWriter = new PrintWriter(result);
aThrowable.printStackTrace(printWriter);
return result.toString();
}
public static void assertNotAlive(Thread... threads) {
for (Thread thread : threads) {
assertFalse(thread.isAlive());
}
}
public static void assertEventuallyNotAlive(Thread... threads){
assertEventuallyNotAlive(60 * 1000, threads);
}
public static void assertEventuallyNotAlive(long timeoutMs, Thread... threads) {
for (Thread thread : threads) {
if(timeoutMs <=0){
fail("There is no remaining timeout");
}
long startMs = System.currentTimeMillis();
try {
thread.join(timeoutMs);
} catch (InterruptedException e) {
fail("Failed to join thread: " + thread.getName());
}
long elapsed = System.currentTimeMillis() - startMs;
if (thread.isAlive()) {
fail(format("Thread [%s] is still alive after a timeout of [%s] ms", thread, timeoutMs));
}
assertFalse(thread.isAlive());
timeoutMs -= elapsed;
}
}
public static void startAll(TestThread... threads) {
for (Thread thread : threads) {
thread.start();
}
}
public static void assertEventuallyFalse(Callable<Boolean> f) {
assertEventually(f, false);
}
public static void assertEventuallyFalse(Callable<Boolean> f, long timeoutMs) {
assertEventually(f, false, timeoutMs);
}
public static void assertEventually(Callable<Boolean> f, boolean value) {
assertEventually(f, value, 60 * 1000);
}
public static void assertEventuallyTrue(Callable<Boolean> f) {
assertEventually(f, true);
}
public static void assertEventuallyTrue(Callable<Boolean> f, long timeoutMs) {
assertEventually(f, true, timeoutMs);
}
public static void assertEventually(Callable<Boolean> f, boolean value, long timeoutMs) {
long endTime = System.currentTimeMillis() + timeoutMs;
for (; ; ) {
try {
if (f.call() == value) {
return;
}
} catch (Exception e) {
throw new RuntimeException(e);
}
sleepMs(100);
if (endTime > System.currentTimeMillis()) {
fail("Failed to become true in the given timeout");
}
}
}
public static void sleepRandomUs(int maxUs) {
Bugshaker.sleepUs((long) randomInt(maxUs));
}
public static void assertInstanceof(Class expected, Object o) {
assertTrue(o.getClass().getName(), expected.isAssignableFrom(o.getClass()));
}
/**
* Joins all threads. If this can't be done within 5 minutes, an assertion failure is thrown.
*
* @param threads the threads to join.
* @return the total duration of all threads (so the sum of the time each thread has been running.
* @see #joinAll(long, TestThread...) for more specifics.
*/
public static long joinAll(TestThread... threads) {
return joinAll(5 * 60, threads);
}
/**
* Joins all threads. If one of the thread throws a throwable, the join will fail as well.
*
* @param timeoutSec the timeout in seconds. If the join doesn't complete within that time, the
* join fails.
* @param threads the threads to join.
* @return the total duration of all threads (so the sum of the time each thread has been running.
*/
@SuppressWarnings({"ThrowableResultOfMethodCallIgnored"})
public static long joinAll(long timeoutSec, TestThread... threads) {
if (timeoutSec < 0) {
throw new IllegalArgumentException();
}
List<TestThread> uncompleted = new LinkedList<TestThread>(Arrays.asList(threads));
long maxTimeMs = System.currentTimeMillis() + TimeUnit.SECONDS.toMillis(timeoutSec);
long durationMs = 0;
while (!uncompleted.isEmpty()) {
for (Iterator<TestThread> it = uncompleted.iterator(); it.hasNext(); ) {
TestThread thread = it.next();
try {
if (System.currentTimeMillis() > maxTimeMs) {
fail(String.format(
"Failed to join all threads in %s seconds, remaining threads %s.\n%s",
timeoutSec, uncompleted, getStacks(uncompleted)));
}
thread.join(100);
if (!thread.isAlive()) {
it.remove();
durationMs += thread.getDurationMs();
if (thread.getThrowable() == null) {
System.out.printf("Multiverse > %s completed successfully\n", thread.getName());
} else {
System.out.printf("Multiverse > %s encountered the following error\n", thread.getName());
thread.getThrowable().printStackTrace();
fail(String.format("Multiverse > %s completed with failure", thread.getName()));
}
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new RuntimeException(String.format("Joining %s was interrupted", thread), e);
}
}
}
return durationMs;
}
private static String getStacks(List<TestThread> uncompleted) {
StringBuffer sb = new StringBuffer();
sb.append("Uncompleted threads:\n");
for (TestThread thread : uncompleted) {
sb.append("-------------------------------------------------------------------\n");
sb.append(thread.getName() + "\n");
for (StackTraceElement element : thread.getStackTrace()) {
sb.append("\tat " + element + "\n");
}
}
sb.append("-------------------------------------------------------------------\n");
return sb.toString();
}
}