/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.nifi.processors.websocket; import org.apache.nifi.processor.ProcessSessionFactory; import org.apache.nifi.processor.Processor; import org.apache.nifi.processor.Relationship; import org.apache.nifi.provenance.ProvenanceEventRecord; import org.apache.nifi.provenance.ProvenanceEventType; import org.apache.nifi.util.MockFlowFile; import org.apache.nifi.util.MockProcessSession; import org.apache.nifi.util.SharedSessionState; import org.apache.nifi.util.TestRunner; import org.apache.nifi.util.TestRunners; import org.apache.nifi.websocket.AbstractWebSocketSession; import org.apache.nifi.websocket.WebSocketMessage; import org.apache.nifi.websocket.WebSocketServerService; import org.apache.nifi.websocket.WebSocketSession; import org.junit.Test; import java.net.InetSocketAddress; import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; import static org.apache.nifi.processors.websocket.WebSocketProcessorAttributes.ATTR_WS_CS_ID; import static org.apache.nifi.processors.websocket.WebSocketProcessorAttributes.ATTR_WS_ENDPOINT_ID; import static org.apache.nifi.processors.websocket.WebSocketProcessorAttributes.ATTR_WS_LOCAL_ADDRESS; import static org.apache.nifi.processors.websocket.WebSocketProcessorAttributes.ATTR_WS_MESSAGE_TYPE; import static org.apache.nifi.processors.websocket.WebSocketProcessorAttributes.ATTR_WS_REMOTE_ADDRESS; import static org.apache.nifi.processors.websocket.WebSocketProcessorAttributes.ATTR_WS_SESSION_ID; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.Matchers.eq; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.when; public class TestListenWebSocket { protected void assertFlowFile(WebSocketSession webSocketSession, String serviceId, String endpointId, MockFlowFile ff, WebSocketMessage.Type messageType) { assertEquals(serviceId, ff.getAttribute(ATTR_WS_CS_ID)); assertEquals(webSocketSession.getSessionId(), ff.getAttribute(ATTR_WS_SESSION_ID)); assertEquals(endpointId, ff.getAttribute(ATTR_WS_ENDPOINT_ID)); assertEquals(webSocketSession.getLocalAddress().toString(), ff.getAttribute(ATTR_WS_LOCAL_ADDRESS)); assertEquals(webSocketSession.getRemoteAddress().toString(), ff.getAttribute(ATTR_WS_REMOTE_ADDRESS)); assertEquals(messageType != null ? messageType.name() : null, ff.getAttribute(ATTR_WS_MESSAGE_TYPE)); } protected Map<Relationship, List<MockFlowFile>> getAllTransferredFlowFiles(final Collection<MockProcessSession> processSessions, final Processor processor) { final Map<Relationship, List<MockFlowFile>> flowFiles = new HashMap<>(); processSessions.forEach(session -> { processor.getRelationships().forEach(rel -> { List<MockFlowFile> relFlowFiles = flowFiles.get(rel); if (relFlowFiles == null) { relFlowFiles = new ArrayList<>(); flowFiles.put(rel, relFlowFiles); } relFlowFiles.addAll(session.getFlowFilesForRelationship(rel)); }); }); return flowFiles; } @Test public void testValidationError() throws Exception { final TestRunner runner = TestRunners.newTestRunner(ListenWebSocket.class); final WebSocketServerService service = mock(WebSocketServerService.class); final String serviceId = "ws-service"; final String endpointId = "test"; when(service.getIdentifier()).thenReturn(serviceId); runner.addControllerService(serviceId, service); runner.enableControllerService(service); runner.setProperty(ListenWebSocket.PROP_WEBSOCKET_SERVER_SERVICE, serviceId); runner.setProperty(ListenWebSocket.PROP_SERVER_URL_PATH, endpointId); try { runner.run(); fail("Should fail with validation error."); } catch (AssertionError e) { assertTrue(e.toString().contains("'server-url-path' is invalid because Must starts with")); } } @Test public void testSuccess() throws Exception { final TestRunner runner = TestRunners.newTestRunner(ListenWebSocket.class); final ListenWebSocket processor = (ListenWebSocket)runner.getProcessor(); final SharedSessionState sharedSessionState = new SharedSessionState(processor, new AtomicLong(0)); // Use this custom session factory implementation so that createdSessions can be read from test case, // because MockSessionFactory doesn't expose it. final Set<MockProcessSession> createdSessions = new HashSet<>(); final ProcessSessionFactory sessionFactory = () -> { final MockProcessSession session = new MockProcessSession(sharedSessionState, processor); createdSessions.add(session); return session; }; final WebSocketServerService service = mock(WebSocketServerService.class); final WebSocketSession webSocketSession = spy(AbstractWebSocketSession.class); when(webSocketSession.getSessionId()).thenReturn("ws-session-id"); when(webSocketSession.getLocalAddress()).thenReturn(new InetSocketAddress("localhost", 12345)); when(webSocketSession.getRemoteAddress()).thenReturn(new InetSocketAddress("example.com", 80)); final String serviceId = "ws-service"; final String endpointId = "/test"; final String textMessageReceived = "message from server."; final AtomicReference<Boolean> registered = new AtomicReference<>(false); when(service.getIdentifier()).thenReturn(serviceId); doAnswer(invocation -> { registered.set(true); processor.connected(webSocketSession); // Two times. processor.consume(webSocketSession, textMessageReceived); processor.consume(webSocketSession, textMessageReceived); // Three times. final byte[] binaryMessage = textMessageReceived.getBytes(); processor.consume(webSocketSession, binaryMessage, 0, binaryMessage.length); processor.consume(webSocketSession, binaryMessage, 0, binaryMessage.length); processor.consume(webSocketSession, binaryMessage, 0, binaryMessage.length); return null; }).when(service).registerProcessor(endpointId, processor); doAnswer(invocation -> registered.get()) .when(service).isProcessorRegistered(eq(endpointId), eq(processor)); doAnswer(invocation -> { registered.set(false); return null; }).when(service).deregisterProcessor(eq(endpointId), eq(processor)); runner.addControllerService(serviceId, service); runner.enableControllerService(service); runner.setProperty(ListenWebSocket.PROP_WEBSOCKET_SERVER_SERVICE, serviceId); runner.setProperty(ListenWebSocket.PROP_SERVER_URL_PATH, endpointId); processor.onTrigger(runner.getProcessContext(), sessionFactory); Map<Relationship, List<MockFlowFile>> transferredFlowFiles = getAllTransferredFlowFiles(createdSessions, processor); List<MockFlowFile> connectedFlowFiles = transferredFlowFiles.get(AbstractWebSocketGatewayProcessor.REL_CONNECTED); assertEquals(1, connectedFlowFiles.size()); connectedFlowFiles.forEach(ff -> { assertFlowFile(webSocketSession, serviceId, endpointId, ff, null); }); List<MockFlowFile> textFlowFiles = transferredFlowFiles.get(AbstractWebSocketGatewayProcessor.REL_MESSAGE_TEXT); assertEquals(2, textFlowFiles.size()); textFlowFiles.forEach(ff -> { assertFlowFile(webSocketSession, serviceId, endpointId, ff, WebSocketMessage.Type.TEXT); }); List<MockFlowFile> binaryFlowFiles = transferredFlowFiles.get(AbstractWebSocketGatewayProcessor.REL_MESSAGE_BINARY); assertEquals(3, binaryFlowFiles.size()); binaryFlowFiles.forEach(ff -> { assertFlowFile(webSocketSession, serviceId, endpointId, ff, WebSocketMessage.Type.BINARY); }); final List<ProvenanceEventRecord> provenanceEvents = sharedSessionState.getProvenanceEvents(); assertEquals(6, provenanceEvents.size()); assertTrue(provenanceEvents.stream().allMatch(event -> ProvenanceEventType.RECEIVE.equals(event.getEventType()))); runner.clearTransferState(); runner.clearProvenanceEvents(); createdSessions.clear(); assertEquals(0, createdSessions.size()); // Simulate that the processor has started, and it get's triggered again processor.onTrigger(runner.getProcessContext(), sessionFactory); assertEquals("No session should be created", 0, createdSessions.size()); // Simulate that the processor is stopped. processor.onStopped(runner.getProcessContext()); assertEquals("No session should be created", 0, createdSessions.size()); // Simulate that the processor is restarted. // And the mock service will emit consume msg events. processor.onTrigger(runner.getProcessContext(), sessionFactory); assertEquals("Processor should register it with the service again", 6, createdSessions.size()); } }