/* * Copyright 2002-2017 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.springframework.integration.mqtt; import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.lessThanOrEqualTo; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertSame; import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.willAnswer; import static org.mockito.BDDMockito.willReturn; import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.util.Properties; import java.util.concurrent.BlockingQueue; import java.util.concurrent.CountDownLatch; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import javax.net.SocketFactory; import org.aopalliance.intercept.MethodInterceptor; import org.apache.commons.logging.Log; import org.eclipse.paho.client.mqttv3.IMqttClient; import org.eclipse.paho.client.mqttv3.IMqttToken; import org.eclipse.paho.client.mqttv3.MqttAsyncClient; import org.eclipse.paho.client.mqttv3.MqttCallback; import org.eclipse.paho.client.mqttv3.MqttClient; import org.eclipse.paho.client.mqttv3.MqttConnectOptions; import org.eclipse.paho.client.mqttv3.MqttDeliveryToken; import org.eclipse.paho.client.mqttv3.MqttException; import org.eclipse.paho.client.mqttv3.MqttMessage; import org.eclipse.paho.client.mqttv3.MqttSecurityException; import org.eclipse.paho.client.mqttv3.MqttToken; import org.eclipse.paho.client.mqttv3.persist.MemoryPersistence; import org.junit.Test; import org.mockito.internal.stubbing.answers.CallsRealMethods; import org.springframework.aop.framework.ProxyFactoryBean; import org.springframework.beans.DirectFieldAccessor; import org.springframework.beans.factory.BeanFactory; import org.springframework.context.ApplicationEventPublisher; import org.springframework.context.annotation.AnnotationConfigApplicationContext; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.integration.channel.NullChannel; import org.springframework.integration.channel.QueueChannel; import org.springframework.integration.handler.MessageProcessor; import org.springframework.integration.mqtt.core.ConsumerStopAction; import org.springframework.integration.mqtt.core.DefaultMqttPahoClientFactory; import org.springframework.integration.mqtt.core.DefaultMqttPahoClientFactory.Will; import org.springframework.integration.mqtt.event.MqttConnectionFailedEvent; import org.springframework.integration.mqtt.event.MqttIntegrationEvent; import org.springframework.integration.mqtt.event.MqttSubscribedEvent; import org.springframework.integration.mqtt.inbound.MqttPahoMessageDrivenChannelAdapter; import org.springframework.integration.mqtt.outbound.MqttPahoMessageHandler; import org.springframework.integration.mqtt.support.DefaultPahoMessageConverter; import org.springframework.integration.test.util.TestUtils; import org.springframework.messaging.Message; import org.springframework.messaging.support.GenericMessage; import org.springframework.scheduling.TaskScheduler; import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler; import org.springframework.util.ReflectionUtils; /** * @author Gary Russell * @author Artem Bilan * @since 4.0 * */ public class MqttAdapterTests { private IMqttToken alwaysComplete; { ProxyFactoryBean pfb = new ProxyFactoryBean(); pfb.addAdvice((MethodInterceptor) invocation -> null); pfb.setInterfaces(IMqttToken.class); this.alwaysComplete = (IMqttToken) pfb.getObject(); } @Test public void testPahoConnectOptions() { DefaultMqttPahoClientFactory factory = new DefaultMqttPahoClientFactory(); factory.setCleanSession(false); factory.setConnectionTimeout(23); factory.setKeepAliveInterval(45); factory.setPassword("pass"); SocketFactory socketFactory = mock(SocketFactory.class); factory.setSocketFactory(socketFactory); Properties props = new Properties(); factory.setSslProperties(props); factory.setUserName("user"); Will will = new Will("foo", "bar".getBytes(), 2, true); factory.setWill(will); MqttConnectOptions options = factory.getConnectionOptions(); assertEquals(23, options.getConnectionTimeout()); assertEquals(45, options.getKeepAliveInterval()); assertEquals("pass", new String(options.getPassword())); assertSame(socketFactory, options.getSocketFactory()); assertSame(props, options.getSSLProperties()); assertEquals("user", options.getUserName()); assertEquals("foo", options.getWillDestination()); assertEquals("bar", new String(options.getWillMessage().getPayload())); assertEquals(2, options.getWillMessage().getQos()); } @Test public void testOutboundOptionsApplied() throws Exception { DefaultMqttPahoClientFactory factory = new DefaultMqttPahoClientFactory(); factory.setCleanSession(false); factory.setConnectionTimeout(23); factory.setKeepAliveInterval(45); factory.setPassword("pass"); MemoryPersistence persistence = new MemoryPersistence(); factory.setPersistence(persistence); final SocketFactory socketFactory = mock(SocketFactory.class); factory.setSocketFactory(socketFactory); final Properties props = new Properties(); factory.setSslProperties(props); factory.setUserName("user"); Will will = new Will("foo", "bar".getBytes(), 2, true); factory.setWill(will); factory = spy(factory); final MqttAsyncClient client = mock(MqttAsyncClient.class); willAnswer(invocation -> client).given(factory).getAsyncClientInstance(anyString(), anyString()); MqttPahoMessageHandler handler = new MqttPahoMessageHandler("foo", "bar", factory); handler.setDefaultTopic("mqtt-foo"); handler.setBeanFactory(mock(BeanFactory.class)); handler.afterPropertiesSet(); handler.start(); final MqttToken token = mock(MqttToken.class); final AtomicBoolean connectCalled = new AtomicBoolean(); willAnswer(invocation -> { MqttConnectOptions options = invocation.getArgument(0); assertEquals(23, options.getConnectionTimeout()); assertEquals(45, options.getKeepAliveInterval()); assertEquals("pass", new String(options.getPassword())); assertSame(socketFactory, options.getSocketFactory()); assertSame(props, options.getSSLProperties()); assertEquals("user", options.getUserName()); assertEquals("foo", options.getWillDestination()); assertEquals("bar", new String(options.getWillMessage().getPayload())); assertEquals(2, options.getWillMessage().getQos()); connectCalled.set(true); return token; }).given(client).connect(any(MqttConnectOptions.class)); willReturn(token).given(client).subscribe(any(String[].class), any(int[].class)); final MqttDeliveryToken deliveryToken = mock(MqttDeliveryToken.class); final AtomicBoolean publishCalled = new AtomicBoolean(); willAnswer(invocation -> { assertEquals("mqtt-foo", invocation.getArguments()[0]); MqttMessage message = invocation.getArgument(1); assertEquals("Hello, world!", new String(message.getPayload())); publishCalled.set(true); return deliveryToken; }).given(client).publish(anyString(), any(MqttMessage.class)); handler.handleMessage(new GenericMessage<String>("Hello, world!")); verify(client, times(1)).connect(any(MqttConnectOptions.class)); assertTrue(connectCalled.get()); } @Test public void testInboundOptionsApplied() throws Exception { DefaultMqttPahoClientFactory factory = new DefaultMqttPahoClientFactory(); factory.setCleanSession(false); factory.setConnectionTimeout(23); factory.setKeepAliveInterval(45); factory.setPassword("pass"); MemoryPersistence persistence = new MemoryPersistence(); factory.setPersistence(persistence); final SocketFactory socketFactory = mock(SocketFactory.class); factory.setSocketFactory(socketFactory); final Properties props = new Properties(); factory.setSslProperties(props); factory.setUserName("user"); Will will = new Will("foo", "bar".getBytes(), 2, true); factory.setWill(will); factory = spy(factory); final IMqttClient client = mock(IMqttClient.class); willAnswer(invocation -> client).given(factory).getClientInstance(anyString(), anyString()); final AtomicBoolean connectCalled = new AtomicBoolean(); final AtomicBoolean failConnection = new AtomicBoolean(); final CountDownLatch waitToFail = new CountDownLatch(1); final CountDownLatch failInProcess = new CountDownLatch(1); final CountDownLatch goodConnection = new CountDownLatch(2); final MqttException reconnectException = new MqttException(MqttException.REASON_CODE_SERVER_CONNECT_ERROR); willAnswer(invocation -> { if (failConnection.get()) { failInProcess.countDown(); waitToFail.await(10, TimeUnit.SECONDS); throw reconnectException; } MqttConnectOptions options = invocation.getArgument(0); assertEquals(23, options.getConnectionTimeout()); assertEquals(45, options.getKeepAliveInterval()); assertEquals("pass", new String(options.getPassword())); assertSame(socketFactory, options.getSocketFactory()); assertSame(props, options.getSSLProperties()); assertEquals("user", options.getUserName()); assertEquals("foo", options.getWillDestination()); assertEquals("bar", new String(options.getWillMessage().getPayload())); assertEquals(2, options.getWillMessage().getQos()); connectCalled.set(true); goodConnection.countDown(); return null; }).given(client).connect(any(MqttConnectOptions.class)); final AtomicReference<MqttCallback> callback = new AtomicReference<MqttCallback>(); willAnswer(invocation -> { callback.set(invocation.getArgument(0)); return null; }).given(client).setCallback(any(MqttCallback.class)); given(client.isConnected()).willReturn(true); MqttPahoMessageDrivenChannelAdapter adapter = new MqttPahoMessageDrivenChannelAdapter("foo", "bar", factory, "baz", "fix"); QueueChannel outputChannel = new QueueChannel(); adapter.setOutputChannel(outputChannel); ThreadPoolTaskScheduler taskScheduler = new ThreadPoolTaskScheduler(); taskScheduler.initialize(); adapter.setTaskScheduler(taskScheduler); adapter.setBeanFactory(mock(BeanFactory.class)); ApplicationEventPublisher applicationEventPublisher = mock(ApplicationEventPublisher.class); final BlockingQueue<MqttIntegrationEvent> events = new LinkedBlockingQueue<MqttIntegrationEvent>(); willAnswer(invocation -> { events.add(invocation.getArgument(0)); return null; }).given(applicationEventPublisher).publishEvent(any(MqttIntegrationEvent.class)); adapter.setApplicationEventPublisher(applicationEventPublisher); adapter.setRecoveryInterval(500); adapter.afterPropertiesSet(); adapter.start(); verify(client, times(1)).connect(any(MqttConnectOptions.class)); assertTrue(connectCalled.get()); MqttMessage message = new MqttMessage("qux".getBytes()); callback.get().messageArrived("baz", message); Message<?> outMessage = outputChannel.receive(0); assertNotNull(outMessage); assertEquals("qux", outMessage.getPayload()); MqttIntegrationEvent event = events.poll(10, TimeUnit.SECONDS); assertThat(event, instanceOf(MqttSubscribedEvent.class)); assertEquals("Connected and subscribed to [baz, fix]", ((MqttSubscribedEvent) event).getMessage()); // lose connection and make first reconnect fail failConnection.set(true); RuntimeException e = new RuntimeException("foo"); adapter.connectionLost(e); event = events.poll(10, TimeUnit.SECONDS); assertThat(event, instanceOf(MqttConnectionFailedEvent.class)); assertSame(event.getCause(), e); assertTrue(failInProcess.await(10, TimeUnit.SECONDS)); waitToFail.countDown(); failConnection.set(false); event = events.poll(10, TimeUnit.SECONDS); assertThat(event, instanceOf(MqttConnectionFailedEvent.class)); assertSame(event.getCause(), reconnectException); // reconnect can now succeed; however, we might have other failures on a slow server (500ms retry). assertTrue(goodConnection.await(10, TimeUnit.SECONDS)); int n = 0; while (!(event instanceof MqttSubscribedEvent) && n++ < 20) { event = events.poll(10, TimeUnit.SECONDS); } assertThat(event, instanceOf(MqttSubscribedEvent.class)); assertEquals("Connected and subscribed to [baz, fix]", ((MqttSubscribedEvent) event).getMessage()); taskScheduler.destroy(); } @Test public void testStopActionDefault() throws Exception { final IMqttClient client = mock(IMqttClient.class); MqttPahoMessageDrivenChannelAdapter adapter = buildAdapter(client, null, null); adapter.start(); adapter.stop(); verifyUnsubscribe(client); } @Test public void testStopActionDefaultNotClean() throws Exception { final IMqttClient client = mock(IMqttClient.class); MqttPahoMessageDrivenChannelAdapter adapter = buildAdapter(client, false, null); adapter.start(); adapter.stop(); verifyNotUnsubscribe(client); } @Test public void testStopActionAlways() throws Exception { final IMqttClient client = mock(IMqttClient.class); MqttPahoMessageDrivenChannelAdapter adapter = buildAdapter(client, false, ConsumerStopAction.UNSUBSCRIBE_ALWAYS); adapter.start(); adapter.stop(); verifyUnsubscribe(client); } @Test public void testStopActionNever() throws Exception { final IMqttClient client = mock(IMqttClient.class); MqttPahoMessageDrivenChannelAdapter adapter = buildAdapter(client, null, ConsumerStopAction.UNSUBSCRIBE_NEVER); adapter.start(); adapter.stop(); verifyNotUnsubscribe(client); } @SuppressWarnings("unchecked") @Test public void testCustomExpressions() { AnnotationConfigApplicationContext ctx = new AnnotationConfigApplicationContext(Config.class); MqttPahoMessageHandler handler = ctx.getBean("handler", MqttPahoMessageHandler.class); GenericMessage<String> message = new GenericMessage<>("foo"); assertEquals("fooTopic", TestUtils.getPropertyValue(handler, "topicProcessor", MessageProcessor.class).processMessage(message)); assertEquals(1, TestUtils.getPropertyValue(handler, "converter.qosProcessor", MessageProcessor.class) .processMessage(message)); assertEquals(Boolean.TRUE, TestUtils.getPropertyValue(handler, "converter.retainedProcessor", MessageProcessor.class) .processMessage(message)); handler = ctx.getBean("handlerWithNullExpressions", MqttPahoMessageHandler.class); assertEquals(1, TestUtils.getPropertyValue(handler, "converter", DefaultPahoMessageConverter.class) .fromMessage(message, null).getQos()); assertEquals(Boolean.TRUE, TestUtils.getPropertyValue(handler, "converter", DefaultPahoMessageConverter.class) .fromMessage(message, null).isRetained()); ctx.close(); } @Test public void testReconnect() throws Exception { final IMqttClient client = mock(IMqttClient.class); MqttPahoMessageDrivenChannelAdapter adapter = buildAdapter(client, null, ConsumerStopAction.UNSUBSCRIBE_NEVER); adapter.setRecoveryInterval(10); Log logger = spy(TestUtils.getPropertyValue(adapter, "logger", Log.class)); new DirectFieldAccessor(adapter).setPropertyValue("logger", logger); given(logger.isDebugEnabled()).willReturn(true); final AtomicInteger attemptingReconnectCount = new AtomicInteger(); willAnswer(i -> { if (attemptingReconnectCount.getAndIncrement() == 0) { adapter.connectionLost(new RuntimeException("while schedule running")); } i.callRealMethod(); return null; }).given(logger).debug("Attempting reconnect"); ThreadPoolTaskScheduler taskScheduler = new ThreadPoolTaskScheduler(); taskScheduler.initialize(); adapter.setTaskScheduler(taskScheduler); adapter.start(); adapter.connectionLost(new RuntimeException("initial")); Thread.sleep(1000); // the following assertion should be equalTo, but leq to protect against a slow CI server assertThat(attemptingReconnectCount.get(), lessThanOrEqualTo(2)); adapter.stop(); taskScheduler.destroy(); } @Test public void testSubscribeFailure() throws Exception { DefaultMqttPahoClientFactory factory = new DefaultMqttPahoClientFactory(); factory.setCleanSession(false); factory.setConnectionTimeout(23); factory.setKeepAliveInterval(45); factory.setPassword("pass"); MemoryPersistence persistence = new MemoryPersistence(); factory.setPersistence(persistence); final SocketFactory socketFactory = mock(SocketFactory.class); factory.setSocketFactory(socketFactory); final Properties props = new Properties(); factory.setSslProperties(props); factory.setUserName("user"); Will will = new Will("foo", "bar".getBytes(), 2, true); factory.setWill(will); factory = spy(factory); MqttAsyncClient aClient = mock(MqttAsyncClient.class); final MqttClient client = mock(MqttClient.class); willAnswer(invocation -> client).given(factory).getClientInstance(anyString(), anyString()); given(client.isConnected()).willReturn(true); new DirectFieldAccessor(client).setPropertyValue("aClient", aClient); willAnswer(new CallsRealMethods()).given(client).connect(any(MqttConnectOptions.class)); willAnswer(new CallsRealMethods()).given(client).subscribe(any(String[].class), any(int[].class)); willReturn(alwaysComplete).given(aClient).connect(any(MqttConnectOptions.class), any(), any()); IMqttToken token = mock(IMqttToken.class); given(token.getGrantedQos()).willReturn(new int[] { 0x80 }); willReturn(token).given(aClient).subscribe(any(String[].class), any(int[].class), any(), any()); MqttPahoMessageDrivenChannelAdapter adapter = new MqttPahoMessageDrivenChannelAdapter("foo", "bar", factory, "baz", "fix"); AtomicReference<Method> method = new AtomicReference<>(); ReflectionUtils.doWithMethods(MqttPahoMessageDrivenChannelAdapter.class, m -> { m.setAccessible(true); method.set(m); }, m -> m.getName().equals("connectAndSubscribe")); assertNotNull(method.get()); try { method.get().invoke(adapter); fail("Expected InvocationTargetException"); } catch (InvocationTargetException e) { assertThat(e.getCause(), instanceOf(MqttException.class)); assertThat(((MqttException) e.getCause()).getReasonCode(), equalTo((int) MqttException.REASON_CODE_SUBSCRIBE_FAILED)); } } @Test public void testDifferentQos() throws Exception { DefaultMqttPahoClientFactory factory = new DefaultMqttPahoClientFactory(); factory.setCleanSession(false); factory.setConnectionTimeout(23); factory.setKeepAliveInterval(45); factory.setPassword("pass"); MemoryPersistence persistence = new MemoryPersistence(); factory.setPersistence(persistence); final SocketFactory socketFactory = mock(SocketFactory.class); factory.setSocketFactory(socketFactory); final Properties props = new Properties(); factory.setSslProperties(props); factory.setUserName("user"); Will will = new Will("foo", "bar".getBytes(), 2, true); factory.setWill(will); factory = spy(factory); MqttAsyncClient aClient = mock(MqttAsyncClient.class); final MqttClient client = mock(MqttClient.class); willAnswer(invocation -> client).given(factory).getClientInstance(anyString(), anyString()); given(client.isConnected()).willReturn(true); new DirectFieldAccessor(client).setPropertyValue("aClient", aClient); willAnswer(new CallsRealMethods()).given(client).connect(any(MqttConnectOptions.class)); willAnswer(new CallsRealMethods()).given(client).subscribe(any(String[].class), any(int[].class)); willReturn(alwaysComplete).given(aClient).connect(any(MqttConnectOptions.class), any(), any()); IMqttToken token = mock(IMqttToken.class); given(token.getGrantedQos()).willReturn(new int[] { 2, 0 }); willReturn(token).given(aClient).subscribe(any(String[].class), any(int[].class), any(), any()); MqttPahoMessageDrivenChannelAdapter adapter = new MqttPahoMessageDrivenChannelAdapter("foo", "bar", factory, "baz", "fix"); AtomicReference<Method> method = new AtomicReference<>(); ReflectionUtils.doWithMethods(MqttPahoMessageDrivenChannelAdapter.class, m -> { m.setAccessible(true); method.set(m); }, m -> m.getName().equals("connectAndSubscribe")); assertNotNull(method.get()); Log logger = spy(TestUtils.getPropertyValue(adapter, "logger", Log.class)); new DirectFieldAccessor(adapter).setPropertyValue("logger", logger); given(logger.isWarnEnabled()).willReturn(true); method.get().invoke(adapter); verify(logger, atLeastOnce()) .warn("Granted QOS different to Requested QOS; topics: [baz, fix] requested: [1, 1] granted: [2, 0]"); verify(client).setTimeToWait(30_000L); } private MqttPahoMessageDrivenChannelAdapter buildAdapter(final IMqttClient client, Boolean cleanSession, ConsumerStopAction action) throws MqttException, MqttSecurityException { DefaultMqttPahoClientFactory factory = new DefaultMqttPahoClientFactory() { @Override public IMqttClient getClientInstance(String uri, String clientId) throws MqttException { return client; } }; factory.setServerURIs("tcp://localhost:1883"); if (cleanSession != null) { factory.setCleanSession(cleanSession); } if (action != null) { factory.setConsumerStopAction(action); } given(client.isConnected()).willReturn(true); MqttPahoMessageDrivenChannelAdapter adapter = new MqttPahoMessageDrivenChannelAdapter("client", factory, "foo"); adapter.setApplicationEventPublisher(mock(ApplicationEventPublisher.class)); adapter.setOutputChannel(new NullChannel()); adapter.setTaskScheduler(mock(TaskScheduler.class)); adapter.afterPropertiesSet(); return adapter; } private void verifyUnsubscribe(IMqttClient client) throws Exception { verify(client).connect(any(MqttConnectOptions.class)); verify(client).subscribe(any(String[].class), any(int[].class)); verify(client).unsubscribe(any(String[].class)); verify(client).disconnectForcibly(anyLong()); } private void verifyNotUnsubscribe(IMqttClient client) throws Exception { verify(client).connect(any(MqttConnectOptions.class)); verify(client).subscribe(any(String[].class), any(int[].class)); verify(client, never()).unsubscribe(any(String[].class)); verify(client).disconnectForcibly(anyLong()); } @Configuration public static class Config { @Bean public MqttPahoMessageHandler handler() { MqttPahoMessageHandler handler = new MqttPahoMessageHandler("tcp://localhost:1883", "bar"); handler.setTopicExpressionString("@topic"); handler.setQosExpressionString("@qos"); handler.setRetainedExpressionString("@retained"); return handler; } @Bean public String topic() { return "fooTopic"; } @Bean public Integer qos() { return 1; } @Bean public Boolean retained() { return true; } @Bean public MqttPahoMessageHandler handlerWithNullExpressions() { MqttPahoMessageHandler handler = new MqttPahoMessageHandler("tcp://localhost:1883", "bar"); handler.setDefaultQos(1); handler.setQosExpressionString("null"); handler.setDefaultRetained(true); handler.setRetainedExpressionString("null"); return handler; } } }