package org.jgroups.tests;
import org.jgroups.*;
import org.jgroups.protocols.DISCARD;
import org.jgroups.protocols.MERGE3;
import org.jgroups.protocols.TP;
import org.jgroups.protocols.pbcast.GMS;
import org.jgroups.protocols.pbcast.NAKACK2;
import org.jgroups.stack.ProtocolStack;
import org.jgroups.util.Util;
import org.testng.annotations.AfterMethod;
import org.testng.annotations.Test;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* Tests merging on all stacks
*
* @author vlada
*/
@Test(groups=Global.STACK_DEPENDENT,singleThreaded=true)
public class MergeTest extends ChannelTestBase {
protected JChannel[] channels=null;
@AfterMethod protected void destroy() {
level("warn", channels);
for(JChannel ch: channels)
try {
Util.shutdown(ch);
}
catch(Exception e) {
e.printStackTrace();
}
}
public void testMerging2Members() throws Exception {
mergeHelper("MergeTest.testMerging2Members", "A", "B");
}
public void testMerging4Members() throws Exception {
mergeHelper("MergeTest.testMerging4Members", "A", "B", "C", "D");
}
protected void mergeHelper(String cluster_name, String ... members) throws Exception {
channels=createChannels(cluster_name, members);
Util.waitUntilAllChannelsHaveSameView(10000, 500, channels);
print(channels);
System.out.println("\ncreating partitions: ");
createPartitions(channels);
print(channels);
for(JChannel ch: channels)
assert ch.getView().size() == 1 : "view is " + ch.getView();
Address merge_leader=determineLeader(channels, members);
System.out.println("\n==== injecting merge event into merge leader : " + merge_leader + " ====");
for(JChannel ch: channels)
ch.getProtocolStack().removeProtocol(DISCARD.class);
injectMergeEvent(channels, merge_leader, members);
for(int i=0; i < 40; i++) {
System.out.print(".");
if(allChannelsHaveViewOf(channels, members.length))
break;
Util.sleep(1000);
}
System.out.println("\n");
print(channels);
assertAllChannelsHaveViewOf(channels, members.length);
}
protected static void level(String level, JChannel ... channels) {
for(JChannel ch: channels) {
GMS gms=(GMS)ch.getProtocolStack().findProtocol(GMS.class);
gms.setLevel(level);
}
}
protected JChannel[] createChannels(String cluster_name, String[] members) throws Exception {
JChannel[] retval=new JChannel[members.length];
JChannel ch=null;
for(int i=0; i < retval.length; i++) {
JChannel tmp;
if(ch == null) {
ch=createChannel(true, members.length);
tmp=ch;
}
else {
tmp=createChannel(ch);
}
tmp.setName(members[i]);
ProtocolStack stack=tmp.getProtocolStack();
NAKACK2 nakack=(NAKACK2)stack.findProtocol(NAKACK2.class);
if(nakack != null)
nakack.setLogDiscardMessages(false);
stack.removeProtocol(MERGE3.class);
tmp.connect(cluster_name);
retval[i]=tmp;
}
return retval;
}
private static void close(JChannel[] channels) {
Util.close(channels);
}
private static void createPartitions(JChannel[] channels) throws Exception {
long view_id=1; // find the highest view-id +1
for(JChannel ch: channels)
view_id=Math.max(ch.getView().getViewId().getId(), view_id);
view_id++;
for(JChannel ch: channels) {
DISCARD discard=new DISCARD();
discard.setDiscardAll(true);
ch.getProtocolStack().insertProtocol(discard, ProtocolStack.Position.ABOVE,TP.class);
}
for(JChannel ch: channels) {
View view=View.create(ch.getAddress(), view_id, ch.getAddress());
GMS gms=(GMS)ch.getProtocolStack().findProtocol(GMS.class);
gms.installView(view);
}
}
private static void injectMergeEvent(JChannel[] channels, String leader, String ... coordinators) {
Address leader_addr=leader != null? findAddress(leader, channels) : determineLeader(channels);
injectMergeEvent(channels, leader_addr, coordinators);
}
private static void injectMergeEvent(JChannel[] channels, Address leader_addr, String ... coordinators) {
Map<Address,View> views=new HashMap<>();
for(String tmp: coordinators) {
Address coord=findAddress(tmp, channels);
views.put(coord, findView(tmp, channels));
}
JChannel coord=findChannel(leader_addr, channels);
GMS gms=(GMS)coord.getProtocolStack().findProtocol(GMS.class);
gms.setLevel("trace");
gms.up(new Event(Event.MERGE, views));
}
private static JChannel findChannel(Address addr, JChannel[] channels) {
for(JChannel ch: channels) {
if(ch.getAddress().equals(addr))
return ch;
}
return null;
}
private static View findView(String tmp, JChannel[] channels) {
for(JChannel ch: channels) {
if(ch.getName().equals(tmp))
return ch.getView();
}
return null;
}
private static boolean allChannelsHaveViewOf(JChannel[] channels, int count) {
for(JChannel ch: channels) {
if(ch.getView().size() != count)
return false;
}
return true;
}
private static void assertAllChannelsHaveViewOf(JChannel[] channels, int count) {
for(JChannel ch: channels)
assert ch.getView().size() == count : ch.getName() + " has view " + ch.getView();
}
private static Address determineLeader(JChannel[] channels, String ... coords) {
Membership membership=new Membership();
for(String coord: coords)
membership.add(findAddress(coord, channels));
return membership.sort().elementAt(0);
}
private static Address findAddress(String tmp, JChannel[] channels) {
for(JChannel ch: channels) {
if(ch.getName().equals(tmp))
return ch.getAddress();
}
return null;
}
private static void applyViews(List<View> views, JChannel[] channels) {
for(View view: views) {
Collection<Address> members=view.getMembers();
for(JChannel ch: channels) {
Address addr=ch.getAddress();
if(members.contains(addr)) {
GMS gms=(GMS)ch.getProtocolStack().findProtocol(GMS.class);
gms.installView(view);
}
}
}
}
private static void print(JChannel[] channels) {
for(JChannel ch: channels) {
System.out.println(ch.getName() + ": " + ch.getView());
}
}
}