package org.jgroups.tests;
import org.jgroups.Global;
import org.jgroups.JChannel;
import org.jgroups.Message;
import org.jgroups.ReceiverAdapter;
import org.jgroups.protocols.TP;
import org.jgroups.util.Util;
import org.testng.annotations.AfterMethod;
import org.testng.annotations.BeforeMethod;
import org.testng.annotations.Test;
import java.util.Collection;
import java.util.concurrent.*;
/**
* @author Bela Ban
*/
@Test(groups=Global.STACK_DEPENDENT,sequential=true)
public class TransportThreadPoolTest extends ChannelTestBase {
JChannel c1, c2;
@BeforeMethod
protected void setUp() throws Exception {
c1=createChannel(true, 2, "A");
c2=createChannel(c1, "B");
}
@AfterMethod
protected void tearDown() throws Exception {
Util.close(c2, c1);
}
@Test
public void testThreadPoolReplacement() throws Exception {
Receiver r1=new Receiver(), r2=new Receiver();
c1.setReceiver(r1);
c2.setReceiver(r2);
c1.connect("TransportThreadPoolTest");
c2.connect("TransportThreadPoolTest");
Util.waitUntilAllChannelsHaveSameSize(10000, 1000, c1, c2);
assert c2.getView().size() == 2 : "view is " + c2.getView() + ", but should have had a size of 2";
TP transport=c1.getProtocolStack().getTransport();
ExecutorService thread_pool=Executors.newFixedThreadPool(2);
transport.setDefaultThreadPool(thread_pool);
transport=c2.getProtocolStack().getTransport();
thread_pool=Executors.newFixedThreadPool(2);
transport.setDefaultThreadPool(thread_pool);
c1.send(null, "hello world");
c2.send(null, "bela");
c1.send(null, "message 3");
c2.send(null, "message 4");
long start=System.currentTimeMillis();
r1.getLatch().await(3000, TimeUnit.MILLISECONDS);
r2.getLatch().await(3000, TimeUnit.MILLISECONDS);
long diff=System.currentTimeMillis() - start;
System.out.println("messages c1: " + print(r1.getMsgs()) + "\nmessages c2: " + print(r2.getMsgs())
+ "\ntook " + diff + " ms");
assert r1.getMsgs().size() == 4;
assert r2.getMsgs().size() == 4;
}
private static String print(Collection<Message> msgs) {
StringBuilder sb=new StringBuilder();
for(Message msg: msgs) {
sb.append("\"" + msg.getObject() + "\"").append(" ");
}
return sb.toString();
}
private static class Receiver extends ReceiverAdapter {
Collection<Message> msgs=new ConcurrentLinkedQueue<Message>();
final CountDownLatch latch = new CountDownLatch(4);
public Collection<Message> getMsgs() {
return msgs;
}
public CountDownLatch getLatch(){
return latch;
}
public void receive(Message msg) {
msgs.add(msg);
latch.countDown();
}
}
}