package org.radargun.stages.test; import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; import org.radargun.DistStageAck; import org.radargun.Operation; import org.radargun.StageResult; import org.radargun.Version; import org.radargun.config.Init; import org.radargun.config.Property; import org.radargun.config.Stage; import org.radargun.reporting.Report; import org.radargun.state.SlaveState; import org.radargun.stats.Statistics; import org.radargun.traits.InjectTrait; import org.radargun.traits.Transactional; import org.radargun.utils.TimeConverter; import org.radargun.utils.TimeService; import org.radargun.utils.Utils; /** * @author Radim Vansa <rvansa@redhat.com> */ @Stage(doc = "Base for test spawning several threads and benchmark of operations executed in those.") public abstract class TestStage extends BaseTestStage { public static final String NAMESPACE = "urn:radargun:stages:cache:" + Version.SCHEMA_VERSION; public static final String DEPRECATED_NAMESPACE = "urn:radargun:stages:legacy:" + Version.SCHEMA_VERSION; @Property(doc = "The number of threads executing on each node. You have to set either this or 'total-threads'. No default.") public int numThreadsPerNode = 0; @Property(doc = "Total number of threads across whole cluster. You have to set either this or 'num-threads-per-node'. No default.") public int totalThreads = 0; @Property(doc = "Specifies if the requests should be explicitly wrapped in transactions. " + "Options are NEVER, ALWAYS and IF_TRANSACTIONAL: transactions are used only if " + "the cache configuration is transactional and transactionSize > 0. Default is IF_TRANSACTIONAL.") public TransactionMode useTransactions = TransactionMode.IF_TRANSACTIONAL; @Property(doc = "Specifies whether the transactions should be committed (true) or rolled back (false). " + "Default is true") public boolean commitTransactions = true; @Property(doc = "Number of requests in one transaction. Default is 1.") public int transactionSize = 1; @Property(doc = "Local threads synchronize on starting each round of requests. Note that with requestPeriod > 0, " + "there is still the random ramp-up delay. Default is false.") public boolean synchronousRequests = false; @Property(doc = "Max duration of the test. Default is infinite.", converter = TimeConverter.class) public long timeout = 0; @Property(doc = "Delay to let all threads start executing operations. Default is 0.", converter = TimeConverter.class) public long rampUp = 0; @Property(converter = TimeConverter.class, doc = "Time between consecutive requests of one stressor thread. Default is 0.") protected long delayBetweenRequests = 0; @Property(doc = "Whether an error from transaction commit/rollback should be logged as error. Default is true.") public boolean logTransactionExceptions = true; @InjectTrait protected Transactional transactional; private Completion completion; private OperationSelector operationSelector; protected volatile boolean started = false; protected volatile boolean finished = false; protected volatile boolean terminated = false; protected StressorsManager stressorsManager; public StressorsManager getStressorsManager() { return stressorsManager; } @Init public void init() { if (totalThreads <= 0 && numThreadsPerNode <= 0) throw new IllegalStateException("You have to set either total-threads or num-threads-per-node."); if (totalThreads > 0 && numThreadsPerNode > 0) throw new IllegalStateException("You have to set only one ot total-threads, num-threads-per-node"); if (totalThreads < 0 || numThreadsPerNode < 0) throw new IllegalStateException("Number of threads can't be < 0"); } public DistStageAck executeOnSlave() { if (!isServiceRunning()) { log.info("Not running test on this slave as service is not running."); return successfulResponse(); } prepare(); try { long startNanos = TimeService.nanoTime(); log.info("Starting test " + testName); stressorsManager = setUpAndStartStressors(); waitForStressorsToFinish(stressorsManager); destroy(); log.info("Finished test. Test duration is: " + Utils.getNanosDurationString(TimeService.nanoTime() - startNanos)); return newStatisticsAck(stressorsManager.getStressors()); } catch (Exception e) { return errorResponse("Exception while initializing the test", e); } } /** * To be overridden in inheritors. */ protected void prepare() { } /** * To be overridden in inheritors. */ protected void destroy() { } public StageResult processAckOnMaster(List<DistStageAck> acks) { return processAckOnMaster(acks, testName); } protected StageResult processAckOnMaster(List<DistStageAck> acks, String testNameOverride) { StageResult result = super.processAckOnMaster(acks); if (result.isError()) return result; Report.Test test = getTest(amendTest, testNameOverride); testIteration = test == null ? 0 : test.getIterations().size(); // we cannot use aggregated = createStatistics() since with PeriodicStatistics the merge would fail List<StatisticsAck> statisticsAcks = instancesOf(acks, StatisticsAck.class); Statistics aggregated = statisticsAcks.stream().flatMap(ack -> ack.statistics.stream()).reduce(null, Statistics.MERGE); for (StatisticsAck ack : statisticsAcks) { if (ack.statistics != null) { if (test != null) { int testIteration = getTestIteration(); String iterationValue = resolveIterationValue(); if (iterationValue != null) { test.setIterationValue(testIteration, iterationValue); } if (test.getGroupOperationsMap() == null) { test.setGroupOperationsMap(ack.getGroupOperationsMap()); } test.addStatistics(testIteration, ack.getSlaveIndex(), ack.statistics); } } else { log.trace("No statistics received from slave: " + ack.getSlaveIndex()); } } if (checkRepeatCondition(aggregated)) { return StageResult.SUCCESS; } else { return StageResult.BREAK; } } protected StressorsManager setUpAndStartStressors() { long startTime = TimeService.currentTimeMillis(); completion = createCompletion(); CountDownLatch finishCountDown = new CountDownLatch(1); completion.setCompletionHandler(new Runnable() { @Override public void run() { //Stop collecting statistics for duration-based tests if (duration > 0) { finished = true; } finishCountDown.countDown(); } }); operationSelector = wrapOperationSelector(createOperationSelector()); List<Stressor> stressors = startStressors(); started = true; if (rampUp > 0) { try { Thread.sleep(rampUp); } catch (InterruptedException e) { throw new IllegalStateException("Interrupted during ramp-up.", e); } } return new StressorsManager(stressors, startTime, finishCountDown); } protected void waitForStressorsToFinish(StressorsManager manager) { try { if (timeout > 0) { long waitTime = getWaitTime(manager.getStartTime()); if (waitTime <= 0) { throw new TestTimeoutException(); } else { if (!manager.getFinishCountDown().await(waitTime, TimeUnit.MILLISECONDS)) { throw new TestTimeoutException(); } } } else { manager.getFinishCountDown().await(); } } catch (InterruptedException e) { throw new IllegalStateException("Unexpected interruption", e); } for (Thread stressorThread : manager.getStressors()) { try { if (timeout > 0) { long waitTime = getWaitTime(manager.getStartTime()); if (waitTime <= 0) throw new TestTimeoutException(); stressorThread.join(waitTime); } else { stressorThread.join(); } } catch (InterruptedException e) { throw new TestTimeoutException(e); } } } protected Completion createCompletion() { if (numOperations > 0) { long countPerNode = numOperations / getExecutingSlaves().size(); long modCountPerNode = numOperations % getExecutingSlaves().size(); if (getExecutingSlaveIndex() + 1 <= modCountPerNode) { countPerNode++; } return new CountStressorCompletion(countPerNode); } else { return new TimeStressorCompletion(duration); } } protected OperationSelector createOperationSelector() { return OperationSelector.DUMMY; } protected OperationSelector wrapOperationSelector(OperationSelector operationSelector) { if (synchronousRequests) { operationSelector = new SynchronousOperationSelector(operationSelector); } return operationSelector; } protected List<Stressor> startStressors() { int myFirstThread = getFirstThreadOn(slaveState.getSlaveIndex()); int myNumThreads = getNumThreadsOn(slaveState.getSlaveIndex()); CountDownLatch threadCountDown = new CountDownLatch(myNumThreads); List<Stressor> stressors = new ArrayList<>(); for (int threadIndex = stressors.size(); threadIndex < myNumThreads; threadIndex++) { Stressor stressor = new Stressor(this, getLogic(), myFirstThread + threadIndex, threadIndex, logTransactionExceptions, threadCountDown, delayBetweenRequests); stressors.add(stressor); stressor.start(); } try { threadCountDown.await(); } catch (InterruptedException e) { //FIXME implement me } log.info("Started " + stressors.size() + " stressor threads."); return stressors; } protected DistStageAck newStatisticsAck(List<Stressor> stressors) { List<Statistics> results = gatherResults(stressors, new StatisticsResultRetriever()); return new StatisticsAck(slaveState, results, statisticsPrototype.getGroupOperationsMap()); } protected <T> List<T> gatherResults(List<Stressor> stressors, ResultRetriever<T> retriever) { if (mergeThreadStats) { return stressors.stream() .map(retriever::getResult) .reduce(retriever::merge) .map(Collections::singletonList).orElse(Collections.emptyList()); } else { return stressors.stream() .map(retriever::getResult) .filter(r -> r != null) .collect(Collectors.toList()); } } protected long getWaitTime(long startTime) { return startTime + timeout - TimeService.currentTimeMillis(); } public int getTotalThreads() { if (totalThreads > 0) { return totalThreads; } else if (numThreadsPerNode > 0) { return getExecutingSlaves().size() * numThreadsPerNode; } else throw new IllegalStateException(); } public int getFirstThreadOn(int slave) { List<Integer> executingSlaves = getExecutingSlaves(); int execId = executingSlaves.indexOf(slave); if (numThreadsPerNode > 0) { return execId * numThreadsPerNode; } else if (totalThreads > 0) { return execId * totalThreads / executingSlaves.size(); } else { throw new IllegalStateException(); } } public int getNumThreadsOn(int slave) { List<Integer> executingSlaves = getExecutingSlaves(); if (numThreadsPerNode > 0) { return executingSlaves.contains(slaveState.getSlaveIndex()) ? numThreadsPerNode : 0; } else if (totalThreads > 0) { int execId = executingSlaves.indexOf(slave); return (execId + 1) * totalThreads / executingSlaves.size() - execId * totalThreads / executingSlaves.size(); } else { throw new IllegalStateException(); } } protected Statistics createStatistics() { return statisticsPrototype.copy(); } public boolean isStarted() { return started; } public boolean isFinished() { return finished; } public boolean isTerminated() { return terminated; } public void setTerminated() { terminated = true; stressorsManager.getFinishCountDown().countDown(); } public Completion getCompletion() { return completion; } public OperationSelector getOperationSelector() { return operationSelector; } public boolean useTransactions(String resourceName) { return useTransactions.use(transactional, resourceName, transactionSize); } public abstract OperationLogic getLogic(); public boolean isSingleTxType() { return transactionSize == 1; } protected interface ResultRetriever<T> { T getResult(Stressor stressor); T merge(T stats1, T stats2); } protected static class StatisticsResultRetriever implements ResultRetriever<Statistics> { public StatisticsResultRetriever() {} @Override public Statistics getResult(Stressor stressor) { return stressor.getStats(); } @Override public Statistics merge(Statistics stats1, Statistics stats2) { return Statistics.MERGE.apply(stats1, stats2); } } protected class TestTimeoutException extends RuntimeException { public TestTimeoutException() { } public TestTimeoutException(Throwable cause) { super(cause); } } protected static class StatisticsAck extends DistStageAck { public final List<Statistics> statistics; private final Map<String, Set<Operation>> groupOperationsMap; public StatisticsAck(SlaveState slaveState, List<Statistics> statistics, Map<String, Set<Operation>> groupOperationsMap) { super(slaveState); this.statistics = statistics; this.groupOperationsMap = groupOperationsMap; } public Map<String, Set<Operation>> getGroupOperationsMap() { return groupOperationsMap; } } }