package com.jfastnet.processors;
import com.jfastnet.AbstractTest;
import com.jfastnet.Config;
import com.jfastnet.MessageKey;
import com.jfastnet.MessageLog;
import com.jfastnet.idprovider.ReliableModeIdProvider;
import com.jfastnet.messages.Message;
import com.jfastnet.util.NullsafeHashMap;
import lombok.extern.slf4j.Slf4j;
import org.junit.Test;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.*;
/** @author Klaus Pfeiffer - klaus@allpiper.com */
@Slf4j
public class StackedMessageProcessorTest extends AbstractTest {
private static final AtomicInteger receivedCounter = new AtomicInteger();
// private static ThreadLocal<AtomicInteger> stackableReceived;
// private static ThreadLocal<AtomicInteger> unstackableReceived;
private static Map<Integer, AtomicInteger> expectedId = new HashMap<>();
private static final AtomicInteger closeMsgReceived = new AtomicInteger();
private static final Map<Integer, List<Message>> receivedMessages = new HashMap<>();
private static Map<Integer, Set<Long>> stackableIds = new NullsafeHashMap<Integer, Set<Long>>() {
@Override protected Set<Long> newInstance() {return new HashSet<>();}
};
private static Map<Integer, Set<Long>> unstackableIds = new NullsafeHashMap<Integer, Set<Long>>() {
@Override protected Set<Long> newInstance() {return new HashSet<>();}
};
static boolean fail = false;
static class StackableMsg1 extends Message {
@Override
public boolean stackable() {
return true;
}
@Override
public void process(Object context) {
receivedCounter.incrementAndGet();
expectedId.get(getConfig().senderId).incrementAndGet();
writeLog("Stackable", this);
// int stackableReceivedCounter = stackableReceived.get().incrementAndGet();
// log.info("########### STACKABLE ### ClientID: {} ### MsgID: {} ### Number: {} ### ThreadID: {}",
// new Object[]{getConfig().senderId, getMsgId(), stackableReceivedCounter, Thread.currentThread().getId()});
addReceived(this);
printMsg(this);
long expectedIdValue = expectedId.get(getConfig().senderId).get();
if (getMsgId() != expectedIdValue) {
log.error("Wrong id found! Expected: {}, Actual: {}", new Object[]{expectedIdValue, getMsgId()});
fail = true;
}
// if (stackableReceivedCounter <= unstackableReceivedCounter) {
// log.error("Stackable must have a greater id! stackableReceived: {}, unstackableReceived: {}", stackableReceivedCounter, unstackableReceivedCounter);
// fail = true;
// }
if (getConfig() != null && stackableIds.containsKey(getConfig().senderId)) {
if (stackableIds.get(getConfig().senderId).contains(getMsgId())) {
log.error("Stackables already contained. senderId: {}, msgId: {}", getConfig().senderId, getMsgId());
fail = true;
}
}
}
}
static class StackableMsg2 extends StackableMsg1 {
@Override
public void process(Object context) {
addReceived(this);
closeMsgReceived.incrementAndGet();
log.info("Close msg #" + closeMsgReceived.get());
}
}
static class UnStackableMsg1 extends Message {
@Override
public void process(Object context) {
receivedCounter.incrementAndGet();
expectedId.get(getConfig().senderId).incrementAndGet();
writeLog("Unstackable", this);
// int unstackableReceivedCounter = unstackableReceived.get().incrementAndGet();
// log.info("########### UNSTACKABLE ### ClientID: {} ### MsgID: {} ### Number: {} ### ThreadID: {}",
// new Object[]{getConfig().senderId, getMsgId(), unstackableReceivedCounter, Thread.currentThread().getId()});
addReceived(this);
printMsg(this);
if (getConfig() != null && unstackableIds.containsKey(getConfig().senderId)) {
if (unstackableIds.get(getConfig().senderId).contains(getMsgId())) {
log.error("Stackables already contained {}, {}", getConfig().senderId, getMsgId());
fail = true;
}
}
}
}
private static void writeLog(final String type, Message msg) {
log.info("########### " + type + " ### ClientID: {} ### MsgID: {} ### ThreadID: {}",
new Object[]{msg.getConfig().senderId, msg.getMsgId(), Thread.currentThread().getId()});
}
private synchronized static void addReceived(Message message) {
List<Message> messages = receivedMessages.getOrDefault(message.getConfig().senderId, new ArrayList<>());
messages.add(message);
receivedMessages.put(message.getConfig().senderId, messages);
}
private static void printMsg(Message msg) {
// log.info("+++++++++++++ msg-id: " + msg.getMsgId());
// try {
// throw new Exception();
// } catch (Exception e) {
// e.printStackTrace();
// }
}
@Test
public void testStacking() {
reset();
start(8,
() -> {
Config config = newClientConfig().setStackKeepAliveMessages(true);
config.debug.enabled = true;
config.debug.lostPacketsPercentage = 5;
config.setIdProviderClass(ReliableModeIdProvider.class);
return config;
});
logBig("Send broadcast messages to clients");
int messageCount = 40;
for (int i = 0; i < messageCount; i++) {
server.send(new StackableMsg1());
}
server.send(new StackableMsg2());
int timeoutInSeconds = 15;
waitForCondition("Not all messages received.", timeoutInSeconds,
() -> closeMsgReceived.get() == clients.size(),
() -> "Received close messages: " + closeMsgReceived);
assertThat(receivedCounter.get(), is(messageCount * clients.size()));
assertThat(fail, is(false));
}
@Test
public void testStackingWithUnstackables() {
reset();
start(4,
() -> {
Config config = newClientConfig();
config.debug.enabled = true;
config.debug.lostPacketsPercentage = 5;
config.setIdProviderClass(ReliableModeIdProvider.class);
return config;
});
logBig("Send broadcast messages to clients");
int messageCount = 100;
for (int i = 0; i < messageCount; i++) {
server.send(new StackableMsg1());
server.send(new UnStackableMsg1());
}
server.send(new StackableMsg2());
int timeoutInSeconds = 15;
waitForCondition("Not all messages received.", timeoutInSeconds,
() -> closeMsgReceived.get() == clients.size(),
() -> "Received close messages: " + closeMsgReceived);
assertThat(receivedCounter.get(), is(messageCount * 2 * clients.size()));
assertThat(fail, is(false));
log.info("Check order of received messages");
for (int i = 1; i <= clients.size(); i++) {
List<Message> messages = receivedMessages.get(i);
assertThat(messages, is(notNullValue()));
assertThat(messages.size(), greaterThan(0));
long lastId = messages.get(0).getMsgId();
for (Message message : messages) {
assertThat(message.getMsgId(), is(lastId));
lastId++;
}
}
log.info("Check ids in message log");
MessageLog messageLog = server.getState().getProcessorOf(MessageLogProcessor.class).getMessageLog();
for (long i = 1; i <= messageCount * 2; i++) {
Message message = messageLog.getSent(MessageKey.newKey(Message.ReliableMode.SEQUENCE_NUMBER, 0, i));
assertThat("Message was null, id=" + i, message, is(notNullValue()));
assertThat(message.getMsgId(), is(i));
}
}
@Test
public void testLostPacketCorrectReceiveOrder() {
reset();
start(1,
() -> {
Config config = newClientConfig();
config.debug.enabled = true;
config.debug.lostPacketsPercentage = 0;
config.setIdProviderClass(ReliableModeIdProvider.class);
return config;
});
logBig("Send broadcast messages to clients");
discardNextPacket();
server.send(new StackableMsg1());
server.send(new UnStackableMsg1());
server.send(new StackableMsg1());
discardNextPacket();
server.send(new UnStackableMsg1());
discardNextPacket();
server.send(new StackableMsg1());
discardNextPacket();
server.send(new UnStackableMsg1());
server.send(new StackableMsg1());
server.send(new UnStackableMsg1());
// Send close message
server.send(new StackableMsg2());
int timeoutInSeconds = 5;
waitForCondition("Not all messages received.", timeoutInSeconds,
() -> closeMsgReceived.get() == clients.size(),
() -> "Received close messages: " + closeMsgReceived);
assertThat(fail, is(false));
}
private void discardNextPacket() {
clients.forEach(client -> client.getConfig().debug.discardNextPacket = true);
}
private void reset() {
receivedCounter.set(0);
closeMsgReceived.set(0);
stackableIds.clear();
unstackableIds.clear();
receivedMessages.clear();
fail = false;
// stackableReceived = new ThreadLocal<AtomicInteger>() {
// @Override
// protected AtomicInteger initialValue() {
// return new AtomicInteger();
// }
// };
// unstackableReceived = new ThreadLocal<AtomicInteger>() {
// @Override
// protected AtomicInteger initialValue() {
// return new AtomicInteger();
// }
// };
expectedId = new ConcurrentHashMap<>();
for (int i = 0; i < 16; i++) {
expectedId.put(i, new AtomicInteger(1));
}
}
}