package restx.tests;
import com.google.common.base.Optional;
import com.google.common.base.Suppliers;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import restx.factory.Factory;
import restx.specs.*;
import java.io.IOException;
import java.util.*;
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.collect.Iterables.transform;
import static com.google.common.collect.Lists.newArrayList;
import static restx.RestxMainRouterFactory.Blade;
import static restx.factory.Factory.LocalMachines.contextLocal;
import static restx.factory.Factory.LocalMachines.overrideComponents;
/**
* User: xavierhanin
* Date: 3/17/13
* Time: 1:58 PM
*/
public class RestxSpecRunner {
private final String serverId;
private final String baseUrl;
private final String routerPath;
private final RestxSpecLoader specLoader;
private final Iterable<GivenSpecRuleSupplier> givenSpecRuleSuppliers;
private final Iterable<GivenRunner> givenRunners;
private final Iterable<WhenChecker> whenCheckers;
private final Map<String, Iterable<GivenSpecRule>> givenRulesPerBlade = new LinkedHashMap<>();
public RestxSpecRunner(RestxSpecLoader specLoader, String routerPath, String serverId, String baseUrl, Factory factory) {
this.specLoader = checkNotNull(specLoader);
this.routerPath = checkNotNull(routerPath);
this.serverId = checkNotNull(serverId);
this.baseUrl = checkNotNull(baseUrl);
givenSpecRuleSuppliers = factory.queryByClass(GivenSpecRuleSupplier.class).findAsComponents();
givenRunners = factory.queryByClass(GivenRunner.class).findAsComponents();
whenCheckers = factory.queryByClass(WhenChecker.class).findAsComponents();
}
public void runTest(String spec) throws IOException {
runTest(loadSpec(spec));
}
public RestxSpec loadSpec(String spec) throws IOException {
return specLoader.load(spec);
}
public void runTest(RestxSpec restxSpec) {
Map<String, String> params = Maps.newLinkedHashMap();
for (GivenSpecRule givenSpecRule : getSpecRulesForCurrentBlade()) {
params.putAll(givenSpecRule.getRunParams());
}
params.put(WhenHttpRequest.CONTEXT_NAME, serverId);
params.put(WhenHttpRequest.BASE_URL, baseUrl + routerPath);
runSpec(restxSpec, ImmutableMap.copyOf(params));
}
private synchronized Iterable<GivenSpecRule> getSpecRulesForCurrentBlade() {
String currentBlade = Blade.current();
Iterable<GivenSpecRule> givenSpecRules = givenRulesPerBlade.get(currentBlade);
if (givenSpecRules == null) {
givenSpecRules = newArrayList(transform(givenSpecRuleSuppliers,
Suppliers.<GivenSpecRule>supplierFunction()));
for (GivenSpecRule givenSpecRule : givenSpecRules) {
givenSpecRule.onSetup(contextLocal(bladeContextId(currentBlade)));
}
givenRulesPerBlade.put(currentBlade, givenSpecRules);
}
return givenSpecRules;
}
public synchronized void dispose() {
for (Map.Entry<String, Iterable<GivenSpecRule>> blade : givenRulesPerBlade.entrySet()) {
for (GivenSpecRule givenSpecRule : blade.getValue()) {
givenSpecRule.onTearDown(contextLocal(bladeContextId(blade.getKey())));
}
}
}
private void runSpec(RestxSpec restxSpec, ImmutableMap<String, String> params) {
List<GivenCleaner> givenCleaners = newArrayList();
try {
for (Given given : restxSpec.getGiven()) {
Set<GivenRunner<Given>> runnersFor = findRunnersFor(given);
if (runnersFor.isEmpty()) {
throw new IllegalStateException(
"no runner found for given " + given + ". double check your classpath and factory settings.");
}
for (GivenRunner<Given> runner : runnersFor) {
givenCleaners.add(runner.run(given, params));
}
}
for (When when : restxSpec.getWhens()) {
Optional<WhenChecker<When>> checkerFor = findCheckerFor(when);
if (!checkerFor.isPresent()) {
throw new IllegalStateException("no checker found for when " + when + "." +
" double check your classpath and factory settings.");
}
checkerFor.get().check(when, params);
}
} finally {
for (GivenCleaner givenCleaner : givenCleaners) {
givenCleaner.cleanUp();
}
overrideComponents().clear();
}
}
@SuppressWarnings("unchecked")
private <T extends When> Optional<WhenChecker<T>> findCheckerFor(T when) {
if (when instanceof WhenChecker) {
return Optional.of((WhenChecker<T>) when);
}
for (WhenChecker whenChecker : whenCheckers) {
if (whenChecker.getWhenClass().isAssignableFrom(when.getClass())) {
return Optional.of((WhenChecker<T>) whenChecker);
}
}
return Optional.absent();
}
@SuppressWarnings("unchecked")
private <T extends Given> Set<GivenRunner<T>> findRunnersFor(T given) {
if (given instanceof GivenRunner) {
return Sets.newHashSet((GivenRunner<T>) given);
}
Set<GivenRunner<T>> compatibleRunners = new HashSet<>();
for (GivenRunner<?> givenRunner : givenRunners) {
if (givenRunner.getGivenClass().isAssignableFrom(given.getClass())) {
compatibleRunners.add((GivenRunner<T>) givenRunner);
}
}
return compatibleRunners;
}
protected Factory.LocalMachines bladeLocalMachines() {
return Blade.bladeLocalMachines(serverId);
}
private String bladeContextId() {
return bladeContextId(Blade.current());
}
private String bladeContextId(String bladeId) {
return Blade.contextId(serverId, bladeId);
}
}