package org.radargun.util; import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.concurrent.Callable; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; import java.util.stream.IntStream; import org.radargun.DistStageAck; import org.radargun.StageResult; import org.radargun.config.Cluster; import org.radargun.config.Configuration; import org.radargun.config.InitHelper; import org.radargun.config.MasterConfig; import org.radargun.reporting.Report; import org.radargun.reporting.Timeline; import org.radargun.stages.AbstractDistStage; import org.radargun.stages.AbstractMasterStage; import org.radargun.state.MasterState; import org.radargun.state.SlaveState; import org.radargun.traits.TraitHelper; /** * @author Matej Cimbora */ public class CoreStageRunner { private final Cluster cluster; private final List<PerSlaveConfiguration> perSlaveConfigurations = new ArrayList<>(2); private static final String DEFAULT_PLUGIN = "plugin"; private static final String DEFAULT_SERVICE = "service"; private static final String DEFAULT_HOST = "localhost"; private static final String DEFAULT_CONFIGURATION = "configuration"; private static final int DEFAULT_PORT = 2103; public CoreStageRunner(int clusterSize) { if (clusterSize <= 0) { throw new IllegalArgumentException("Cluster size needs to be greater than 0, was " + clusterSize); } Cluster cluster = new Cluster(); cluster.setSize(clusterSize); this.cluster = cluster; for (int i = 0; i < clusterSize; i++) { Map<Class<?>, Object> traitMap = getDefaultTraitMap(); SlaveState slaveState = new SlaveState(); slaveState.setSlaveIndex(i); slaveState.setCluster(cluster); slaveState.setPlugin(DEFAULT_PLUGIN); slaveState.setService(DEFAULT_SERVICE); slaveState.setTimeline(new Timeline(i)); slaveState.setTraits(traitMap); MasterState masterState = new MasterState(new MasterConfig(DEFAULT_PORT, DEFAULT_HOST)); masterState.setCluster(cluster); masterState.setReport(new Report(new Configuration(DEFAULT_CONFIGURATION), cluster)); perSlaveConfigurations.add(new PerSlaveConfiguration(traitMap, slaveState, masterState)); } } public DistStageAck executeOnSlave(AbstractDistStage stage) throws Exception { return executeOnSlave(stage, 0); } public DistStageAck executeOnSlave(AbstractDistStage stage, int slaveIndex) throws Exception { checkSlaveIndex(slaveIndex); PerSlaveConfiguration perSlaveConfiguration = perSlaveConfigurations.get(slaveIndex); TraitHelper.inject(stage, perSlaveConfiguration.traitMap); InitHelper.init(stage); stage.initOnSlave(perSlaveConfiguration.slaveState); // TODO move elsewhere? stage.initOnMaster(perSlaveConfiguration.masterState); return stage.executeOnSlave(); } public List<DistStageAck> executeOnSlave(AbstractDistStage[] stages, int[] slaveIndices) throws Exception { ExecutorService executor = Executors.newFixedThreadPool(stages.length); List<Callable<DistStageAck>> callables = new ArrayList<>(stages.length); IntStream.range(0, stages.length).forEach(i -> { callables.add(() -> executeOnSlave(stages[i], slaveIndices[i])); }); List<Future<DistStageAck>> futures = executor.invokeAll(callables); List<DistStageAck> acks = new ArrayList<>(stages.length); futures.stream().forEach(f -> { try { acks.add(f.get()); } catch (InterruptedException e) { throw new IllegalStateException(e); } catch (ExecutionException e) { throw new IllegalStateException(e); } }); return acks; } public StageResult processAckOnMaster(AbstractDistStage stage, List<DistStageAck> acks) { return stage.processAckOnMaster(acks); } public StageResult executeMasterStage(AbstractMasterStage stage) throws Exception { // TODO more initialization InitHelper.init(stage); MasterState masterState = new MasterState(new MasterConfig(DEFAULT_PORT, DEFAULT_HOST)); masterState.setCluster(cluster); masterState.setReport(new Report(new Configuration(DEFAULT_CONFIGURATION), cluster)); stage.init(masterState); return stage.execute(); } protected Map<Class<?>, Object> getDefaultTraitMap() { return CoreTraitRepository.getAllTraits(); } public <T> T getTraitImpl(Class<T> clazz) { return getTraitImpl(clazz, 0); } public <T> T getTraitImpl(Class<T> clazz, int slaveIndex) { checkSlaveIndex(slaveIndex); return (T) perSlaveConfigurations.get(slaveIndex).traitMap.get(clazz); } public void replaceTraitImpl(Class clazz, Object traitImpl) { replaceTraitImpl(clazz, 0); } public void replaceTraitImpl(Class clazz, Object traitImpl, int slaveIndex) { checkSlaveIndex(slaveIndex); if (!perSlaveConfigurations.get(slaveIndex).traitMap.containsKey(clazz)) { throw new IllegalArgumentException("Trait implementation for class " + clazz + " not found"); } perSlaveConfigurations.get(slaveIndex).traitMap.put(clazz, traitImpl); } public SlaveState getSlaveState() { return getSlaveState(0); } public SlaveState getSlaveState(int slaveIndex) { checkSlaveIndex(slaveIndex); return perSlaveConfigurations.get(slaveIndex).slaveState; } private static class PerSlaveConfiguration { private final Map<Class<?>, Object> traitMap; private final SlaveState slaveState; private final MasterState masterState; public PerSlaveConfiguration(Map<Class<?>, Object> traitMap, SlaveState slaveState, MasterState masterState) { this.traitMap = traitMap; this.slaveState = slaveState; this.masterState = masterState; } } private void checkSlaveIndex(int slaveIndex) { if (slaveIndex >= cluster.getSize()) { throw new IllegalArgumentException("Illegal slave index provided, expected value from range (0 - " + (cluster.getSize() - 1) + "), was " + slaveIndex); } } }