package org.jgroups.protocols; import org.jgroups.*; import org.jgroups.conf.ProtocolStackConfigurator; import org.jgroups.conf.ClassConfigurator; import org.jgroups.protocols.pbcast.GMS; import org.jgroups.protocols.pbcast.NAKACK2; import org.jgroups.protocols.pbcast.STABLE; import org.jgroups.stack.ProtocolStack; import org.jgroups.tests.ChannelTestBase; import org.jgroups.util.Digest; import org.jgroups.util.MergeId; import org.jgroups.util.Util; import org.testng.annotations.Test; import org.w3c.dom.Element; import java.io.File; import java.net.URL; import java.util.*; import java.util.concurrent.atomic.AtomicInteger; /** * Tests the GMS protocol for merging functionality * @author Bela Ban */ @Test(groups={Global.STACK_INDEPENDENT}, sequential=true) public class GMS_MergeTest extends ChannelTestBase { static final String simple_props="SHARED_LOOPBACK:PING(timeout=1000):" + "pbcast.NAKACK2(use_mcast_xmit=false;log_discard_msgs=false;log_not_found_msgs=false)" + ":UNICAST:pbcast.STABLE(stability_delay=200):pbcast.GMS:FC:FRAG2"; static final String flush_props=simple_props + ":pbcast.FLUSH"; static final short GMS_ID=ClassConfigurator.getProtocolId(GMS.class); public static void testMergeRequestTimeout() throws Exception { _testMergeRequestTimeout(simple_props, "testMergeRequestTimeout"); } public static void testMergeRequestTimeoutWithFlush() throws Exception { _testMergeRequestTimeout(flush_props, "testMergeRequestTimeoutWithFlush"); } public static void testSimpleMerge() throws Exception { _testSimpleMerge(simple_props, "testSimpleMerge"); } public static void testSimpleMergeWithFlush() throws Exception { _testSimpleMerge(flush_props, "testSimpleMergeWithFlush"); } public static void testConcurrentMergeTwoPartitions() throws Exception { _testConcurrentMergeTwoPartitions(simple_props, "testConcurrentMergeTwoPartitions"); } public static void testConcurrentMergeTwoPartitionsWithFlush() throws Exception { _testConcurrentMergeTwoPartitions(flush_props, "testConcurrentMergeTwoPartitionsWithFlush"); } public static void testConcurrentMergeMultiplePartitions() throws Exception { _testConcurrentMergeMultiplePartitions(simple_props, "testConcurrentMergeMultiplePartitions"); } public static void testConcurrentMergeMultiplePartitionsWithFlush() throws Exception { _testConcurrentMergeMultiplePartitions(flush_props, "testConcurrentMergeMultiplePartitionsWithFlush"); } public static void testMergeAsymmetricPartitions() throws Exception { _testMergeAsymmetricPartitions(simple_props, "testMergeAsymmetricPartitions"); } public static void testMergeAsymmetricPartitionsWithFlush() throws Exception { _testMergeAsymmetricPartitions(flush_props, "testMergeAsymmetricPartitionsWithFlush"); } public static void testMergeAsymmetricPartitions2() throws Exception { _testMergeAsymmetricPartitions2(simple_props, "testMergeAsymmetricPartitions2"); } public static void testMergeAsymmetricPartitionsWithFlush2() throws Exception { _testMergeAsymmetricPartitions2(flush_props, "testMergeAsymmetricPartitionsWithFlush2"); } /** * Simulates the death of a merge leader after having sent a MERGE_REQ. Because there is no MergeView or CANCEL_MERGE * message, the MergeCanceller has to null merge_id after a timeout */ static void _testMergeRequestTimeout(String props, String cluster_name) throws Exception { JChannel c1=new JChannel(props); try { c1.connect(cluster_name); Message merge_request=new Message(); GMS.GmsHeader hdr=new GMS.GmsHeader(GMS.GmsHeader.MERGE_REQ); MergeId new_merge_id=MergeId.create(c1.getAddress()); hdr.setMergeId(new_merge_id); merge_request.putHeader(GMS_ID, hdr); GMS gms=(GMS)c1.getProtocolStack().findProtocol(GMS.class); gms.setMergeTimeout(2000); MergeId merge_id=gms._getMergeId(); assert merge_id == null; System.out.println("starting merge"); gms.up(new Event(Event.MSG, merge_request)); long timeout=gms.getMergeTimeout() * 10; System.out.println("sleeping for " + timeout + " ms, then fetching merge_id: should be null (cancelled by the MergeCanceller)"); long target_time=System.currentTimeMillis() + timeout; while(System.currentTimeMillis() < target_time) { merge_id=gms._getMergeId(); if(merge_id == null) break; Util.sleep(500); } merge_id=gms._getMergeId(); System.out.println("merge_id = " + merge_id); assert merge_id == null : "MergeCanceller didn't kick in"; } finally { Util.close(c1); } } static void _testSimpleMerge(String props, String cluster_name) throws Exception { JChannel[] channels=null; try { channels=create(props, true, cluster_name, "A", "B", "C", "D"); print(channels); View view=channels[channels.length -1].getView(); assert view.size() == channels.length : "view is " + view; System.out.println("\ncreating partitions: "); String[][] partitions=generate(new String[]{"A", "B"}, new String[]{"C", "D"}); createPartitions(channels, partitions); print(channels); checkViews(channels, "A", "A", "B"); checkViews(channels, "B", "A", "B"); checkViews(channels, "C", "C", "D"); checkViews(channels, "C", "C", "D"); System.out.println("\ndigests:"); printDigests(channels); Address leader=determineLeader(channels, "A", "C"); long end_time=System.currentTimeMillis() + 30000; do { System.out.println("\n==== injecting merge events into " + leader + " ===="); injectMergeEvent(channels, leader, "A", "C"); Util.sleep(1000); if(allChannelsHaveViewOf(channels, channels.length)) break; } while(end_time > System.currentTimeMillis()); System.out.println("\n"); print(channels); assertAllChannelsHaveViewOf(channels, channels.length); System.out.println("\ndigests:"); printDigests(channels); } finally { System.out.println("closing channels"); close(channels); System.out.println("done"); } } static void _testConcurrentMergeTwoPartitions(String props, String cluster_name) throws Exception { JChannel[] channels=null; try { channels=create(props, true, cluster_name, "A", "B", "C", "D"); print(channels); View view=channels[channels.length -1].getView(); assert view.size() == channels.length : "view is " + view; System.out.println("\ncreating partitions: "); String[][] partitions=generate(new String[]{"A", "B"}, new String[]{"C", "D"}); createPartitions(channels, partitions); print(channels); checkViews(channels, "A", "A", "B"); checkViews(channels, "B", "A", "B"); checkViews(channels, "C", "C", "D"); checkViews(channels, "D", "C", "D"); Address merge_leader=determineLeader(channels, "A", "C"); long end_time=System.currentTimeMillis() + 30000; do { System.out.println("\n==== injecting merge events into A and C concurrently ===="); injectMergeEvent(channels, merge_leader, "A", "C"); Util.sleep(2000); if(allChannelsHaveViewOf(channels, 4)) break; } while(end_time > System.currentTimeMillis()); System.out.println("\n"); print(channels); assertAllChannelsHaveViewOf(channels, 4); } finally { close(channels); } } static void _testConcurrentMergeMultiplePartitions(String props, String cluster_name) throws Exception { JChannel[] channels=null; try { channels=create(props, true, cluster_name, "A", "B", "C", "D", "E", "F", "G", "H"); print(channels); View view=channels[channels.length -1].getView(); assert view.size() == channels.length : "view is " + view; assertAllChannelsHaveViewOf(channels, 8); System.out.println("\ncreating partitions: "); String[][] partitions=generate(new String[]{"A", "B"}, new String[]{"C", "D"}, new String[]{"E", "F"}, new String[]{"G", "H"}); createPartitions(channels, partitions); print(channels); checkViews(channels, "A", "A", "B"); checkViews(channels, "B", "A", "B"); checkViews(channels, "C", "C", "D"); checkViews(channels, "D", "C", "D"); checkViews(channels, "E", "E", "F"); checkViews(channels, "F", "E", "F"); checkViews(channels, "G", "G", "H"); checkViews(channels, "H", "G", "H"); Address merge_leader=determineLeader(channels, "A", "C", "E", "G"); long end_time=System.currentTimeMillis() + 30000; do { System.out.println("\n==== injecting merge event into " + merge_leader + " ===="); injectMergeEvent(channels, merge_leader, "A", "C", "E", "G"); Util.sleep(2000); if(allChannelsHaveViewOf(channels, 8)) break; } while(end_time > System.currentTimeMillis()); print(channels); assertAllChannelsHaveViewOf(channels, 8); disableTracing(channels); } finally { close(channels); } } /** * Tests the merge of the following partitions: * <ul> * <li>A: {B, A, C} * <li>B: {B, C} * <li>C: {B, C} * </ol> * JIRA: https://jira.jboss.org/jira/browse/JGRP-1031 * @throws Exception */ static void _testMergeAsymmetricPartitions(String props, String cluster_name) throws Exception { JChannel[] channels=null; MyReceiver[] receivers; final int NUM=10; try { // use simple IDs for UUIDs, so sorting on merge will NOT change the view order channels=create(props, true, cluster_name, "B", "A", "C"); receivers=new MyReceiver[channels.length]; for(int i=0; i < channels.length; i++) { receivers[i]=new MyReceiver(channels[i].getName()); channels[i].setReceiver(receivers[i]); } JChannel a=findChannel("A", channels), b=findChannel("B", channels), c=findChannel("C", channels); print(channels); View view=channels[channels.length -1].getView(); assert view.size() == channels.length : "view is " + view; System.out.println("sending " + NUM + " msgs:"); for(int i=0; i < NUM; i++) for(JChannel ch: channels) ch.send(null, "Number #" + i + " from " + ch.getAddress()); waitForNumMessages(NUM * channels.length, 10000, 1000, receivers); checkMessages(NUM * channels.length, receivers); System.out.println("\ncreating partitions: "); applyView(channels, "A", "B", "A", "C"); applyView(channels, "B", "B", "C"); applyView(channels, "C", "B", "C"); print(channels); checkViews(channels, "A", "B", "A", "C"); checkViews(channels, "B", "B", "C"); checkViews(channels, "C", "B", "C"); for(MyReceiver receiver: receivers) receiver.clear(); DISCARD discard=new DISCARD(); discard.addIgnoreMember(b.getAddress()); discard.addIgnoreMember(c.getAddress()); // A should drop all traffic from B or C a.getProtocolStack().insertProtocol(discard, ProtocolStack.ABOVE, SHARED_LOOPBACK.class); System.out.println("B and C exchange " + NUM + " messages, A discards them"); for(int i=0; i < NUM; i++) b.send(null, "message #" + i +" from B"); for(int i=0; i < NUM; i++) c.send(null, "message #" + i +" from C"); waitForNumMessages(NUM * 2, 10000, 500, receivers[0], receivers[2]); // A *does* receiver B's and C's messages ! checkMessages(NUM * 2, receivers[0], receivers[2]); checkMessages(0, receivers[1]); Digest da=((NAKACK2)a.getProtocolStack().findProtocol(NAKACK2.class)).getDigest(), db=((NAKACK2)b.getProtocolStack().findProtocol(NAKACK2.class)).getDigest(), dc=((NAKACK2)c.getProtocolStack().findProtocol(NAKACK2.class)).getDigest(); System.out.println("Digest A: " + da + "\nDigest B: " + db + "\nDigest C: " + dc); System.out.println("Running stability protocol on B and C now"); // a.getProtocolStack().findProtocol(STABLE.class).setLevel("trace"); // b.getProtocolStack().findProtocol(STABLE.class).setLevel("trace"); // c.getProtocolStack().findProtocol(STABLE.class).setLevel("trace"); for(int i=0; i < 3; i++) { ((STABLE)b.getProtocolStack().findProtocol(STABLE.class)).runMessageGarbageCollection(); ((STABLE)c.getProtocolStack().findProtocol(STABLE.class)).runMessageGarbageCollection(); Util.sleep(300); } db=((NAKACK2)a.getProtocolStack().findProtocol(NAKACK2.class)).getDigest(); db=((NAKACK2)b.getProtocolStack().findProtocol(NAKACK2.class)).getDigest(); dc=((NAKACK2)c.getProtocolStack().findProtocol(NAKACK2.class)).getDigest(); System.out.println("(after purging)\nDigest A: " + da + "\nDigest B: " + db + "\nDigest C: " + dc); // now enable traffic reception of B and C on A: discard.removeIgnoredMember(b.getAddress()); discard.removeIgnoredMember(c.getAddress()); Address leader=b.getAddress(); long end_time=System.currentTimeMillis() + 12000; do { System.out.println("\n==== injecting merge event into " + leader + " ===="); injectMergeEvent(channels, leader, "B", "A", "C"); Util.sleep(3000); if(allChannelsHaveView(channels, b.getView())) break; } while(end_time > System.currentTimeMillis()); System.out.println("\n"); print(channels); assertAllChannelsHaveView(channels, b.getView()); } finally { close(channels); } } /** * Tests the merge of the following partitions: * <ul> * <li>A: {A,B} * <li>B: {B} * </ol> * JIRA: https://jira.jboss.org/jira/browse/JGRP-1031 * @throws Exception */ static void _testMergeAsymmetricPartitions2(String props, String cluster_name) throws Exception { JChannel[] channels=null; MyReceiver[] receivers; final int NUM=10; try { // use simple IDs for UUIDs, so sorting on merge will NOT change the view order channels=create(props, true, cluster_name, "A", "B"); receivers=new MyReceiver[channels.length]; for(int i=0; i < channels.length; i++) { receivers[i]=new MyReceiver(channels[i].getName()); channels[i].setReceiver(receivers[i]); } JChannel a=findChannel("A", channels), b=findChannel("B", channels); print(channels); View view=channels[channels.length -1].getView(); assert view.size() == channels.length : "view is " + view; System.out.println("sending " + NUM + " msgs:"); for(int i=0; i < NUM; i++) for(JChannel ch: channels) ch.send(null, "Number #" + i + " from " + ch.getAddress()); waitForNumMessages(NUM * channels.length, 10000, 1000, receivers); checkMessages(NUM * channels.length, receivers); System.out.println("\ncreating partitions: "); applyView(channels, "B", "B"); // B has view {B} print(channels); checkViews(channels, "A", "A", "B"); // A: {A,B} checkViews(channels, "B", "B"); // B: {B} for(MyReceiver receiver: receivers) receiver.clear(); Digest da=((NAKACK2)a.getProtocolStack().findProtocol(NAKACK2.class)).getDigest(); Digest db=((NAKACK2)b.getProtocolStack().findProtocol(NAKACK2.class)).getDigest(); System.out.println("(after purging)\nDigest A: " + da + "\nDigest B: " + db); Address merge_leader=determineLeader(channels, "A", "B"); long end_time=System.currentTimeMillis() + 12000; do { System.out.println("\n==== injecting merge event into " + merge_leader + " ===="); injectMergeEvent(channels, merge_leader, "A", "B"); Util.sleep(3000); if(allChannelsHaveView(channels, a.getView())) break; } while(end_time > System.currentTimeMillis()); System.out.println("\n"); print(channels); assertAllChannelsHaveView(channels, a.getView()); } finally { close(channels); } } /** * First name is the channel name, the rest is the view to be applied * @param members */ private static void applyView(JChannel[] channels, String member, String ... members) throws Exception { JChannel ch=findChannel(member, channels); View view=createView(members, channels); GMS gms=(GMS)ch.getProtocolStack().findProtocol(GMS.class); gms.installView(view); } private static boolean allChannelsHaveViewOf(JChannel[] channels, int count) { for(JChannel ch: channels) { if(ch.getView().size() != count) return false; } return true; } private static boolean allChannelsHaveView(JChannel[] channels, View view) { for(JChannel ch: channels) { if(!ch.getView().equals(view)) return false; } return true; } private static void assertAllChannelsHaveView(JChannel[] channels, View view) { for(JChannel ch: channels) { View v=ch.getView(); assert v.equals(view) : "expected view " + view + " but got " + v + " for channel " + ch.getName(); } } private static void assertAllChannelsHaveViewOf(JChannel[] channels, int count) { for(JChannel ch: channels) assert ch.getView().size() == count : ch.getName() + " has view " + ch.getView() + " (should have " + count + " mbrs)"; } private static void close(JChannel[] channels) { if(channels == null) return; for(int i=channels.length -1; i >= 0; i--) { JChannel ch=channels[i]; Util.close(ch); } } /*private static JChannel[] create(String cluster_name, String ... names) throws Exception { return create(false, cluster_name, names); }*/ private static JChannel[] create(String props, boolean simple_ids, String cluster_name, String ... names) throws Exception { JChannel[] retval=new JChannel[names.length]; for(int i=0; i < retval.length; i++) { JChannel ch; if(simple_ids) { ch=new MyChannel(props); ((MyChannel)ch).setId(i+1); } else ch=new JChannel(props); ch.setName(names[i]); retval[i]=ch; ch.connect(cluster_name); if(i == 0) Util.sleep(3000); } return retval; } private static void createPartitions(JChannel[] channels, String[]... partitions) throws Exception { checkUniqueness(partitions); List<View> views=new ArrayList<View>(partitions.length); for(String[] partition: partitions) { View view=createView(partition, channels); views.add(view); } applyViews(views, channels); } private static void injectMergeEvent(JChannel[] channels, String leader, String ... coordinators) { Address leader_addr=leader != null? findAddress(leader, channels) : determineLeader(channels); injectMergeEvent(channels, leader_addr, coordinators); } private static void injectMergeEvent(JChannel[] channels, Address leader_addr, String ... coordinators) { Map<Address,View> views=new HashMap<Address,View>(); for(String tmp: coordinators) { Address coord=findAddress(tmp, channels); views.put(coord, findView(tmp, channels)); } JChannel coord=findChannel(leader_addr, channels); GMS gms=(GMS)coord.getProtocolStack().findProtocol(GMS.class); gms.setLevel("trace"); gms.up(new Event(Event.MERGE, views)); } private static Address determineLeader(JChannel[] channels, String ... coords) { Membership membership=new Membership(); for(String coord: coords) membership.add(findAddress(coord, channels)); membership.sort(); return membership.elementAt(0); } private static String[][] generate(String[] ... partitions) { String[][] retval=new String[partitions.length][]; System.arraycopy(partitions, 0, retval, 0, partitions.length); return retval; } private static void checkUniqueness(String[] ... partitions) throws Exception { Set<String> set=new HashSet<String>(); for(String[] partition: partitions) { for(String tmp: partition) { if(!set.add(tmp)) throw new Exception("partitions are overlapping: element " + tmp + " is in multiple partitions"); } } } protected static void disableTracing(JChannel[] channels) { for(JChannel ch: channels) { GMS gms=(GMS)ch.getProtocolStack().findProtocol(GMS.class); gms.setLevel("warn"); } } private static View createView(String[] partition, JChannel[] channels) throws Exception { List<Address> members=new ArrayList<Address>(partition.length); for(String tmp: partition) { Address addr=findAddress(tmp, channels); if(addr == null) throw new Exception(tmp + " not associated with a channel"); members.add(addr); } return new View(members.get(0), 10, members); } private static void checkViews(JChannel[] channels, String channel_name, String ... members) { JChannel ch=findChannel(channel_name, channels); View view=ch.getView(); assert view.size() == members.length : "view is " + view + ", members: " + Arrays.toString(members); for(String member: members) { Address addr=findAddress(member, channels); assert view.getMembers().contains(addr) : "view " + view + " does not contain " + addr; } } private static JChannel findChannel(String tmp, JChannel[] channels) { for(JChannel ch: channels) { if(ch.getName().equals(tmp)) return ch; } return null; } private static JChannel findChannel(Address addr, JChannel[] channels) { for(JChannel ch: channels) { if(ch.getAddress().equals(addr)) return ch; } return null; } private static Address findAddress(String tmp, JChannel[] channels) { for(JChannel ch: channels) { if(ch.getName().equals(tmp)) return ch.getAddress(); } return null; } private static View findView(String tmp, JChannel[] channels) { for(JChannel ch: channels) { if(ch.getName().equals(tmp)) return ch.getView(); } return null; } private static void applyViews(List<View> views, JChannel[] channels) { for(View view: views) { Collection<Address> members=view.getMembers(); for(JChannel ch: channels) { Address addr=ch.getAddress(); if(members.contains(addr)) { GMS gms=(GMS)ch.getProtocolStack().findProtocol(GMS.class); gms.installView(view); } } } } private static void print(JChannel[] channels) { for(JChannel ch: channels) { System.out.println(ch.getName() + ": " + ch.getView()); } } private static void printDigests(JChannel[] channels) { for(JChannel ch: channels) { NAKACK2 nak=(NAKACK2)ch.getProtocolStack().findProtocol(NAKACK2.class); Digest digest=nak.getDigest(); System.out.println(ch.getName() + ": " + digest.toStringSorted()); } } static void waitForNumMessages(int num_msgs, long timeout, long interval, MyReceiver ... receivers) { long target_time=System.currentTimeMillis() + timeout; while(System.currentTimeMillis() < target_time) { boolean all_received=true; for(MyReceiver receiver: receivers) { if(receiver.getNumMsgs() < num_msgs) { all_received=false; break; } } if(all_received) break; Util.sleep(interval); } } static void checkMessages(int expected, MyReceiver ... receivers) { for(MyReceiver receiver: receivers) System.out.println(receiver.name + ": " + receiver.getNumMsgs()); for(MyReceiver receiver: receivers) assert receiver.getNumMsgs() == expected : "[" + receiver.name + "] expected " + expected + " msgs, but received " + receiver.getNumMsgs(); } private static class MyChannel extends JChannel { protected int id=0; private MyChannel() throws Exception { super(); } private MyChannel(File properties) throws Exception { super(properties); } private MyChannel(Element properties) throws Exception { super(properties); } private MyChannel(URL properties) throws Exception { super(properties); } private MyChannel(String properties) throws Exception { super(properties); } private MyChannel(ProtocolStackConfigurator configurator) throws Exception { super(configurator); } private MyChannel(JChannel ch) throws Exception { super(ch); } public void setId(int id) { this.id=id; } protected void setAddress() { Address old_addr=local_addr; local_addr=new org.jgroups.util.UUID(id, id); if(old_addr != null) down(new Event(Event.REMOVE_ADDRESS, old_addr)); if(name == null || name.length() == 0) // generate a logical name if not set name=Util.generateLocalName(); if(name != null && name.length() > 0) org.jgroups.util.UUID.add(local_addr, name); Event evt=new Event(Event.SET_LOCAL_ADDRESS, local_addr); down(evt); if(up_handler != null) up_handler.up(evt); } } private static class MyReceiver extends ReceiverAdapter { private final String name; private AtomicInteger num_msgs=new AtomicInteger(0); private MyReceiver(String name) { this.name=name; } public int getNumMsgs() {return num_msgs.get();} public void clear() {num_msgs.set(0);} public void receive(Message msg) { num_msgs.incrementAndGet(); } public void viewAccepted(View new_view) { System.out.println("[" + name + "] view=" + new_view); } } }