package org.infinispan.remoting; import static org.infinispan.test.TestingUtil.extractCommandsFactory; import static org.infinispan.test.TestingUtil.extractGlobalComponent; import static org.infinispan.test.fwk.TestCacheManagerFactory.createClusteredCacheManager; import static org.infinispan.test.fwk.TestCacheManagerFactory.getDefaultCacheConfiguration; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; import java.util.Collections; import java.util.List; import java.util.concurrent.AbstractExecutorService; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; import java.util.concurrent.TimeUnit; import org.infinispan.Cache; import org.infinispan.commands.CommandsFactory; import org.infinispan.commands.ReplicableCommand; import org.infinispan.commands.remote.ClusteredGetCommand; import org.infinispan.commands.remote.SingleRpcCommand; import org.infinispan.commons.io.ByteBuffer; import org.infinispan.commons.marshall.StreamingMarshaller; import org.infinispan.commons.util.EnumUtil; import org.infinispan.configuration.cache.CacheMode; import org.infinispan.configuration.cache.ConfigurationBuilder; import org.infinispan.factories.ComponentRegistry; import org.infinispan.factories.GlobalComponentRegistry; import org.infinispan.factories.KnownComponentNames; import org.infinispan.manager.EmbeddedCacheManager; import org.infinispan.remoting.transport.Transport; import org.infinispan.remoting.transport.jgroups.CommandAwareRpcDispatcher; import org.infinispan.remoting.transport.jgroups.JGroupsTransport; import org.infinispan.stream.impl.StreamRequestCommand; import org.infinispan.test.AbstractInfinispanTest; import org.infinispan.test.TestingUtil; import org.infinispan.topology.CacheTopologyControlCommand; import org.infinispan.util.ByteString; import org.infinispan.util.concurrent.BlockingTaskAwareExecutorService; import org.infinispan.util.concurrent.BlockingTaskAwareExecutorServiceImpl; import org.jgroups.Address; import org.jgroups.Message; import org.jgroups.blocks.Response; import org.testng.Assert; import org.testng.annotations.AfterClass; import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; /** * Tests the Asynchronous Invocation API and checks if the commands are correctly processed (or JGroups or Infinispan * thread pool) * * @author Pedro Ruivo * @since 5.3 */ @Test(groups = "functional", testName = "remoting.AsynchronousInvocationTest") public class AsynchronousInvocationTest extends AbstractInfinispanTest { private EmbeddedCacheManager cacheManager; private DummyTaskCountExecutorService executorService; private CommandAwareRpcDispatcher commandAwareRpcDispatcher; private Address address; private StreamingMarshaller marshaller; private CommandsFactory commandsFactory; private ReplicableCommand blockingCacheRpcCommand; private ReplicableCommand nonBlockingCacheRpcCommand; private ReplicableCommand blockingNonCacheRpcCommand; private ReplicableCommand nonBlockingNonCacheRpcCommand; private ReplicableCommand blockingSingleRpcCommand; private ReplicableCommand nonBlockingSingleRpcCommand; private static ReplicableCommand mockReplicableCommand(boolean blocking) throws Throwable { ReplicableCommand mock = mock(ReplicableCommand.class); when(mock.canBlock()).thenReturn(blocking); doReturn(null).when(mock).invokeAsync(); return mock; } @BeforeClass public void setUp() throws Throwable { executorService = new DummyTaskCountExecutorService(); final BlockingTaskAwareExecutorService remoteExecutorService = new BlockingTaskAwareExecutorServiceImpl("AsynchronousInvocationTest-Controller", executorService, TIME_SERVICE); ConfigurationBuilder builder = getDefaultCacheConfiguration(false); builder.clustering().cacheMode(CacheMode.DIST_SYNC); cacheManager = createClusteredCacheManager(builder); Cache<Object, Object> cache = cacheManager.getCache(); ByteString cacheName = ByteString.fromString(cache.getName()); Transport transport = extractGlobalComponent(cacheManager, Transport.class); if (transport instanceof JGroupsTransport) { commandAwareRpcDispatcher = ((JGroupsTransport) transport).getCommandAwareRpcDispatcher(); address = ((JGroupsTransport) transport).getChannel().getAddress(); marshaller = TestingUtil.extractGlobalMarshaller(cacheManager); } else { Assert.fail("Expected a JGroups Transport"); } ComponentRegistry registry = cache.getAdvancedCache().getComponentRegistry(); registry.registerComponent(remoteExecutorService, KnownComponentNames.REMOTE_COMMAND_EXECUTOR); registry.rewire(); GlobalComponentRegistry globalRegistry = cache.getCacheManager().getGlobalComponentRegistry(); globalRegistry.registerComponent(remoteExecutorService, KnownComponentNames.REMOTE_COMMAND_EXECUTOR); globalRegistry.rewire(); commandsFactory = extractCommandsFactory(cache); ReplicableCommand nonBlockingReplicableCommand = mockReplicableCommand(false); ReplicableCommand blockingReplicableCommand = mockReplicableCommand(true); //populate commands blockingCacheRpcCommand = new StreamRequestCommand<>(cacheName); nonBlockingCacheRpcCommand = new ClusteredGetCommand("key", cacheName, EnumUtil.EMPTY_BIT_SET); blockingNonCacheRpcCommand = new CacheTopologyControlCommand(null, CacheTopologyControlCommand.Type.POLICY_GET_STATUS, null, 0); //the GetKeyValueCommand is not replicated, but I only need a command that returns false in canBlock() nonBlockingNonCacheRpcCommand = new ClusteredGetCommand("key", cacheName, EnumUtil.EMPTY_BIT_SET); blockingSingleRpcCommand = new SingleRpcCommand(cacheName, blockingReplicableCommand); nonBlockingSingleRpcCommand = new SingleRpcCommand(cacheName, nonBlockingReplicableCommand); } @AfterClass public void tearDown() { if (cacheManager != null) { cacheManager.getGlobalComponentRegistry().getComponent(ExecutorService.class, KnownComponentNames.REMOTE_COMMAND_EXECUTOR).shutdownNow(); cacheManager.stop(); } } public void testCommands() { //if some of these tests fails, we need to pick another command to make the assertions true Assert.assertTrue(blockingCacheRpcCommand.canBlock()); Assert.assertTrue(blockingNonCacheRpcCommand.canBlock()); Assert.assertTrue(blockingSingleRpcCommand.canBlock()); Assert.assertFalse(nonBlockingCacheRpcCommand.canBlock()); Assert.assertFalse(nonBlockingNonCacheRpcCommand.canBlock()); Assert.assertFalse(nonBlockingSingleRpcCommand.canBlock()); } public void testCacheRpcCommands() throws Exception { assertDispatchForCommand(blockingCacheRpcCommand, true); assertDispatchForCommand(nonBlockingCacheRpcCommand, false); } public void testSingleRpcCommand() throws Exception { assertDispatchForCommand(blockingSingleRpcCommand, true); assertDispatchForCommand(nonBlockingSingleRpcCommand, false); } public void testNonCacheRpcCommands() throws Exception { assertDispatchForCommand(blockingNonCacheRpcCommand, true); assertDispatchForCommand(nonBlockingNonCacheRpcCommand, false); } private void assertDispatchForCommand(ReplicableCommand command, boolean expected) throws Exception { log.debugf("Testing " + command.getClass().getCanonicalName()); commandsFactory.initializeReplicableCommand(command, true); Message oobRequest = serialize(command, true, address); if (oobRequest == null) { log.debugf("Don't test " + command.getClass() + ". it is not Serializable"); return; } executorService.reset(); CountDownLatchResponse response = new CountDownLatchResponse(); commandAwareRpcDispatcher.handle(oobRequest, response); response.await(30, TimeUnit.SECONDS); Assert.assertEquals(executorService.hasExecutedCommand, expected, "Command " + command.getClass() + " dispatched wrongly."); Message nonOobRequest = serialize(command, false, address); if (nonOobRequest == null) { log.debugf("Don't test " + command.getClass() + ". it is not Serializable"); return; } executorService.reset(); response = new CountDownLatchResponse(); commandAwareRpcDispatcher.handle(nonOobRequest, response); response.await(30, TimeUnit.SECONDS); Assert.assertFalse(executorService.hasExecutedCommand, "Command " + command.getClass() + " dispatched wrongly."); } private Message serialize(ReplicableCommand command, boolean oob, Address from) { ByteBuffer buffer; try { buffer = marshaller.objectToBuffer(command); } catch (Exception e) { //ignore, it will not be replicated return null; } Message message = new Message(null, buffer.getBuf(), buffer.getOffset(), buffer.getLength()); message.setFlag(Message.Flag.NO_TOTAL_ORDER); if (oob) { message.setFlag(Message.Flag.OOB); } message.src(from); return message; } private class DummyTaskCountExecutorService extends AbstractExecutorService { private volatile boolean hasExecutedCommand; @Override public void execute(Runnable command) { hasExecutedCommand = true; command.run(); } public void reset() { hasExecutedCommand = false; } @Override public void shutdown() { //no-op } @Override public List<Runnable> shutdownNow() { return Collections.emptyList(); //no-op } @Override public boolean isShutdown() { return false; //no-op } @Override public boolean isTerminated() { return false; //no-op } @Override public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException { return false; //no-op } } private static class CountDownLatchResponse implements Response { private final CountDownLatch countDownLatch; private CountDownLatchResponse() { countDownLatch = new CountDownLatch(1); } @Override public void send(Object reply, boolean is_exception) { countDownLatch.countDown(); } @Override public void send(Message reply, boolean is_exception) { countDownLatch.countDown(); } public boolean await(long time, TimeUnit unit) throws InterruptedException { return countDownLatch.await(time, unit); } } }