/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.nifi.cluster.protocol; import java.io.IOException; import java.net.InetSocketAddress; import java.net.Socket; import org.apache.nifi.cluster.protocol.message.ClusterWorkloadRequestMessage; import org.apache.nifi.cluster.protocol.message.ClusterWorkloadResponseMessage; import org.apache.nifi.cluster.protocol.message.ConnectionRequestMessage; import org.apache.nifi.cluster.protocol.message.ConnectionResponseMessage; import org.apache.nifi.cluster.protocol.message.HeartbeatMessage; import org.apache.nifi.cluster.protocol.message.HeartbeatResponseMessage; import org.apache.nifi.cluster.protocol.message.ProtocolMessage; import org.apache.nifi.cluster.protocol.message.ProtocolMessage.MessageType; import org.apache.nifi.io.socket.SocketConfiguration; import org.apache.nifi.io.socket.SocketUtils; public abstract class AbstractNodeProtocolSender implements NodeProtocolSender { private final SocketConfiguration socketConfiguration; private final ProtocolContext<ProtocolMessage> protocolContext; public AbstractNodeProtocolSender(final SocketConfiguration socketConfiguration, final ProtocolContext<ProtocolMessage> protocolContext) { this.socketConfiguration = socketConfiguration; this.protocolContext = protocolContext; } @Override public ConnectionResponseMessage requestConnection(final ConnectionRequestMessage msg) throws ProtocolException, UnknownServiceAddressException { Socket socket = null; try { socket = createSocket(); try { // marshal message to output stream final ProtocolMessageMarshaller<ProtocolMessage> marshaller = protocolContext.createMarshaller(); marshaller.marshal(msg, socket.getOutputStream()); } catch (final IOException ioe) { throw new ProtocolException("Failed marshalling '" + msg.getType() + "' protocol message due to: " + ioe, ioe); } final ProtocolMessage response; try { // unmarshall response and return final ProtocolMessageUnmarshaller<ProtocolMessage> unmarshaller = protocolContext.createUnmarshaller(); response = unmarshaller.unmarshal(socket.getInputStream()); } catch (final IOException ioe) { throw new ProtocolException("Failed unmarshalling '" + MessageType.CONNECTION_RESPONSE + "' protocol message from " + socket.getRemoteSocketAddress() + " due to: " + ioe, ioe); } if (MessageType.CONNECTION_RESPONSE == response.getType()) { final ConnectionResponseMessage connectionResponse = (ConnectionResponseMessage) response; return connectionResponse; } else { throw new ProtocolException("Expected message type '" + MessageType.CONNECTION_RESPONSE + "' but found '" + response.getType() + "'"); } } finally { SocketUtils.closeQuietly(socket); } } @Override public HeartbeatResponseMessage heartbeat(final HeartbeatMessage msg, final String address) throws ProtocolException { final String hostname; final int port; try { final String[] parts = address.split(":"); hostname = parts[0]; port = Integer.parseInt(parts[1]); } catch (final Exception e) { throw new IllegalArgumentException("Cannot send heartbeat to address [" + address + "]. Address must be in <hostname>:<port> format"); } final ProtocolMessage responseMessage = sendProtocolMessage(msg, hostname, port); if (MessageType.HEARTBEAT_RESPONSE == responseMessage.getType()) { return (HeartbeatResponseMessage) responseMessage; } throw new ProtocolException("Expected message type '" + MessageType.HEARTBEAT_RESPONSE + "' but found '" + responseMessage.getType() + "'"); } @Override public ClusterWorkloadResponseMessage clusterWorkload(final ClusterWorkloadRequestMessage msg) throws ProtocolException { final InetSocketAddress serviceAddress; try { serviceAddress = getServiceAddress(); } catch (IOException e) { throw new ProtocolException("Failed to getServiceAddress due to " + e, e); } final ProtocolMessage responseMessage = sendProtocolMessage(msg, serviceAddress.getHostName(), serviceAddress.getPort()); if (MessageType.CLUSTER_WORKLOAD_RESPONSE == responseMessage.getType()) { return (ClusterWorkloadResponseMessage) responseMessage; } throw new ProtocolException("Expected message type '" + MessageType.CLUSTER_WORKLOAD_RESPONSE + "' but found '" + responseMessage.getType() + "'"); } private Socket createSocket() { InetSocketAddress socketAddress = null; try { // create a socket socketAddress = getServiceAddress(); return SocketUtils.createSocket(socketAddress, socketConfiguration); } catch (final IOException ioe) { if (socketAddress == null) { throw new ProtocolException("Failed to create socket due to: " + ioe, ioe); } else { throw new ProtocolException("Failed to create socket to " + socketAddress + " due to: " + ioe, ioe); } } } public SocketConfiguration getSocketConfiguration() { return socketConfiguration; } private ProtocolMessage sendProtocolMessage(final ProtocolMessage msg, final String hostname, final int port) { Socket socket = null; try { try { socket = SocketUtils.createSocket(new InetSocketAddress(hostname, port), socketConfiguration); } catch (IOException e) { throw new ProtocolException("Failed to send message to Cluster Coordinator due to: " + e, e); } try { // marshal message to output stream final ProtocolMessageMarshaller<ProtocolMessage> marshaller = protocolContext.createMarshaller(); marshaller.marshal(msg, socket.getOutputStream()); } catch (final IOException ioe) { throw new ProtocolException("Failed marshalling '" + msg.getType() + "' protocol message due to: " + ioe, ioe); } final ProtocolMessage response; try { // unmarshall response and return final ProtocolMessageUnmarshaller<ProtocolMessage> unmarshaller = protocolContext.createUnmarshaller(); response = unmarshaller.unmarshal(socket.getInputStream()); } catch (final IOException ioe) { throw new ProtocolException("Failed unmarshalling '" + MessageType.CONNECTION_RESPONSE + "' protocol message from " + socket.getRemoteSocketAddress() + " due to: " + ioe, ioe); } return response; } finally { SocketUtils.closeQuietly(socket); } } protected abstract InetSocketAddress getServiceAddress() throws IOException; }