package org.jgroups.tests; import org.jgroups.Address; import org.jgroups.Global; import org.jgroups.Message; import org.jgroups.conf.ClassConfigurator; import org.jgroups.protocols.*; import org.jgroups.util.MessageBatch; import org.jgroups.util.Util; import org.testng.annotations.Test; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.DataInputStream; import java.io.DataOutputStream; import java.util.*; import java.util.concurrent.LinkedBlockingQueue; import java.util.function.BiFunction; import java.util.function.Predicate; import java.util.stream.Collectors; /** * Tests {@link org.jgroups.util.MessageBatch} * @author Bela Ban * @since 3.3 */ @Test(groups=Global.FUNCTIONAL,singleThreaded=true) public class MessageBatchTest { protected static final short UNICAST3_ID=ClassConfigurator.getProtocolId(UNICAST3.class), PING_ID=ClassConfigurator.getProtocolId(PING.class), FD_ID=ClassConfigurator.getProtocolId(FD.class), MERGE_ID=ClassConfigurator.getProtocolId(MERGE3.class), UDP_ID=ClassConfigurator.getProtocolId(UDP.class); protected final Address a=Util.createRandomAddress("A"), b=Util.createRandomAddress("B"); protected static final BiFunction<Message,MessageBatch,Integer> print_numbers=(msg, batch) -> msg != null? (Integer)msg.getObject() : null; public void testCopyConstructor() { List<Message> msgs=createMessages(); MessageBatch batch=new MessageBatch(msgs); System.out.println("batch = " + batch); assert batch.size() == msgs.size() : "batch: " + batch; remove(batch, 3, 6, 10); System.out.println("batch = " + batch); assert batch.size() == msgs.size() -3 : "batch: " + batch; } public void testCapacityConstructor() { MessageBatch batch=new MessageBatch(3); assert batch.isEmpty(); } public void testCreationWithFilter() { List<Message> msgs=new ArrayList<>(10); for(int i=1; i <= 10; i++) msgs.add(new Message(null, i)); MessageBatch batch=new MessageBatch(null, null, null, true, msgs, msg -> msg != null && ((Integer)msg.getObject()) % 2 == 0); System.out.println(batch.map(print_numbers)); assert batch.size() == 5; for(Message msg: batch) assert ((Integer)msg.getObject()) % 2 == 0; } public void testCreationWithFilter2() { List<Message> msgs=new ArrayList<>(20); for(int i=1; i <= 20; i++) { Message msg=new Message(null, i); if(i <= 10) { msg.setFlag(Message.Flag.OOB); if(i % 2 == 0) msg.setTransientFlag(Message.TransientFlag.OOB_DELIVERED); } msgs.add(msg); } Predicate<Message> filter=msg -> msg != null && (!msg.isFlagSet(Message.Flag.OOB) || msg.setTransientFlagIfAbsent(Message.TransientFlag.OOB_DELIVERED)); MessageBatch batch=new MessageBatch(null, null, null, true, msgs, filter); System.out.println("batch = " + batch.map(print_numbers)); assert batch.size() == 15; for(Message msg: batch) { int num=msg.getObject(); if(num <= 10) assert msg.isTransientFlagSet(Message.TransientFlag.OOB_DELIVERED); } } public void testIsEmpty() { MessageBatch batch=new MessageBatch(3).add(new Message()).add(new Message()).add(new Message()); assert !batch.isEmpty(); for(Iterator<Message> it=batch.iterator(); it.hasNext();) { it.next(); it.remove(); } set(batch, 2, new Message()); assert !batch.isEmpty(); } public void testIsEmpty2() { List<Message> msgs=createMessages(); MessageBatch batch=new MessageBatch(msgs); batch.add(new Message()); assert !batch.isEmpty(); batch.clear(); assert batch.isEmpty(); msgs.forEach(batch::add); System.out.println("batch = " + batch); for(Iterator<Message> it=batch.iterator(); it.hasNext();) { it.next(); it.remove(); } System.out.println("batch = " + batch); assert batch.isEmpty(); } public void testSet() { List<Message> msgs=createMessages(); Message msg=msgs.get(5); MessageBatch batch=new MessageBatch(msgs); assert get(batch, 5) == msg; set(batch, 4,msg); assert get(batch, 4) == msg; } public void testReplace() { List<Message> msgs=createMessages(); MessageBatch batch=new MessageBatch(msgs); final Message MSG=new Message(); int index=0; for(Message msg: batch) { if(index % 2 == 0) batch.replace(msg,MSG); index++; } index=0; for(Message msg: batch) { if(index % 2 == 0) assert msg == MSG; // every even index has MSG index++; } } public void testReplace2() { List<Message> msgs=createMessages(); MessageBatch batch=new MessageBatch(msgs); final Message MSG=new Message(); batch.replace(MSG, null); assert batch.size() == msgs.size(); // MSG was *not* found and therefore *not* nulled batch.replace(get(batch, 5), null); assert batch.size() == msgs.size() -1; } public void testReplace3() { MessageBatch batch=new MessageBatch(1).add(new Message(null, "Bela")).add(new Message(null, "Michi")) .add(new Message(null, "Nicole")); System.out.println("batch = " + batch); for(Message msg: batch) { if("Michi".equals(msg.getObject())) { msg.setObject("Michelle"); batch.replace(msg, msg); // tests replacing the message with itself (with changed buffer though) } } Queue<String> names=new LinkedBlockingQueue<>(Arrays.asList("Bela", "Michelle", "Nicole")); for(Message msg: batch) { String expected=names.poll(); String name=msg.getObject(); System.out.println("found=" + name + ", expected=" + expected); assert name.equals(expected) : "found=" + name + ", expected=" + expected; } } public void testReplaceIf() { List<Message> msgs=createMessages(); MessageBatch batch=new MessageBatch(msgs); System.out.println("batch = " + batch); int size=batch.size(); int removed=batch.replaceIf(msg -> msg.getHeader(UNICAST3_ID) != null, null, true); System.out.println("batch = " + batch); assert batch.size() == size - removed; } public void testReplaceDuplicates() { final Set<Integer> dupes=new HashSet<>(5); Predicate<Message> filter=(msg) -> { Integer num=msg.getObject(); return dupes.add(num) == false; }; MessageBatch batch=new MessageBatch(10); for(int j=0; j < 2; j++) for(int i=1; i <= 5; i++) batch.add(new Message(null, i)); System.out.println(batch.map(print_numbers)); assert batch.size() == 10; batch.replace(filter, null, true); assert batch.size() == 5; System.out.println(batch.map(print_numbers)); } public void testRemove() { List<Message> msgs=createMessages(); MessageBatch batch=new MessageBatch(msgs); int prev_size=batch.size(); remove(batch, 1, 4); System.out.println("batch = " + batch); assert batch.size() == prev_size - 2; batch.clear(); System.out.println("batch = " + batch); assert batch.isEmpty(); msgs.forEach(batch::add); System.out.println("batch = " + batch); assert batch.size() == prev_size; assert batch.capacity() == prev_size; } public void testRemoveWithFilter() { Predicate<Message> filter=msg -> msg != null && msg.isTransientFlagSet(Message.TransientFlag.OOB_DELIVERED); MessageBatch batch=new MessageBatch(10); for(int i=1; i <= 10; i++) { Message msg=new Message(null, i); if(i % 2 == 0) msg.setTransientFlag(Message.TransientFlag.OOB_DELIVERED); batch.add(msg); } System.out.println("batch = " + batch); assert batch.size() == 10; batch.remove(filter); System.out.println("batch = " + batch); assert batch.size() == 5; for(int i=0; i < 5; i++) batch.add(new Message(null, i).setTransientFlag(Message.TransientFlag.OOB_DELIVERED)); System.out.println("batch = " + batch); batch.replace(filter, null, false); assert batch.size() == 9; } public void testTransfer() { MessageBatch other=new MessageBatch(3); List<Message> msgs=createMessages(); msgs.forEach(other::add); int other_size=other.size(); MessageBatch batch=new MessageBatch(5); int num=batch.transferFrom(other, true); assert num == other_size; assert batch.size() == other_size; assert other.isEmpty(); } public void testTransfer2() { MessageBatch other=new MessageBatch(3); List<Message> msgs=createMessages(); msgs.forEach(other::add); int other_size=other.size(); MessageBatch batch=new MessageBatch(5); msgs.forEach(batch::add); msgs.forEach(batch::add); System.out.println("batch = " + batch); int num=batch.transferFrom(other, true); assert num == other_size; assert batch.size() == other_size; assert other.isEmpty(); } public void testTransfer3() { MessageBatch other=new MessageBatch(30); MessageBatch batch=new MessageBatch(10); int num=batch.transferFrom(other, true); assert num == 0; assert batch.capacity() == 10; } public void testAdd() { MessageBatch batch=new MessageBatch(3); List<Message> msgs=createMessages(); msgs.forEach(batch::add); System.out.println("batch = " + batch); assert batch.size() == msgs.size() : "batch: " + batch; } public void testAddBatch() { MessageBatch batch=new MessageBatch(3), other=new MessageBatch(3); List<Message> msgs=createMessages(); msgs.forEach(other::add); assert other.size() == msgs.size(); batch.add(other); assert batch.size() == msgs.size() : "batch: " + batch; assert batch.size() == other.size(); } public void testAddNoResize() { MessageBatch batch=new MessageBatch(3); List<Message> msgs=createMessages(); for(int i=0; i < 3; i++) batch.add(msgs.get(i)); assert batch.size() == 3; assert batch.capacity() == 3; int added=batch.add(msgs.get(3), false); assert added == 0 && batch.size() == 3 && batch.capacity() == 3; } public void testAddBatchNoResizeOK() { MessageBatch batch=new MessageBatch(16); List<Message> msgs=createMessages(); MessageBatch other=new MessageBatch(3); msgs.forEach(other::add); assert other.size() == msgs.size(); assert batch.isEmpty(); int added=batch.add(other, false); assert added == other.size(); assert batch.size() == msgs.size() && batch.capacity() == 16; assert other.size() == msgs.size(); } public void testAddBatchNoResizeFail() { MessageBatch batch=new MessageBatch(3); List<Message> msgs=createMessages(); MessageBatch other=new MessageBatch(3); msgs.forEach(other::add); assert other.size() == msgs.size(); assert batch.isEmpty(); int added=batch.add(other, false); assert added == batch.size(); assert batch.size() == 3 && batch.capacity() == 3; assert other.size() == msgs.size(); } public void testAddBatch2() { MessageBatch other=new MessageBatch(3); List<Message> msgs=createMessages(); msgs.forEach(other::add); int other_size=other.size(); MessageBatch batch=new MessageBatch(5); batch.add(other); System.out.println("batch = " + batch); assert batch.size() == other_size; assert batch.capacity() >= other.capacity(); } public void testAddBatchToItself() { MessageBatch batch=new MessageBatch(16); for(Message msg: createMessages()) batch.add(msg); try { batch.add(batch); assert false: "should throw IllegalArumentException as a batch cannot be added to itself"; } catch(IllegalArgumentException ex) { System.out.printf("caught %s as expected: %s\n", ex.getClass().getSimpleName(), ex.getCause()); } } public void testGetMatchingMessages() { List<Message> msgs=createMessages(); MessageBatch batch=new MessageBatch(msgs); Collection<Message> matching=batch.getMatchingMessages(UDP_ID, false); assert matching.size() == batch.size(); assert batch.size() == msgs.size(); matching=batch.getMatchingMessages(FD_ID, true); assert matching.size() == 1; assert batch.size() == msgs.size() -1; int size=batch.size(); matching=batch.getMatchingMessages(UDP_ID, true); assert matching.size() == size; assert batch.isEmpty(); } public void testTotalSize() { List<Message> msgs=createMessages(); MessageBatch batch=new MessageBatch(msgs); long total_size=0; for(Message msg: msgs) total_size+=msg.size(); System.out.println("total size=" + batch.totalSize()); assert batch.totalSize() == total_size; } public void testSize() throws Exception { List<Message> msgs=createMessages(); ByteArrayOutputStream output=new ByteArrayOutputStream(); DataOutputStream out=new DataOutputStream(output); Util.writeMessageList(b, a, "cluster".getBytes(), msgs, out, false, UDP_ID); out.flush(); byte[] buf=output.toByteArray(); System.out.println("size=" + buf.length + " bytes, " + msgs.size() + " messages"); DataInputStream in=new DataInputStream(new ByteArrayInputStream(buf)); in.readShort(); // version in.readByte(); // flags List<Message> list=Util.readMessageList(in, UDP_ID); assert msgs.size() == list.size(); } public void testSize2() { List<Message> msgs=createMessages(); MessageBatch batch=new MessageBatch(msgs); assert batch.size() == msgs.size(); remove(batch, 2, 3, 10); assert batch.size() == msgs.size() - 3; } public void testIterator() { List<Message> msgs=createMessages(); MessageBatch batch=new MessageBatch(msgs); int index=0, count=0; for(Message msg: batch) { Message tmp=msgs.get(index++); count++; assert msg == tmp; } assert count == msgs.size(); } public void testStream() { List<Message> msgs=createMessages(); MessageBatch batch=new MessageBatch(msgs); long num_msgs=batch.stream() .filter(msg -> msg.getHeader(UNICAST3_ID) != null) .peek(msg -> System.out.printf("msg = %s, hdrs=%s\n", msg, msg.printHeaders())) .count(); System.out.println("num_msgs = " + num_msgs); assert num_msgs == 10; List<Message> list=batch.stream().collect(Collectors.toList()); assert list.size() == batch.size(); int total_size=batch.stream().map(Message::getLength).reduce(0, (l, r) -> l+r); assert total_size == 0; List<Long> msg_sizes=batch.stream().map(Message::size).collect(Collectors.toList()); System.out.println("msg_sizes = " + msg_sizes); assert msg_sizes.size() == batch.stream().count(); } public void testIterator2() { List<Message> msgs=createMessages(); MessageBatch batch=new MessageBatch(msgs); int count=0; for(Message ignored : batch) count++; assert count == msgs.size(); remove(batch, 3, 5, 10); count=0; for(Message ignored : batch) count++; assert count == msgs.size() - 3; } /** Test removal via iterator */ public void testIterator3() { List<Message> msgs=createMessages(); MessageBatch batch=new MessageBatch(msgs); for(Message msg: batch) if(msg.getHeader(UNICAST3_ID) != null) batch.remove(msg); System.out.println("batch = " + batch); assert batch.size() == 3; } public void testIterator4() { List<Message> msgs=createMessages(); MessageBatch batch=new MessageBatch(msgs); for(Message msg: batch) { if(msg.getHeader(UNICAST3_ID) != null) batch.remove(msg); } System.out.println("batch = " + batch); assert batch.size() == 3; } public void testIterator5() { List<Message> msgs=createMessages(); MessageBatch batch=new MessageBatch(msgs); Iterator<Message> itr=batch.iterator(); itr.remove(); assert batch.size() == msgs.size(); // didn't remove anything for(Iterator<Message> it=batch.iterator(); it.hasNext();) { Message msg=it.next(); if(msg != null && msg.getHeader(UNICAST3_ID) != null) it.remove(); } System.out.println("batch = " + batch); assert batch.size() == 3; } public void testIterator6() { List<Message> msgs=createMessages(); MessageBatch batch=new MessageBatch(msgs); remove(batch, 1, 2, 3, 10, msgs.size()-1); System.out.println("batch = " + batch); int count=0; for(Message ignored : batch) count++; assert count == msgs.size() - 5; count=0; batch.add(new Message()).add(new Message()); System.out.println("batch = " + batch); for(Message ignored : batch) count++; assert count == msgs.size() - 5+2; } public void testIteratorOnEmptyBatch() { MessageBatch batch=new MessageBatch(3); int count=0; for(Message ignored : batch) count++; assert count == 0; } public void testIterator7() { List<Message> msgs=createMessages(); MessageBatch batch=new MessageBatch(msgs); int index=0; final Message MSG=new Message(); for(Message msg: batch) { if(index % 2 == 0) batch.replace(msg, MSG); index++; } index=0; for(Message msg: batch) { if(index % 2 == 0) assert msg == MSG; // every even index has MSG index++; } } public void testIterator8() { List<Message> msgs=createMessages(); MessageBatch batch=new MessageBatch(msgs); int index=0; for(Iterator<Message> it=batch.iterator(); it.hasNext();) { it.next(); if(index == 1 || index == 2 || index == 3 || index == 10 || index == msgs.size()-1) it.remove(); index++; } System.out.println("batch = " + batch); int count=0; for(Message ignored : batch) count++; assert count == msgs.size() - 5; } public void testIterationWithAddition() { List<Message> msgs=createMessages(); MessageBatch batch=new MessageBatch(msgs); int count=0; for(Message ignored : batch) { count++; if(count % 2 == 0) batch.add(new Message()); } System.out.println("batch = " + batch); assert count == msgs.size() : "the added messages should *not* have been included"; } public void testIterationWithAddition2() { List<Message> msgs=createMessages(); MessageBatch batch=new MessageBatch(msgs); int count=0; for(Iterator<Message> it=batch.iterator(); it.hasNext();) { it.next(); count++; if(count % 2 == 0) batch.add(new Message()); } System.out.println("batch = " + batch); assert count == msgs.size() : "the added messages should *not* have been included"; } public void testForEach() { MessageBatch batch=new MessageBatch(10); for(int i=0; i < 10; i++) batch.add(new Message(a, i)); System.out.println("batch = " + batch); assert batch.size() == 10; batch.remove(msg -> { // removes all msgs with even-numbered payloads int num=msg.getObject(); return num % 2 == 0; }); System.out.println("batch = " + batch); assert batch.size() == 5; } protected MessageBatch remove(MessageBatch batch, int ... indices) { Message[] msgs=batch.array(); for(int index: indices) msgs[index]=null; return batch; } protected MessageBatch set(MessageBatch batch, int index, Message msg) { Message[] msgs=batch.array(); msgs[index]=msg; return batch; } protected Message get(MessageBatch batch, int index) { return batch.array()[index]; } protected List<Message> createMessages() { List<Message> retval=new ArrayList<>(10); for(long seqno=1; seqno <= 5; seqno++) retval.add(new Message(b).putHeader(UNICAST3_ID, UnicastHeader3.createDataHeader(seqno, (short)22, false))); retval.add(new Message(b).putHeader(PING_ID, new PingHeader(PingHeader.GET_MBRS_RSP).clusterName("demo-cluster"))); retval.add(new Message(b).putHeader(FD_ID, new FD.FdHeader(org.jgroups.protocols.FD.FdHeader.HEARTBEAT))); retval.add(new Message(b).putHeader(MERGE_ID, MERGE3.MergeHeader.createViewResponse())); for(long seqno=6; seqno <= 10; seqno++) retval.add(new Message(b).putHeader(UNICAST3_ID, UnicastHeader3.createDataHeader(seqno, (short)22, false))); for(Message msg: retval) msg.putHeader(UDP_ID, new TpHeader("demo-cluster")); return retval; } }