package org.jgroups.tests; import org.jgroups.*; import org.jgroups.protocols.MERGE2; import org.jgroups.protocols.pbcast.GMS; import org.jgroups.protocols.pbcast.NAKACK2; import org.jgroups.stack.ProtocolStack; import org.jgroups.util.Util; import org.testng.annotations.Test; import java.util.*; /** * Tests merging on all stacks * * @author vlada */ @Test(groups=Global.STACK_DEPENDENT,sequential=true) public class MergeTest extends ChannelTestBase { @Test public void testMerging2Members() throws Exception { mergeHelper("MergeTest.testMerging2Members", "A", "B"); } @Test public void testMerging4Members() throws Exception { mergeHelper("MergeTest.testMerging4Members", "A", "B", "C", "D"); } protected void mergeHelper(String cluster_name, String ... members) throws Exception { JChannel[] channels=null; try { channels=createChannels(cluster_name, members); print(channels); System.out.println("\ncreating partitions: "); createPartitions(channels, members); print(channels); for(String member: members) { JChannel ch=findChannel(member, channels); assert ch.getView().size() == 1 : "view of " + ch.getAddress() + ": " + ch.getView(); } Address merge_leader=determineLeader(channels, members); System.out.println("\n==== injecting merge event into merge leader : " + merge_leader + " ===="); injectMergeEvent(channels, merge_leader, members); for(int i=0; i < 40; i++) { System.out.print("."); if(allChannelsHaveViewOf(channels, members.length)) break; Util.sleep(1000); } System.out.println("\n"); print(channels); assertAllChannelsHaveViewOf(channels, members.length); } finally { if(channels != null) close(channels); } } private JChannel[] createChannels(String cluster_name, String[] members) throws Exception { JChannel[] retval=new JChannel[members.length]; JChannel ch=null; for(int i=0; i < retval.length; i++) { JChannel tmp; if(ch == null) { ch=createChannel(true, members.length); tmp=ch; } else { tmp=createChannel(ch); } tmp.setName(members[i]); ProtocolStack stack=tmp.getProtocolStack(); NAKACK2 nakack=(NAKACK2)stack.findProtocol(NAKACK2.class); if(nakack != null) nakack.setLogDiscardMessages(false); stack.removeProtocol(MERGE2.class); tmp.connect(cluster_name); retval[i]=tmp; } return retval; } private static void close(JChannel[] channels) { if(channels == null) return; for(int i=channels.length -1; i <= 0; i--) { JChannel ch=channels[i]; Util.close(ch); } } private static void createPartitions(JChannel[] channels, String ... partitions) throws Exception { checkUniqueness(partitions); List<View> views=new ArrayList<View>(partitions.length); for(String partition: partitions) { View view=createView(partition, channels); views.add(view); } applyViews(views, channels); } private static void checkUniqueness(String[] ... partitions) throws Exception { Set<String> set=new HashSet<String>(); for(String[] partition: partitions) { for(String tmp: partition) { if(!set.add(tmp)) throw new Exception("partitions are overlapping: element " + tmp + " is in multiple partitions"); } } } private static void injectMergeEvent(JChannel[] channels, String leader, String ... coordinators) { Address leader_addr=leader != null? findAddress(leader, channels) : determineLeader(channels); injectMergeEvent(channels, leader_addr, coordinators); } private static void injectMergeEvent(JChannel[] channels, Address leader_addr, String ... coordinators) { Map<Address,View> views=new HashMap<Address,View>(); for(String tmp: coordinators) { Address coord=findAddress(tmp, channels); views.put(coord, findView(tmp, channels)); } JChannel coord=findChannel(leader_addr, channels); GMS gms=(GMS)coord.getProtocolStack().findProtocol(GMS.class); gms.setLevel("trace"); gms.up(new Event(Event.MERGE, views)); } private static View createView(String partition, JChannel[] channels) throws Exception { List<Address> members=new ArrayList<Address>(); Address addr=findAddress(partition, channels); if(addr == null) throw new Exception(partition + " not associated with a channel"); members.add(addr); return new View(members.get(0), 10, members); } private static JChannel findChannel(String tmp, JChannel[] channels) { for(JChannel ch: channels) { if(ch.getName().equals(tmp)) return ch; } return null; } private static JChannel findChannel(Address addr, JChannel[] channels) { for(JChannel ch: channels) { if(ch.getAddress().equals(addr)) return ch; } return null; } private static View findView(String tmp, JChannel[] channels) { for(JChannel ch: channels) { if(ch.getName().equals(tmp)) return ch.getView(); } return null; } private static boolean allChannelsHaveViewOf(JChannel[] channels, int count) { for(JChannel ch: channels) { if(ch.getView().size() != count) return false; } return true; } private static void assertAllChannelsHaveViewOf(JChannel[] channels, int count) { for(JChannel ch: channels) assert ch.getView().size() == count : ch.getName() + " has view " + ch.getView(); } private static Address determineLeader(JChannel[] channels, String ... coords) { Membership membership=new Membership(); for(String coord: coords) membership.add(findAddress(coord, channels)); membership.sort(); return membership.elementAt(0); } private static Address findAddress(String tmp, JChannel[] channels) { for(JChannel ch: channels) { if(ch.getName().equals(tmp)) return ch.getAddress(); } return null; } private static void applyViews(List<View> views, JChannel[] channels) { for(View view: views) { Collection<Address> members=view.getMembers(); for(JChannel ch: channels) { Address addr=ch.getAddress(); if(members.contains(addr)) { GMS gms=(GMS)ch.getProtocolStack().findProtocol(GMS.class); gms.installView(view); } } } } private static void print(JChannel[] channels) { for(JChannel ch: channels) { System.out.println(ch.getName() + ": " + ch.getView()); } } }