package com.spotify.heroic;
import static org.junit.Assert.assertEquals;
import com.google.common.base.Throwables;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.spotify.heroic.cluster.ClusterManagerModule;
import com.spotify.heroic.cluster.NodeMetadataFactory;
import com.spotify.heroic.cluster.RpcProtocolModule;
import com.spotify.heroic.cluster.discovery.simple.StaticListDiscoveryModule;
import com.spotify.heroic.dagger.CoreComponent;
import com.spotify.heroic.profile.MemoryProfile;
import com.spotify.heroic.querylogging.QueryLogger;
import com.spotify.heroic.querylogging.QueryLoggerFactory;
import com.spotify.heroic.querylogging.QueryLoggingComponent;
import com.spotify.heroic.querylogging.QueryLoggingModule;
import com.spotify.heroic.rpc.grpc.GrpcRpcProtocolModule;
import com.spotify.heroic.rpc.jvm.JvmRpcContext;
import com.spotify.heroic.rpc.jvm.JvmRpcProtocolModule;
import eu.toolchain.async.AsyncFuture;
import eu.toolchain.async.TinyAsync;
import java.net.URI;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import org.junit.After;
import org.junit.Before;
import org.mockito.Mockito;
public abstract class AbstractLocalClusterIT {
protected final ExecutorService executor = Executors.newSingleThreadExecutor();
protected final TinyAsync async = TinyAsync.builder().executor(executor).build();
protected List<HeroicCoreInstance> instances;
private final Map<String, QueryLogger> loggers = new HashMap<>();
private final QueryLoggingModule mockQueryLoggingModule = early -> new QueryLoggingComponent() {
@Override
public QueryLoggerFactory queryLoggerFactory() {
return new QueryLoggerFactory() {
@Override
public QueryLogger create(final String component) {
synchronized (loggers) {
QueryLogger logger = loggers.get(component);
if (logger == null) {
logger = Mockito.mock(QueryLogger.class);
loggers.put(component, logger);
}
return logger;
}
}
};
}
};
protected String protocol() {
return "jvm";
}
/**
* Override to configure more than one instance.
* <p>
* These will be available in the {@link #instances} field.
*
* @return a list of URIs.
*/
protected List<URI> instanceUris() {
return ImmutableList.of(URI.create(protocol() + "://a"), URI.create(protocol() + "://b"));
}
/**
* Overload the default metadata behavior.
*
* If empty, default factories will be used.
*
* @return a list of metadata factories, that will be round-robined over.
*/
protected List<NodeMetadataFactory> metadataFactories() {
return ImmutableList.of();
}
/**
* Prepare the environment before the test.
* <p>
* {@link #instances} have been configured before this is called and can be safely used.
*
* @return a future that indicates that the environment is ready.
*/
protected AsyncFuture<Void> prepareEnvironment() {
return async.resolved(null);
}
/**
* Access to locally pre-configured query loggers which have been provided to all instances.
*
* @param component Component to get logger for
* @return a QueryLogger or empty
*/
protected Optional<QueryLogger> getQueryLogger(final String component) {
synchronized (loggers) {
return Optional.ofNullable(loggers.get(component));
}
}
@Before
public final void abstractSetup() throws Exception {
final JvmRpcContext context = new JvmRpcContext();
final List<URI> uris = instanceUris();
final List<Integer> expectedNumberOfNodes =
uris.stream().map(u -> uris.size()).collect(Collectors.toList());
int instanceIndex = 0;
final List<HeroicCoreInstance> instances = new ArrayList<>();
for (final URI uri : uris) {
instances.add(setupCore(uri, uris, context, instanceIndex++));
}
this.instances = instances;
final AsyncFuture<Void> startup = async.collectAndDiscard(
instances.stream().map(HeroicCoreInstance::start).collect(Collectors.toList()));
// Refresh all cores, allowing them to discover each other.
final AsyncFuture<Void> refresh = startup.lazyTransform(v -> {
return setupStaticNodes().lazyTransform(ignore -> async.collectAndDiscard(instances
.stream()
.map(c -> c.inject(inj -> inj.clusterManager().refresh()))
.collect(Collectors.toList())));
});
refresh.lazyTransform(v -> {
// verify that the correct number of nodes are visible from all instances.
final List<Integer> actualNumberOfNodes = instances
.stream()
.map(c -> c.inject(inj -> inj.clusterManager().getNodes().size()))
.collect(Collectors.toList());
assertEquals(expectedNumberOfNodes, actualNumberOfNodes);
return prepareEnvironment();
}).get(10, TimeUnit.SECONDS);
}
private AsyncFuture<Void> setupStaticNodes() {
// JVM protocol does not require static nodes since addresses are known beforehand.
if ("jvm".equals(protocol())) {
return async.resolved(null);
}
// Collect remote URIs from instances to get ahold of generated port.
final List<AsyncFuture<String>> remotes = instances
.stream()
.map(c -> c.inject(CoreComponent::clusterManager))
.map(c -> c.protocols().iterator().next().getListenURI())
.collect(Collectors.toList());
return async.collect(remotes).lazyTransform(uris -> {
final List<AsyncFuture<Void>> addNodes = new ArrayList<>();
for (final String stringUri : uris) {
final URI remoteUri = URI.create(stringUri);
instances.forEach(instance -> {
final URI uri =
URI.create(remoteUri.getScheme() + "://127.0.0.1:" + remoteUri.getPort());
addNodes.add(instance.inject(CoreComponent::clusterManager).addStaticNode(uri));
});
}
return async.collectAndDiscard(addNodes);
});
}
@After
public final void abstractTeardown() throws Exception {
async
.collectAndDiscard(
instances.stream().map(HeroicCoreInstance::shutdown).collect(Collectors.toList()))
.get(10, TimeUnit.SECONDS);
}
private HeroicCoreInstance setupCore(
final URI uri, final List<URI> uris, final JvmRpcContext context, final int index
) {
try {
return setupCoreThrowing(uri, uris, context, index);
} catch (Exception e) {
throw Throwables.propagate(e);
}
}
private HeroicCoreInstance setupCoreThrowing(
final URI uri, final List<URI> uris, final JvmRpcContext context, final int index
) throws Exception {
final RpcProtocolModule protocol;
final StaticListDiscoveryModule discovery;
switch (uri.getScheme()) {
case "jvm":
protocol =
JvmRpcProtocolModule.builder().context(context).bindName(uri.getHost()).build();
discovery = new StaticListDiscoveryModule(uris);
break;
case "grpc":
protocol = GrpcRpcProtocolModule.builder().port(0).build();
discovery = new StaticListDiscoveryModule(ImmutableList.of());
break;
default:
throw new IllegalArgumentException("Unsupported URI: " + uri);
}
return HeroicCore
.builder()
.setupShellServer(false)
.setupService(false)
.oneshot(true)
.executor(executor)
.configFragment(HeroicConfig
.builder()
.cluster(buildClusterConfig(uri, protocol, discovery, index)))
.configFragment(HeroicConfig.builder().queryLogging(mockQueryLoggingModule))
.profile(new MemoryProfile())
.modules(HeroicModules.ALL_MODULES)
.build()
.newInstance();
}
private ClusterManagerModule.Builder buildClusterConfig(
final URI uri, final RpcProtocolModule protocol, final StaticListDiscoveryModule discovery,
int index
) {
final ClusterManagerModule.Builder cluster = ClusterManagerModule
.builder()
.tags(ImmutableMap.of("shard", uri.getHost()))
.protocols(ImmutableList.of(protocol))
.discovery(discovery);
final List<NodeMetadataFactory> factories = metadataFactories();
if (!factories.isEmpty()) {
cluster.metadataFactory(factories.get(index % factories.size()));
}
return cluster;
}
}