package com.lambdaworks.redis.protocol; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Fail.fail; import static org.mockito.Matchers.any; import static org.mockito.Mockito.*; import java.io.IOException; import java.util.*; import java.util.concurrent.atomic.AtomicLong; import com.lambdaworks.redis.metrics.DefaultCommandLatencyCollector; import com.lambdaworks.redis.metrics.DefaultCommandLatencyCollectorOptions; import org.apache.logging.log4j.Level; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.core.LoggerContext; import org.apache.logging.log4j.core.config.Configuration; import org.apache.logging.log4j.core.config.LoggerConfig; import org.junit.AfterClass; import org.junit.Before; import org.junit.BeforeClass; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.invocation.InvocationOnMock; import org.mockito.runners.MockitoJUnitRunner; import org.mockito.stubbing.Answer; import org.springframework.test.util.ReflectionTestUtils; import com.lambdaworks.redis.ClientOptions; import com.lambdaworks.redis.ConnectionEvents; import com.lambdaworks.redis.RedisChannelHandler; import com.lambdaworks.redis.RedisException; import com.lambdaworks.redis.codec.Utf8StringCodec; import com.lambdaworks.redis.output.StatusOutput; import com.lambdaworks.redis.resource.ClientResources; import edu.umd.cs.mtc.MultithreadedTestCase; import edu.umd.cs.mtc.TestFramework; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.netty.channel.*; import io.netty.util.concurrent.ImmediateEventExecutor; @RunWith(MockitoJUnitRunner.class) public class CommandHandlerTest { private Queue<RedisCommand<String, String, ?>> q = new ArrayDeque<>(10); private CommandHandler<String, String> sut; private final Command<String, String, String> command = new Command<>(CommandType.APPEND, new StatusOutput<String, String>(new Utf8StringCodec()), null); @Mock private ChannelHandlerContext context; @Mock private Channel channel; @Mock private ByteBufAllocator byteBufAllocator; @Mock private ChannelPipeline pipeline; @Mock private EventLoop eventLoop; @Mock private ClientResources clientResources; @Mock private RedisChannelHandler channelHandler; @BeforeClass public static void beforeClass() { LoggerContext ctx = (LoggerContext) LogManager.getContext(); Configuration config = ctx.getConfiguration(); LoggerConfig loggerConfig = config.getLoggerConfig(CommandHandler.class.getName()); loggerConfig.setLevel(Level.ALL); } @AfterClass public static void afterClass() { LoggerContext ctx = (LoggerContext) LogManager.getContext(); Configuration config = ctx.getConfiguration(); LoggerConfig loggerConfig = config.getLoggerConfig(CommandHandler.class.getName()); loggerConfig.setLevel(null); } @Before public void before() throws Exception { when(context.channel()).thenReturn(channel); when(context.alloc()).thenReturn(byteBufAllocator); when(channel.pipeline()).thenReturn(pipeline); when(channel.eventLoop()).thenReturn(eventLoop); when(eventLoop.submit(any(Runnable.class))).thenAnswer(invocation -> { Runnable r = (Runnable) invocation.getArguments()[0]; r.run(); return null; }); when(clientResources.commandLatencyCollector()).thenReturn(new DefaultCommandLatencyCollector( DefaultCommandLatencyCollectorOptions.create())); when(channel.write(any())).thenAnswer(invocation -> { if (invocation.getArguments()[0] instanceof RedisCommand) { q.add((RedisCommand) invocation.getArguments()[0]); } if (invocation.getArguments()[0] instanceof Collection) { q.addAll((Collection) invocation.getArguments()[0]); } return new DefaultChannelPromise(channel); }); when(channel.writeAndFlush(any())).thenAnswer(invocation -> { if (invocation.getArguments()[0] instanceof RedisCommand) { q.add((RedisCommand) invocation.getArguments()[0]); } if (invocation.getArguments()[0] instanceof Collection) { q.addAll((Collection) invocation.getArguments()[0]); } return new DefaultChannelPromise(channel); }); sut = new CommandHandler<String, String>(ClientOptions.create(), clientResources, q); sut.setRedisChannelHandler(channelHandler); } @Test public void testChannelActive() throws Exception { sut.channelRegistered(context); sut.channelActive(context); verify(pipeline).fireUserEventTriggered(any(ConnectionEvents.Activated.class)); } @Test public void testChannelActiveFailureShouldCancelCommands() throws Exception { ClientOptions clientOptions = ClientOptions.builder().cancelCommandsOnReconnectFailure(true).build(); sut = new CommandHandler<String, String>(clientOptions, clientResources, q); sut.setRedisChannelHandler(channelHandler); sut.channelRegistered(context); sut.write(command); reset(context); when(context.channel()).thenThrow(new RuntimeException()); try { sut.channelActive(context); fail("Missing RuntimeException"); } catch (RuntimeException e) { } assertThat(command.isCancelled()).isTrue(); } @Test public void testChannelActiveWithBufferedAndQueuedCommands() throws Exception { Command<String, String, String> bufferedCommand = new Command<>(CommandType.GET, new StatusOutput<String, String>(new Utf8StringCodec()), null); Command<String, String, String> pingCommand = new Command<>(CommandType.PING, new StatusOutput<String, String>(new Utf8StringCodec()), null); q.add(bufferedCommand); AtomicLong atomicLong = (AtomicLong) ReflectionTestUtils.getField(sut, "writers"); doAnswer(new Answer() { @Override public Object answer(InvocationOnMock invocation) throws Throwable { assertThat(atomicLong.get()).isEqualTo(-1); assertThat(ReflectionTestUtils.getField(sut, "exclusiveLockOwner")).isNotNull(); sut.write(pingCommand); return null; } }).when(channelHandler).activated(); when(channel.isActive()).thenReturn(true); sut.channelRegistered(context); sut.channelActive(context); assertThat(atomicLong.get()).isEqualTo(0); assertThat(ReflectionTestUtils.getField(sut, "exclusiveLockOwner")).isNull(); assertThat(q).containsSequence(pingCommand, bufferedCommand); verify(pipeline).fireUserEventTriggered(any(ConnectionEvents.Activated.class)); } @Test public void testChannelActiveWithBufferedAndQueuedCommandsRetainsOrder() throws Exception { Command<String, String, String> bufferedCommand1 = new Command<>(CommandType.SET, new StatusOutput<String, String>(new Utf8StringCodec()), null); Command<String, String, String> bufferedCommand2 = new Command<>(CommandType.GET, new StatusOutput<String, String>(new Utf8StringCodec()), null); Command<String, String, String> queuedCommand1 = new Command<>(CommandType.PING, new StatusOutput<String, String>(new Utf8StringCodec()), null); Command<String, String, String> queuedCommand2 = new Command<>(CommandType.AUTH, new StatusOutput<String, String>(new Utf8StringCodec()), null); q.add(queuedCommand1); q.add(queuedCommand2); Collection buffer = (Collection) ReflectionTestUtils.getField(sut, "commandBuffer"); buffer.add(bufferedCommand1); buffer.add(bufferedCommand2); reset(channel); when(channel.writeAndFlush(any())).thenAnswer(invocation -> new DefaultChannelPromise(channel)); when(channel.eventLoop()).thenReturn(eventLoop); when(channel.pipeline()).thenReturn(pipeline); sut.channelRegistered(context); sut.channelActive(context); assertThat(q).isEmpty(); assertThat(buffer).isEmpty(); ArgumentCaptor<Object> objectArgumentCaptor = ArgumentCaptor.forClass(Object.class); verify(channel).writeAndFlush(objectArgumentCaptor.capture()); assertThat((Collection) objectArgumentCaptor.getValue()).containsSequence(queuedCommand1, queuedCommand2, bufferedCommand1, bufferedCommand2); } @Test public void testChannelActiveReplayBufferedCommands() throws Exception { Command<String, String, String> bufferedCommand1 = new Command<>(CommandType.SET, new StatusOutput<String, String>(new Utf8StringCodec()), null); Command<String, String, String> bufferedCommand2 = new Command<>(CommandType.GET, new StatusOutput<String, String>(new Utf8StringCodec()), null); Command<String, String, String> queuedCommand1 = new Command<>(CommandType.PING, new StatusOutput<String, String>(new Utf8StringCodec()), null); Command<String, String, String> queuedCommand2 = new Command<>(CommandType.AUTH, new StatusOutput<String, String>(new Utf8StringCodec()), null); q.add(queuedCommand1); q.add(queuedCommand2); Collection buffer = (Collection) ReflectionTestUtils.getField(sut, "commandBuffer"); buffer.add(bufferedCommand1); buffer.add(bufferedCommand2); sut.channelRegistered(context); sut.channelActive(context); assertThat(q).containsSequence(queuedCommand1, queuedCommand2, bufferedCommand1, bufferedCommand2); assertThat(buffer).isEmpty(); } @Test public void testExceptionChannelActive() throws Exception { sut.setState(CommandHandler.LifecycleState.ACTIVE); when(channel.isActive()).thenReturn(true); sut.channelActive(context); sut.exceptionCaught(context, new Exception()); } @Test public void testIOExceptionChannelActive() throws Exception { sut.setState(CommandHandler.LifecycleState.ACTIVE); when(channel.isActive()).thenReturn(true); sut.channelActive(context); sut.exceptionCaught(context, new IOException("Connection timed out")); } @Test public void testWriteChannelDisconnected() throws Exception { when(channel.isActive()).thenReturn(true); sut.channelRegistered(context); sut.channelActive(context); sut.setState(CommandHandler.LifecycleState.DISCONNECTED); sut.write(command); Collection buffer = (Collection) ReflectionTestUtils.getField(sut, "commandBuffer"); assertThat(buffer).containsOnly(command); } @Test(expected = RedisException.class) public void testWriteChannelDisconnectedWithoutReconnect() throws Exception { sut = new CommandHandler<String, String>(ClientOptions.builder().autoReconnect(false).build(), clientResources, q); sut.setRedisChannelHandler(channelHandler); when(channel.isActive()).thenReturn(true); sut.channelRegistered(context); sut.channelActive(context); sut.setState(CommandHandler.LifecycleState.DISCONNECTED); sut.write(command); } @Test public void testExceptionChannelInactive() throws Exception { sut.setState(CommandHandler.LifecycleState.DISCONNECTED); sut.exceptionCaught(context, new Exception()); verify(context, never()).fireExceptionCaught(any(Exception.class)); } @Test public void testExceptionWithQueue() throws Exception { sut.setState(CommandHandler.LifecycleState.ACTIVE); q.clear(); sut.channelActive(context); when(channel.isActive()).thenReturn(true); q.add(command); sut.exceptionCaught(context, new Exception()); assertThat(q).isEmpty(); command.get(); assertThat(ReflectionTestUtils.getField(command, "exception")).isNotNull(); } @Test(expected = RedisException.class) public void testWriteWhenClosed() throws Exception { sut.setState(CommandHandler.LifecycleState.CLOSED); sut.write(command); } @Test public void testExceptionWhenClosed() throws Exception { sut.setState(CommandHandler.LifecycleState.CLOSED); sut.exceptionCaught(context, new Exception()); verifyZeroInteractions(context); } @Test public void isConnectedShouldReportFalseForNOT_CONNECTED() throws Exception { sut.setState(CommandHandler.LifecycleState.NOT_CONNECTED); assertThat(sut.isConnected()).isFalse(); } @Test public void isConnectedShouldReportFalseForREGISTERED() throws Exception { sut.setState(CommandHandler.LifecycleState.REGISTERED); assertThat(sut.isConnected()).isFalse(); } @Test public void isConnectedShouldReportTrueForCONNECTED() throws Exception { sut.setState(CommandHandler.LifecycleState.CONNECTED); assertThat(sut.isConnected()).isTrue(); } @Test public void isConnectedShouldReportTrueForACTIVATING() throws Exception { sut.setState(CommandHandler.LifecycleState.ACTIVATING); assertThat(sut.isConnected()).isTrue(); } @Test public void isConnectedShouldReportTrueForACTIVE() throws Exception { sut.setState(CommandHandler.LifecycleState.ACTIVE); assertThat(sut.isConnected()).isTrue(); } @Test public void isConnectedShouldReportFalseForDISCONNECTED() throws Exception { sut.setState(CommandHandler.LifecycleState.DISCONNECTED); assertThat(sut.isConnected()).isFalse(); } @Test public void isConnectedShouldReportFalseForDEACTIVATING() throws Exception { sut.setState(CommandHandler.LifecycleState.DEACTIVATING); assertThat(sut.isConnected()).isFalse(); } @Test public void isConnectedShouldReportFalseForDEACTIVATED() throws Exception { sut.setState(CommandHandler.LifecycleState.DEACTIVATED); assertThat(sut.isConnected()).isFalse(); } @Test public void isConnectedShouldReportFalseForCLOSED() throws Exception { sut.setState(CommandHandler.LifecycleState.CLOSED); assertThat(sut.isConnected()).isFalse(); } @Test public void shouldNotWriteCancelledCommands() throws Exception { command.cancel(); sut.write(context, command, null); verifyZeroInteractions(context); assertThat((Collection) ReflectionTestUtils.getField(sut, "queue")).isEmpty(); } @Test public void shouldCancelCommandOnQueueSingleFailure() throws Exception { Command<String, String, String> commandMock = mock(Command.class); RuntimeException exception = new RuntimeException(); when(commandMock.getOutput()).thenThrow(exception); ChannelPromise channelPromise = new DefaultChannelPromise(null, ImmediateEventExecutor.INSTANCE); try { sut.write(context, commandMock, channelPromise); fail("Missing RuntimeException"); } catch (RuntimeException e) { assertThat(e).isSameAs(exception); } assertThat((Collection) ReflectionTestUtils.getField(sut, "queue")).isEmpty(); verify(commandMock).completeExceptionally(exception); } @Test public void shouldCancelCommandOnQueueBatchFailure() throws Exception { Command<String, String, String> commandMock = mock(Command.class); RuntimeException exception = new RuntimeException(); when(commandMock.getOutput()).thenThrow(exception); ChannelPromise channelPromise = new DefaultChannelPromise(null, ImmediateEventExecutor.INSTANCE); try { sut.write(context, Arrays.asList(commandMock), channelPromise); fail("Missing RuntimeException"); } catch (RuntimeException e) { assertThat(e).isSameAs(exception); } assertThat((Collection) ReflectionTestUtils.getField(sut, "queue")).isEmpty(); verify(commandMock).completeExceptionally(exception); } @Test public void shouldWriteActiveCommands() throws Exception { sut.write(context, command, null); verify(context).write(command, null); assertThat((Collection) ReflectionTestUtils.getField(sut, "queue")).containsOnly(command); } @Test public void shouldNotWriteCancelledCommandBatch() throws Exception { command.cancel(); sut.write(context, Arrays.asList(command), null); verifyZeroInteractions(context); assertThat((Collection) ReflectionTestUtils.getField(sut, "queue")).isEmpty(); } @Test public void shouldWriteActiveCommandsInBatch() throws Exception { List<Command<String, String, String>> commands = Arrays.asList(command); sut.write(context, commands, null); verify(context).write(commands, null); assertThat((Collection) ReflectionTestUtils.getField(sut, "queue")).containsOnly(command); } @Test public void shouldWriteActiveCommandsInMixedBatch() throws Exception { Command<String, String, String> command2 = new Command<>(CommandType.APPEND, new StatusOutput<String, String>(new Utf8StringCodec()), null); command.cancel(); sut.write(context, Arrays.asList(command, command2), null); ArgumentCaptor<List> captor = ArgumentCaptor.forClass(List.class); verify(context).write(captor.capture(), any()); assertThat(captor.getValue()).containsOnly(command2); assertThat((Collection) ReflectionTestUtils.getField(sut, "queue")).containsOnly(command2); } @Test public void shouldIgnoreNonReadableBuffers() throws Exception { ByteBuf byteBufMock = mock(ByteBuf.class); when(byteBufMock.isReadable()).thenReturn(false); sut.channelRead(context, byteBufMock); verify(byteBufMock, never()).release(); } @Test public void shouldSetLatency() throws Exception { sut.write(context, Arrays.asList(command), null); assertThat(command.sentNs).isNotEqualTo(-1); assertThat(command.firstResponseNs).isEqualTo(-1); } @Test public void testMTCConcurrentWriteThenReset() throws Throwable { TestFramework.runOnce(new MTCConcurrentWriteThenReset(clientResources, q, command)); } @Test public void testMTCConcurrentResetThenWrite() throws Throwable { TestFramework.runOnce(new MTCConcurrentResetThenWrite(clientResources, q, command)); } @Test public void testMTCConcurrentConcurrentWrite() throws Throwable { TestFramework.runOnce(new MTCConcurrentConcurrentWrite(clientResources, q, command)); } /** * Test of concurrent access to locks. write call wins over reset call. */ static class MTCConcurrentWriteThenReset extends MultithreadedTestCase { private final Command<String, String, String> command; private TestableCommandHandler handler; private List<Thread> expectedThreadOrder = Collections.synchronizedList(new ArrayList<>()); private List<Thread> entryThreadOrder = Collections.synchronizedList(new ArrayList<>()); private List<Thread> exitThreadOrder = Collections.synchronizedList(new ArrayList<>()); public MTCConcurrentWriteThenReset(ClientResources clientResources, Queue<RedisCommand<String, String, ?>> queue, Command<String, String, String> command) { this.command = command; handler = new TestableCommandHandler(ClientOptions.create(), clientResources, queue) { @Override protected void incrementWriters() { waitForTick(2); super.incrementWriters(); waitForTick(4); } @Override protected void lockWritersExclusive() { waitForTick(4); super.lockWritersExclusive(); } @Override protected <C extends RedisCommand<String, String, T>, T> void writeToBuffer(C command) { entryThreadOrder.add(Thread.currentThread()); super.writeToBuffer(command); } @Override protected List<RedisCommand<String, String, ?>> prepareReset() { entryThreadOrder.add(Thread.currentThread()); return super.prepareReset(); } @Override protected void unlockWritersExclusive() { exitThreadOrder.add(Thread.currentThread()); super.unlockWritersExclusive(); } @Override protected void decrementWriters() { exitThreadOrder.add(Thread.currentThread()); super.decrementWriters(); } }; } public void thread1() throws InterruptedException { waitForTick(1); expectedThreadOrder.add(Thread.currentThread()); handler.write(command); } public void thread2() throws InterruptedException { waitForTick(3); expectedThreadOrder.add(Thread.currentThread()); handler.reset(); } @Override public void finish() { assertThat(entryThreadOrder).containsExactlyElementsOf(expectedThreadOrder); assertThat(exitThreadOrder).containsExactlyElementsOf(expectedThreadOrder); } } /** * Test of concurrent access to locks. write call wins over flush call. */ static class MTCConcurrentResetThenWrite extends MultithreadedTestCase { private final Command<String, String, String> command; private TestableCommandHandler handler; private List<Thread> expectedThreadOrder = Collections.synchronizedList(new ArrayList<>()); private List<Thread> entryThreadOrder = Collections.synchronizedList(new ArrayList<>()); private List<Thread> exitThreadOrder = Collections.synchronizedList(new ArrayList<>()); public MTCConcurrentResetThenWrite(ClientResources clientResources, Queue<RedisCommand<String, String, ?>> queue, Command<String, String, String> command) { this.command = command; handler = new TestableCommandHandler(ClientOptions.create(), clientResources, queue) { @Override protected void incrementWriters() { waitForTick(4); super.incrementWriters(); } @Override protected void lockWritersExclusive() { waitForTick(2); super.lockWritersExclusive(); waitForTick(4); } @Override protected <C extends RedisCommand<String, String, T>, T> void writeToBuffer(C command) { entryThreadOrder.add(Thread.currentThread()); super.writeToBuffer(command); } @Override protected List<RedisCommand<String, String, ?>> prepareReset() { entryThreadOrder.add(Thread.currentThread()); return super.prepareReset(); } @Override protected void unlockWritersExclusive() { exitThreadOrder.add(Thread.currentThread()); super.unlockWritersExclusive(); } @Override protected void decrementWriters() { exitThreadOrder.add(Thread.currentThread()); super.decrementWriters(); } }; } public void thread1() throws InterruptedException { waitForTick(1); expectedThreadOrder.add(Thread.currentThread()); handler.reset(); } public void thread2() throws InterruptedException { waitForTick(3); expectedThreadOrder.add(Thread.currentThread()); handler.write(command); } @Override public void finish() { assertThat(entryThreadOrder).containsExactlyElementsOf(expectedThreadOrder); assertThat(exitThreadOrder).containsExactlyElementsOf(expectedThreadOrder); } } /** * Test of concurrent access to locks. Two concurrent writes. */ static class MTCConcurrentConcurrentWrite extends MultithreadedTestCase { private final Command<String, String, String> command; private TestableCommandHandler handler; public MTCConcurrentConcurrentWrite(ClientResources clientResources, Queue<RedisCommand<String, String, ?>> queue, Command<String, String, String> command) { this.command = command; handler = new TestableCommandHandler(ClientOptions.create(), clientResources, queue) { @Override protected <C extends RedisCommand<String, String, T>, T> void writeToBuffer(C command) { waitForTick(2); assertThat(writers.get()).isEqualTo(2); waitForTick(3); super.writeToBuffer(command); } }; } public void thread1() throws InterruptedException { waitForTick(1); handler.write(command); } public void thread2() throws InterruptedException { waitForTick(1); handler.write(command); } } static class TestableCommandHandler extends CommandHandler<String, String> { public TestableCommandHandler(ClientOptions clientOptions, ClientResources clientResources, Queue<RedisCommand<String, String, ?>> queue) { super(clientOptions, clientResources, queue); } } }