package org.jgroups.protocols;
import org.jgroups.*;
import org.jgroups.stack.Protocol;
import org.jgroups.util.Util;
import org.testng.annotations.AfterMethod;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicInteger;
/**
* Tests for contention on UNICAST, measured by the number of retransmissions in UNICAST
* @author Bela Ban
*/
@Test(groups=Global.FUNCTIONAL, singleThreaded=true)
public class UNICAST_ContentionTest {
JChannel a, b;
static final int NUM_THREADS = 100;
static final int NUM_MSGS = 100;
static final int SIZE = 1000; // default size of a message in bytes
@AfterMethod
protected void tearDown() throws Exception {
Util.close(b,a);
}
@DataProvider
static Object[][] provider() {
return new Object[][] {
{UNICAST3.class}
};
}
@Test(dataProvider="provider")
public void testSimpleMessageReception(Class<? extends Protocol> unicast_class) throws Exception {
a=create(unicast_class, "A");
b=create(unicast_class, "B");
MyReceiver r1=new MyReceiver("A"), r2=new MyReceiver("B");
a.setReceiver(r1);
b.setReceiver(r2);
a.connect("testSimpleMessageReception");
b.connect("testSimpleMessageReception");
int NUM=100;
Address c1_addr=a.getAddress(), c2_addr=b.getAddress();
for(int i=1; i <= NUM; i++) {
a.send(c1_addr,"bla");
a.send(c2_addr,"bla");
b.send(c2_addr,"bla");
b.send(c1_addr,"bla");
}
for(int i=0; i < 10; i++) {
if(r1.getNum() == NUM * 2 && r2.getNum() == NUM * 2)
break;
Util.sleep(500);
}
System.out.println("c1 received " + r1.getNum() + " msgs, " + getNumberOfRetransmissions(a) + " retransmissions");
System.out.println("c2 received " + r2.getNum() + " msgs, " + getNumberOfRetransmissions(b) + " retransmissions");
assert r1.getNum() == NUM * 2: "expected " + NUM *2 + ", but got " + r1.getNum();
assert r2.getNum() == NUM * 2: "expected " + NUM *2 + ", but got " + r2.getNum();
}
/**
* Multiple threads (NUM_THREADS) send messages (NUM_MSGS)
* @throws Exception
*/
@Test(dataProvider="provider")
public void testMessageReceptionUnderHighLoad(Class<? extends Protocol> unicast_class) throws Exception {
CountDownLatch latch=new CountDownLatch(1);
a=create(unicast_class, "A");
b=create(unicast_class, "B");
MyReceiver r1=new MyReceiver("A"), r2=new MyReceiver("B");
a.setReceiver(r1);
b.setReceiver(r2);
a.connect("testSimpleMessageReception");
b.connect("testSimpleMessageReception");
Address c1_addr=a.getAddress(), c2_addr=b.getAddress();
MySender[] c1_senders=new MySender[NUM_THREADS];
for(int i=0; i < c1_senders.length; i++) {
c1_senders[i]=new MySender(a, c2_addr, latch);
c1_senders[i].start();
}
MySender[] c2_senders=new MySender[NUM_THREADS];
for(int i=0; i < c2_senders.length; i++) {
c2_senders[i]=new MySender(b, c1_addr, latch);
c2_senders[i].start();
}
latch.countDown(); // starts all threads
for(MySender sender: c1_senders)
sender.join();
for(MySender sender: c2_senders)
sender.join();
System.out.println("Senders are done, waiting for all messages to be received");
long NUM_EXPECTED_MSGS=NUM_THREADS * NUM_MSGS;
for(int i=0; i < 20; i++) {
if(r1.getNum() == NUM_EXPECTED_MSGS && r2.getNum() == NUM_EXPECTED_MSGS)
break;
Util.sleep(2000);
}
System.out.println("c1 received " + r1.getNum() + " msgs, " + getNumberOfRetransmissions(a) + " retransmissions");
System.out.println("c2 received " + r2.getNum() + " msgs, " + getNumberOfRetransmissions(b) + " retransmissions");
assert r1.getNum() == NUM_EXPECTED_MSGS : "expected " + NUM_EXPECTED_MSGS + ", but got " + r1.getNum();
assert r2.getNum() == NUM_EXPECTED_MSGS : "expected " + NUM_EXPECTED_MSGS + ", but got " + r2.getNum();
}
protected JChannel create(Class<? extends Protocol> unicast_class, String name) throws Exception {
return new JChannel(new SHARED_LOOPBACK(), unicast_class.newInstance().setValue("xmit_interval", 500)).name(name);
}
private static long getNumberOfRetransmissions(JChannel ch) {
Protocol prot=ch.getProtocolStack().findProtocol(Util.getUnicastProtocols());
if(prot instanceof UNICAST3)
return ((UNICAST3)prot).getNumXmits();
return 0;
}
private static class MySender extends Thread {
private final JChannel ch;
private final Address dest;
private final CountDownLatch latch;
private final byte[] buf=new byte[SIZE];
public MySender(JChannel ch, Address dest, CountDownLatch latch) {
this.ch=ch;
this.dest=dest;
this.latch=latch;
}
public void run() {
try {
latch.await();
}
catch(InterruptedException e) {
e.printStackTrace();
}
for(int i=0; i < NUM_MSGS; i++) {
try {
Message msg=new Message(dest, buf);
ch.send(msg);
}
catch(Exception e) {
e.printStackTrace();
}
}
}
}
private static class MyReceiver extends ReceiverAdapter {
final String name;
final AtomicInteger num=new AtomicInteger(0);
static final long MOD=NUM_MSGS * NUM_THREADS / 10;
public MyReceiver(String name) {
this.name=name;
}
public void receive(Message msg) {
if(num.incrementAndGet() % MOD == 0) {
System.out.println("[" + name + "] received " + getNum() + " msgs");
}
}
public int getNum() {
return num.get();
}
}
}