/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.nifi.cluster.protocol.impl; import java.io.IOException; import java.io.InputStream; import java.net.Socket; import java.security.cert.CertificateException; import java.util.Collection; import java.util.Collections; import java.util.UUID; import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.TimeUnit; import org.apache.nifi.cluster.protocol.NodeIdentifier; import org.apache.nifi.cluster.protocol.ProtocolContext; import org.apache.nifi.cluster.protocol.ProtocolException; import org.apache.nifi.cluster.protocol.ProtocolHandler; import org.apache.nifi.cluster.protocol.ProtocolListener; import org.apache.nifi.cluster.protocol.ProtocolMessageMarshaller; import org.apache.nifi.cluster.protocol.ProtocolMessageUnmarshaller; import org.apache.nifi.cluster.protocol.message.ConnectionRequestMessage; import org.apache.nifi.cluster.protocol.message.DisconnectMessage; import org.apache.nifi.cluster.protocol.message.FlowRequestMessage; import org.apache.nifi.cluster.protocol.message.HeartbeatMessage; import org.apache.nifi.cluster.protocol.message.ProtocolMessage; import org.apache.nifi.cluster.protocol.message.ReconnectionRequestMessage; import org.apache.nifi.events.BulletinFactory; import org.apache.nifi.io.socket.ServerSocketConfiguration; import org.apache.nifi.io.socket.SocketListener; import org.apache.nifi.reporting.Bulletin; import org.apache.nifi.reporting.BulletinRepository; import org.apache.nifi.security.util.CertificateUtils; import org.apache.nifi.util.StopWatch; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * Implements a listener for protocol messages sent over unicast socket. * */ public class SocketProtocolListener extends SocketListener implements ProtocolListener { private static final Logger logger = LoggerFactory.getLogger(SocketProtocolListener.class); private final ProtocolContext<ProtocolMessage> protocolContext; private final Collection<ProtocolHandler> handlers = new CopyOnWriteArrayList<>(); private volatile BulletinRepository bulletinRepository; public SocketProtocolListener( final int numThreads, final int port, final ServerSocketConfiguration configuration, final ProtocolContext<ProtocolMessage> protocolContext) { super(numThreads, port, configuration); if (protocolContext == null) { throw new IllegalArgumentException("Protocol Context may not be null."); } this.protocolContext = protocolContext; } @Override public void setBulletinRepository(final BulletinRepository bulletinRepository) { this.bulletinRepository = bulletinRepository; } @Override public void start() throws IOException { if (super.isRunning()) { throw new IllegalStateException("Instance is already started."); } super.start(); } @Override public void stop() throws IOException { if (super.isRunning() == false) { throw new IOException("Instance is already stopped."); } super.stop(); } @Override public Collection<ProtocolHandler> getHandlers() { return Collections.unmodifiableCollection(handlers); } @Override public void addHandler(final ProtocolHandler handler) { if (handler == null) { throw new NullPointerException("Protocol handler may not be null."); } handlers.add(handler); } @Override public boolean removeHandler(final ProtocolHandler handler) { return handlers.remove(handler); } @Override public void dispatchRequest(final Socket socket) { byte[] receivedMessage = null; String hostname = null; final int maxMsgBuffer = 1024 * 1024; // don't buffer more than 1 MB of the message try { final StopWatch stopWatch = new StopWatch(true); hostname = socket.getInetAddress().getHostName(); final String requestId = UUID.randomUUID().toString(); logger.debug("Received request {} from {}", requestId, hostname); String requestorDn = getRequestorDN(socket); // unmarshall message final ProtocolMessageUnmarshaller<ProtocolMessage> unmarshaller = protocolContext.createUnmarshaller(); final InputStream inStream = socket.getInputStream(); final CopyingInputStream copyingInputStream = new CopyingInputStream(inStream, maxMsgBuffer); // don't copy more than 1 MB final ProtocolMessage request; try { request = unmarshaller.unmarshal(copyingInputStream); } finally { receivedMessage = copyingInputStream.getBytesRead(); if (logger.isDebugEnabled()) { logger.debug("Received message: " + new String(receivedMessage)); } } request.setRequestorDN(requestorDn); // dispatch message to handler ProtocolHandler desiredHandler = null; final Collection<ProtocolHandler> handlers = getHandlers(); for (final ProtocolHandler handler : handlers) { if (handler.canHandle(request)) { desiredHandler = handler; break; } } // if no handler found, throw exception; otherwise handle request if (desiredHandler == null) { logger.error("Received request of type {} but none of the following Protocol Handlers were able to process the request: {}", request.getType(), handlers); throw new ProtocolException("No handler assigned to handle message type: " + request.getType()); } else { final ProtocolMessage response = desiredHandler.handle(request); if (response != null) { try { logger.debug("Sending response for request {}", requestId); // marshal message to output stream final ProtocolMessageMarshaller<ProtocolMessage> marshaller = protocolContext.createMarshaller(); marshaller.marshal(response, socket.getOutputStream()); } catch (final IOException ioe) { throw new ProtocolException("Failed marshalling protocol message in response to message type: " + request.getType() + " due to " + ioe, ioe); } } } stopWatch.stop(); final NodeIdentifier nodeId = getNodeIdentifier(request); final String from = nodeId == null ? hostname : nodeId.toString(); logger.info("Finished processing request {} (type={}, length={} bytes) from {} in {} millis", requestId, request.getType(), receivedMessage.length, from, stopWatch.getDuration(TimeUnit.MILLISECONDS)); } catch (final IOException | ProtocolException e) { logger.warn("Failed processing protocol message from " + hostname + " due to " + e, e); if (bulletinRepository != null) { final Bulletin bulletin = BulletinFactory.createBulletin("Clustering", "WARNING", String.format("Failed to process protocol message from %s due to: %s", hostname, e.toString())); bulletinRepository.addBulletin(bulletin); } } } private NodeIdentifier getNodeIdentifier(final ProtocolMessage message) { if (message == null) { return null; } switch (message.getType()) { case CONNECTION_REQUEST: return ((ConnectionRequestMessage) message).getConnectionRequest().getProposedNodeIdentifier(); case HEARTBEAT: return ((HeartbeatMessage) message).getHeartbeat().getNodeIdentifier(); case DISCONNECTION_REQUEST: return ((DisconnectMessage) message).getNodeId(); case FLOW_REQUEST: return ((FlowRequestMessage) message).getNodeId(); case RECONNECTION_REQUEST: return ((ReconnectionRequestMessage) message).getNodeId(); default: return null; } } private String getRequestorDN(Socket socket) { try { return CertificateUtils.extractPeerDNFromSSLSocket(socket); } catch (CertificateException e) { throw new ProtocolException(e); } } }