/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.nifi.remote.protocol.socket; import org.apache.nifi.remote.Peer; import org.apache.nifi.remote.PeerDescription; import org.apache.nifi.remote.StandardVersionNegotiator; import org.apache.nifi.remote.cluster.ClusterNodeInformation; import org.apache.nifi.remote.cluster.NodeInformation; import org.apache.nifi.remote.io.socket.SocketChannelCommunicationsSession; import org.apache.nifi.remote.io.socket.SocketChannelInput; import org.apache.nifi.remote.io.socket.SocketChannelOutput; import org.apache.nifi.remote.protocol.HandshakeProperties; import org.apache.nifi.remote.protocol.HandshakeProperty; import org.apache.nifi.remote.protocol.Response; import org.apache.nifi.remote.protocol.ResponseCode; import org.apache.nifi.util.NiFiProperties; import org.junit.BeforeClass; import org.junit.Test; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.DataInputStream; import java.io.DataOutputStream; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.util.ArrayList; import java.util.List; import java.util.Optional; import static org.junit.Assert.assertEquals; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.when; public class TestSocketFlowFileServerProtocol { @BeforeClass public static void setup() throws Exception { System.setProperty(NiFiProperties.PROPERTIES_FILE_PATH, "src/test/resources/nifi.properties"); System.setProperty("org.slf4j.simpleLogger.log.org.apache.nifi.remote", "DEBUG"); } private Peer getDefaultPeer(final HandshakeProperties handshakeProperties, final OutputStream outputStream) throws IOException { final PeerDescription description = new PeerDescription("peer-host", 8080, false); final byte[] inputBytes; try (final ByteArrayOutputStream bos = new ByteArrayOutputStream(); final DataOutputStream dos = new DataOutputStream(bos)) { dos.writeUTF(handshakeProperties.getCommsIdentifier()); dos.writeUTF(handshakeProperties.getTransitUriPrefix()); dos.writeInt(1); // num of properties dos.writeUTF(HandshakeProperty.GZIP.name()); dos.writeUTF(String.valueOf(handshakeProperties.isUseGzip())); dos.flush(); inputBytes = bos.toByteArray(); } final InputStream inputStream = new ByteArrayInputStream(inputBytes); final SocketChannelCommunicationsSession commsSession = mock(SocketChannelCommunicationsSession.class); final SocketChannelInput channelInput = mock(SocketChannelInput.class); final SocketChannelOutput channelOutput = mock(SocketChannelOutput.class); when(commsSession.getInput()).thenReturn(channelInput); when(commsSession.getOutput()).thenReturn(channelOutput); when(channelInput.getInputStream()).thenReturn(inputStream); when(channelOutput.getOutputStream()).thenReturn(outputStream); final String peerUrl = "http://peer-host:8080/"; final String clusterUrl = "cluster-url"; return new Peer(description, commsSession, peerUrl, clusterUrl); } private SocketFlowFileServerProtocol getDefaultSocketFlowFileServerProtocol() { final StandardVersionNegotiator versionNegotiator = new StandardVersionNegotiator(5, 4, 3, 2, 1); final SocketFlowFileServerProtocol protocol = spy(new SocketFlowFileServerProtocol()); return protocol; } @Test public void testSendPeerListStandalone() throws Exception { final SocketFlowFileServerProtocol protocol = getDefaultSocketFlowFileServerProtocol(); final Optional<ClusterNodeInformation> clusterNodeInfo = Optional.empty(); final String siteToSiteHostname = "node1.example.com"; final Integer siteToSitePort = 8081; final Integer siteToSiteHttpPort = null; final int apiPort = 8080; final boolean isSiteToSiteSecure = true; final int numOfQueuedFlowFiles = 100; final NodeInformation self = new NodeInformation(siteToSiteHostname, siteToSitePort, siteToSiteHttpPort, apiPort, isSiteToSiteSecure, numOfQueuedFlowFiles); final HandshakeProperties handshakeProperties = new HandshakeProperties(); handshakeProperties.setCommsIdentifier("communication-identifier"); handshakeProperties.setTransitUriPrefix("uri-prefix"); final ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); final Peer peer = getDefaultPeer(handshakeProperties, outputStream); protocol.handshake(peer); protocol.sendPeerList(peer, clusterNodeInfo, self); try (final DataInputStream dis = new DataInputStream(new ByteArrayInputStream(outputStream.toByteArray()))) { final Response handshakeResponse = Response.read(dis); assertEquals(ResponseCode.PROPERTIES_OK, handshakeResponse.getCode()); final int numPeers = dis.readInt(); assertEquals(1, numPeers); assertEquals(siteToSiteHostname, dis.readUTF()); assertEquals(siteToSitePort.intValue(), dis.readInt()); assertEquals(isSiteToSiteSecure, dis.readBoolean()); assertEquals(numOfQueuedFlowFiles, dis.readInt()); } } @Test public void testSendPeerListCluster() throws Exception { final SocketFlowFileServerProtocol protocol = getDefaultSocketFlowFileServerProtocol(); final List<NodeInformation> nodeInfoList = new ArrayList<>(); final ClusterNodeInformation clusterNodeInformation = new ClusterNodeInformation(); clusterNodeInformation.setNodeInformation(nodeInfoList); final Optional<ClusterNodeInformation> clusterNodeInfo = Optional.of(clusterNodeInformation); for (int i = 0; i < 3; i++) { final String siteToSiteHostname = String.format("node%d.example.com", i); final Integer siteToSitePort = 8081; final Integer siteToSiteHttpPort = null; final int apiPort = 8080; final boolean isSiteToSiteSecure = true; final int numOfQueuedFlowFiles = 100 + i; final NodeInformation nodeInformation = new NodeInformation(siteToSiteHostname, siteToSitePort, siteToSiteHttpPort, apiPort, isSiteToSiteSecure, numOfQueuedFlowFiles); nodeInfoList.add(nodeInformation); } final NodeInformation self = nodeInfoList.get(0); final HandshakeProperties handshakeProperties = new HandshakeProperties(); handshakeProperties.setCommsIdentifier("communication-identifier"); handshakeProperties.setTransitUriPrefix("uri-prefix"); final ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); final Peer peer = getDefaultPeer(handshakeProperties, outputStream); protocol.handshake(peer); protocol.sendPeerList(peer, clusterNodeInfo, self); try (final DataInputStream dis = new DataInputStream(new ByteArrayInputStream(outputStream.toByteArray()))) { final Response handshakeResponse = Response.read(dis); assertEquals(ResponseCode.PROPERTIES_OK, handshakeResponse.getCode()); final int numPeers = dis.readInt(); assertEquals(nodeInfoList.size(), numPeers); for (int i = 0; i < nodeInfoList.size(); i++) { final NodeInformation node = nodeInfoList.get(i); assertEquals(node.getSiteToSiteHostname(), dis.readUTF()); assertEquals(node.getSiteToSitePort().intValue(), dis.readInt()); assertEquals(node.isSiteToSiteSecure(), dis.readBoolean()); assertEquals(node.getTotalFlowFiles(), dis.readInt()); } } } }