package org.corfudb.infrastructure;
import lombok.Getter;
import org.assertj.core.api.Assertions;
import org.corfudb.AbstractCorfuTest;
import org.corfudb.protocols.wireprotocol.CorfuMsg;
import org.corfudb.protocols.wireprotocol.CorfuPayloadMsg;
import org.corfudb.runtime.CorfuRuntime;
import org.corfudb.runtime.clients.*;
import org.junit.Before;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
/**
* Created by mwei on 12/12/15.
*/
public abstract class AbstractServerTest extends AbstractCorfuTest {
public static final UUID testClientId = UUID.nameUUIDFromBytes("TEST_CLIENT".getBytes());
@Getter
TestServerRouter router;
AtomicInteger requestCounter;
public AbstractServerTest() {
router = new TestServerRouter();
requestCounter = new AtomicInteger();
// Force all new CorfuRuntimes to override the getRouterFn
CorfuRuntime.overrideGetRouterFunction = this::getRouterFunction;
}
public void setServer(AbstractServer server) {
router.reset();
router.addServer(server);
}
public abstract AbstractServer getDefaultServer();
@Before
public void resetTest() {
router.reset();
router.addServer(getDefaultServer());
requestCounter.set(0);
}
public List<CorfuMsg> getResponseMessages() {
return router.getResponseMessages();
}
public CorfuMsg getLastMessage() {
if (router.getResponseMessages().size() == 0) return null;
return router.getResponseMessages().get(router.getResponseMessages().size() - 1);
}
@SuppressWarnings("unchecked")
public <T extends CorfuMsg> T getLastMessageAs(Class<T> type) {
return (T) getLastMessage();
}
@SuppressWarnings("unchecked")
public <T> T getLastPayloadMessageAs(Class<T> type) {
Assertions.assertThat(getLastMessage())
.isInstanceOf(CorfuPayloadMsg.class);
return ((CorfuPayloadMsg<T>)getLastMessage()).getPayload();
}
public void sendMessage(CorfuMsg message) {
sendMessage(testClientId, message);
}
public void sendMessage(UUID clientId, CorfuMsg message) {
message.setClientID(clientId);
message.setRequestID(requestCounter.getAndIncrement());
router.sendServerMessage(message);
}
/**
* A map of maps to endpoint->routers, mapped for each runtime instance captured
*/
final Map<CorfuRuntime, Map<String, TestClientRouter>>
runtimeRouterMap = new ConcurrentHashMap<>();
/**
* Function for obtaining a router, given a runtime and an endpoint.
*
* @param runtime The CorfuRuntime to obtain a router for.
* @param endpoint An endpoint string for the router.
* @return
*/
private IClientRouter getRouterFunction(CorfuRuntime runtime, String endpoint) {
runtimeRouterMap.putIfAbsent(runtime, new ConcurrentHashMap<>());
if (!endpoint.startsWith("test:")) {
throw new RuntimeException("Unsupported endpoint in test: " + endpoint);
}
return runtimeRouterMap.get(runtime).computeIfAbsent(endpoint,
x -> {
TestClientRouter tcn =
new TestClientRouter(router);
tcn.addClient(new BaseClient())
.addClient(new SequencerClient())
.addClient(new LayoutClient())
.addClient(new LogUnitClient())
.addClient(new ManagementClient());
return tcn;
}
);
}
}