package org.jgroups.tests;
import org.jgroups.*;
import org.jgroups.conf.ClassConfigurator;
import org.jgroups.protocols.pbcast.*;
import org.jgroups.stack.Protocol;
import org.jgroups.stack.ProtocolStack;
import org.jgroups.util.ArrayIterator;
import org.jgroups.util.Util;
import org.testng.annotations.AfterMethod;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;
import java.io.*;
import java.lang.reflect.Field;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
/**
* Tests correct state transfer while other members continue sending messages to the group
* @author Bela Ban
*/
@Test(groups={Global.STACK_DEPENDENT,Global.EAP_EXCLUDED},singleThreaded=true)
public class StateTransferTest extends ChannelTestBase {
static final int MSG_SEND_COUNT=1000;
static final String[] names= {"A", "B", "C", "D"};
static final int APP_COUNT=names.length;
static final Class<?>[] NAK_PROTS={NAKACK2.class};
static final short[] ids=new short[NAK_PROTS.length];
protected StateTransferApplication[] apps=new StateTransferApplication[APP_COUNT];
static {
for(int i=0; i < NAK_PROTS.length; i++)
ids[i]=ClassConfigurator.getProtocolId(NAK_PROTS[i]);
}
@DataProvider(name="createChannels")
protected Iterator<Object[]> createChannels() {
return new ArrayIterator<>(new Class[][]{{STATE_TRANSFER.class}, {STATE.class}, {STATE_SOCK.class}});
}
@AfterMethod protected void destroy() {
for(StateTransferApplication app: apps)
if(app != null) {
app.getChannel().setReceiver(null);
app.cleanup();
}
}
@Test(dataProvider="createChannels")
public void testStateTransferFromSelfWithRegularChannel(final Class<? extends Protocol> state_transfer_class) throws Exception {
JChannel ch=createChannel(true);
replaceStateTransferProtocolWith(ch, state_transfer_class);
ch.connect("StateTransferTest");
try {
Address self=ch.getAddress();
assert self != null;
ch.getState(self, 20000);
assert true : "getState() on self should return";
}
finally {
Util.close(ch);
}
}
// @Test(dataProvider="createChannels",invocationCount=10)
@Test(dataProvider="createChannels")
public void testStateTransferWhileSending(final Class<? extends Protocol> state_transfer_class) throws Exception {
Semaphore semaphore=new Semaphore(APP_COUNT, true); // fifo order
semaphore.acquire(APP_COUNT);
Thread[] threads=new Thread[APP_COUNT];
int from=0, to=MSG_SEND_COUNT;
for(int i=0;i < apps.length;i++) {
if(i == 0)
apps[i]=new StateTransferApplication(semaphore, names[i], from, to);
else
apps[i]=new StateTransferApplication(apps[0].getChannel(), semaphore, names[i], from, to);
replaceStateTransferProtocolWith(apps[i].getChannel(), state_transfer_class);
threads[i]=new Thread(apps[i], "thread-" + names[i]);
threads[i].start();
from+=MSG_SEND_COUNT;
to+=MSG_SEND_COUNT;
}
for(int i=0;i < threads.length; i++) {
semaphore.release();
Util.sleep(i == 0? 4000 : 100); // to reduce changes of a merge
}
// Make sure everyone is in sync
JChannel[] tmp=new JChannel[apps.length];
for(int i=0; i < apps.length; i++)
tmp[i]=apps[i].getChannel();
Util.waitUntilAllChannelsHaveSameView(20000, 1000, tmp);
for(Thread thread: threads)
thread.join(20000);
for(Thread thread: threads)
if(thread.isAlive())
throw new Exception("Thread " + thread.getName() + " is still alive");
// Sleep to ensure async messages arrive
System.out.println("Waiting for all channels to have " + MSG_SEND_COUNT * APP_COUNT + " elements:");
long end_time=System.currentTimeMillis() + 20000L;
while(System.currentTimeMillis() < end_time) {
boolean terminate=true;
for(StateTransferApplication app: apps) {
Map<String,List<Long>> map=app.getMap();
if(getSize(map) != MSG_SEND_COUNT * APP_COUNT) {
terminate=false;
break;
}
}
if(terminate)
break;
else {
resumeStableAndGC();
Util.sleep(500);
}
}
for(int i=0;i < apps.length;i++) {
StateTransferApplication w=apps[i];
ConcurrentMap<Address,AtomicInteger> map=w.getCount();
System.out.println("msgs for " + names[i] + ":");
for(Map.Entry<Address,AtomicInteger> entry: map.entrySet())
System.out.println("from " + entry.getKey() + " --> " + entry.getValue() + " msgs");
}
// have we received all and the correct messages?
System.out.println("++++++++++++++++++++++++++++++++++++++");
for(int i=0;i < apps.length;i++) {
StateTransferApplication w=apps[i];
Map<String,List<Long>> m=w.getMap();
System.out.println("\n" + names[i] + " (" + getSize(m) + "): digest=" + w.getChannel().down(Event.GET_DIGEST_EVT));
for(String name: names)
System.out.println("map " + name + ": " + print(m.get(name)));
}
System.out.println("++++++++++++++++++++++++++++++++++++++");
for(int i=0;i < apps.length;i++) {
StateTransferApplication w=apps[i];
Map<String,List<Long>> m=w.getMap();
assert getSize(m) == MSG_SEND_COUNT * APP_COUNT : "map " + names[i] + " has " + getSize(m) +
" elements (expected: " + MSG_SEND_COUNT * APP_COUNT + ")";
}
// compare the values
for(String name: names) {
List<Long> list=apps[0].getMap().get(name);
for(int i=1; i < apps.length; i++) {
StateTransferApplication app=apps[i];
List<Long> other_list=app.getMap().get(name);
assert list.equals(other_list);
}
}
}
protected void resumeStableAndGC() {
for(StateTransferApplication app: apps) {
STABLE stable=app.getChannel().getProtocolStack().findProtocol(STABLE.class);
stable.down(new Event(Event.RESUME_STABLE));
stable.gc();
}
}
protected String print(List<Long> list) {
if(list.isEmpty())
return "[] (0 elements)";
long first=list.get(0);
int size=list.size();
long last=list.get(size-1);
return "[" + first + " .. " + last + "] (" + size + " elements)";
}
protected int getSize(Map<String,List<Long>> map) {
int retval=0;
for(List<Long> list: map.values())
retval+=list.size();
return retval;
}
protected long getSeqno(Message msg) {
for(short id: ids) {
Header hdr=msg.getHeader(id);
if(hdr != null)
return getSeqnoFromHeader(hdr);
}
return -1;
}
protected long getSeqnoFromHeader(Header hdr) {
Field field=Util.getField(hdr.getClass(), "seqno");
return (Long)Util.getField(field, hdr);
}
protected void replaceStateTransferProtocolWith(JChannel ch, Class<? extends Protocol> state_transfer_class) throws Exception {
ProtocolStack stack=ch.getProtocolStack();
if(stack.findProtocol(state_transfer_class) != null)
return; // protocol of the right class is already in stack
Protocol prot=stack.findProtocol(STATE_TRANSFER.class, StreamingStateTransfer.class);
Protocol new_state_transfer_protcol=state_transfer_class.newInstance();
if(prot != null) {
stack.replaceProtocol(prot, new_state_transfer_protcol);
}
else { // no state transfer protocol found in stack
Protocol flush=stack.findProtocol(FLUSH.class);
if(flush != null)
stack.insertProtocol(new_state_transfer_protcol, ProtocolStack.Position.BELOW, FLUSH.class);
else
stack.insertProtocolAtTop(new_state_transfer_protcol);
}
}
protected class StateTransferApplication extends ReceiverAdapter implements Runnable {
protected final Map<String,List<Long>> map=new HashMap<>(MSG_SEND_COUNT * APP_COUNT);
protected final int from, to;
protected ConcurrentMap<Address,AtomicInteger> count=new ConcurrentHashMap<>();
protected final Semaphore semaphore;
protected final JChannel channel;
protected long start_time;
public StateTransferApplication(Semaphore semaphore, String name, int from, int to) throws Exception {
this.from=from;
this.to=to;
this.semaphore=semaphore;
init();
channel=createChannel(true, APP_COUNT, name);
channel.setReceiver(this);
}
public StateTransferApplication(JChannel copySource, Semaphore semaphore, String name, int from, int to) throws Exception {
this.from=from;
this.to=to;
this.semaphore=semaphore;
init();
this.channel=createChannel(copySource, name);
channel.setReceiver(this);
}
protected void init() {
for(String s: names)
map.put(s, new ArrayList<>(MSG_SEND_COUNT * APP_COUNT));
}
public JChannel getChannel() {
return channel;
}
public void cleanup() {Util.close(channel);}
public Map<String,List<Long>> getMap() {
synchronized(map) {
return map;
}
}
public ConcurrentMap<Address,AtomicInteger> getCount() {
return count;
}
public void receive(Message msg) {
String key=msg.getObject();
Address sender=msg.getSrc();
AtomicInteger cnt=count.get(sender);
if(cnt == null) {
cnt=new AtomicInteger(0);
AtomicInteger tmp=count.putIfAbsent(sender,cnt);
if(tmp != null)
cnt=tmp;
}
cnt.incrementAndGet();
long seqno=getSeqno(msg);
if(seqno == -1)
throw new IllegalArgumentException("NAKACK{2} seqno could not be fetched from message");
synchronized(map) {
List<Long> list=map.get(key);
// needed because we might get retransmissions of messages that are already in the state !
if(!list.contains(seqno))
list.add(seqno);
}
}
public void getState(OutputStream ostream) throws Exception {
OutputStream out=new BufferedOutputStream(ostream);
synchronized(map) {
Util.objectToStream(map, new DataOutputStream(out));
out.flush();
}
}
@SuppressWarnings("unchecked")
public void setState(InputStream istream) throws Exception {
Map<String,List<Long>> tmp=Util.objectFromStream(new DataInputStream(istream));
synchronized(map) {
map.clear();
map.putAll(tmp);
count.clear();
long time=System.currentTimeMillis() - start_time;
StringBuilder sb=new StringBuilder("\n++++++++++++++++++++++++++++++++++++++\n");
sb.append(channel.getAddress() + " <--- received state (in " + time + " ms), map has "
+ getSize(map) + " elements:\n");
for(Map.Entry<String,List<Long>> entry: map.entrySet())
sb.append(entry.getKey() + ": " + print(entry.getValue()) + "\n");
sb.append("++++++++++++++++++++++++++++++++++++++");
System.out.println(sb);
}
}
public void run() {
boolean acquired=false;
try {
acquired=semaphore.tryAcquire(60000L, TimeUnit.MILLISECONDS);
if(!acquired)
throw new Exception(channel.getAddress() + " cannot acquire semaphore");
useChannel();
}
catch(Exception e) {
log.error(channel.getAddress() + ": ", e);
}
}
protected void useChannel() throws Exception {
start_time=System.currentTimeMillis();
channel.connect("StateTransferTest", null, 20000);
int cnt=0;
for(int i=from; i < to; i++) {
try {
channel.send(null, channel.getName()); // the receiver uses name as key and the seqno of NAKACK{2} as value
cnt++;
if(cnt % 100 == 0)
Util.sleep(50);
}
catch(Exception e) {
e.printStackTrace();
break;
}
}
}
}
}