/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.nifi.processor.util.listen.dispatcher; import org.apache.commons.io.IOUtils; import org.apache.nifi.logging.ComponentLog; import org.apache.nifi.processor.util.listen.event.Event; import org.apache.nifi.processor.util.listen.event.EventFactory; import org.apache.nifi.processor.util.listen.handler.ChannelHandlerFactory; import org.apache.nifi.remote.io.socket.ssl.SSLSocketChannel; import org.apache.nifi.security.util.SslContextFactory; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLEngine; import java.io.IOException; import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.StandardSocketOptions; import java.nio.ByteBuffer; import java.nio.channels.Channel; import java.nio.channels.SelectionKey; import java.nio.channels.Selector; import java.nio.channels.ServerSocketChannel; import java.nio.channels.SocketChannel; import java.nio.charset.Charset; import java.util.Iterator; import java.util.concurrent.BlockingQueue; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; /** * Accepts Socket connections on the given port and creates a handler for each connection to * be executed by a thread pool. */ public class SocketChannelDispatcher<E extends Event<SocketChannel>> implements AsyncChannelDispatcher { private final EventFactory<E> eventFactory; private final ChannelHandlerFactory<E, AsyncChannelDispatcher> handlerFactory; private final BlockingQueue<ByteBuffer> bufferPool; private final BlockingQueue<E> events; private final ComponentLog logger; private final int maxConnections; private final SSLContext sslContext; private final SslContextFactory.ClientAuth clientAuth; private final Charset charset; private ExecutorService executor; private volatile boolean stopped = false; private Selector selector; private final BlockingQueue<SelectionKey> keyQueue; private final AtomicInteger currentConnections = new AtomicInteger(0); public SocketChannelDispatcher(final EventFactory<E> eventFactory, final ChannelHandlerFactory<E, AsyncChannelDispatcher> handlerFactory, final BlockingQueue<ByteBuffer> bufferPool, final BlockingQueue<E> events, final ComponentLog logger, final int maxConnections, final SSLContext sslContext, final Charset charset) { this(eventFactory, handlerFactory, bufferPool, events, logger, maxConnections, sslContext, SslContextFactory.ClientAuth.REQUIRED, charset); } public SocketChannelDispatcher(final EventFactory<E> eventFactory, final ChannelHandlerFactory<E, AsyncChannelDispatcher> handlerFactory, final BlockingQueue<ByteBuffer> bufferPool, final BlockingQueue<E> events, final ComponentLog logger, final int maxConnections, final SSLContext sslContext, final SslContextFactory.ClientAuth clientAuth, final Charset charset) { this.eventFactory = eventFactory; this.handlerFactory = handlerFactory; this.bufferPool = bufferPool; this.events = events; this.logger = logger; this.maxConnections = maxConnections; this.keyQueue = new LinkedBlockingQueue<>(maxConnections); this.sslContext = sslContext; this.clientAuth = clientAuth; this.charset = charset; if (bufferPool == null || bufferPool.size() == 0 || bufferPool.size() != maxConnections) { throw new IllegalArgumentException( "A pool of available ByteBuffers equal to the maximum number of connections is required"); } } @Override public void open(final InetAddress nicAddress, final int port, final int maxBufferSize) throws IOException { stopped = false; executor = Executors.newFixedThreadPool(maxConnections); final ServerSocketChannel serverSocketChannel = ServerSocketChannel.open(); serverSocketChannel.configureBlocking(false); if (maxBufferSize > 0) { serverSocketChannel.setOption(StandardSocketOptions.SO_RCVBUF, maxBufferSize); final int actualReceiveBufSize = serverSocketChannel.getOption(StandardSocketOptions.SO_RCVBUF); if (actualReceiveBufSize < maxBufferSize) { logger.warn("Attempted to set Socket Buffer Size to " + maxBufferSize + " bytes but could only set to " + actualReceiveBufSize + "bytes. You may want to consider changing the Operating System's " + "maximum receive buffer"); } } serverSocketChannel.socket().bind(new InetSocketAddress(nicAddress, port)); selector = Selector.open(); serverSocketChannel.register(selector, SelectionKey.OP_ACCEPT); } @Override public void run() { while (!stopped) { try { int selected = selector.select(); // if stopped the selector could already be closed which would result in a ClosedSelectorException if (selected > 0 && !stopped){ Iterator<SelectionKey> selectorKeys = selector.selectedKeys().iterator(); // if stopped we don't want to modify the keys because close() may still be in progress while (selectorKeys.hasNext() && !stopped) { SelectionKey key = selectorKeys.next(); selectorKeys.remove(); if (!key.isValid()){ continue; } if (key.isAcceptable()) { // Handle new connections coming in final ServerSocketChannel channel = (ServerSocketChannel) key.channel(); final SocketChannel socketChannel = channel.accept(); // Check for available connections if (currentConnections.incrementAndGet() > maxConnections){ currentConnections.decrementAndGet(); logger.warn("Rejecting connection from {} because max connections has been met", new Object[]{ socketChannel.getRemoteAddress().toString() }); IOUtils.closeQuietly(socketChannel); continue; } logger.debug("Accepted incoming connection from {}", new Object[]{socketChannel.getRemoteAddress().toString()}); // Set socket to non-blocking, and register with selector socketChannel.configureBlocking(false); SelectionKey readKey = socketChannel.register(selector, SelectionKey.OP_READ); // Prepare the byte buffer for the reads, clear it out ByteBuffer buffer = bufferPool.poll(); buffer.clear(); buffer.mark(); // If we have an SSLContext then create an SSLEngine for the channel SSLSocketChannel sslSocketChannel = null; if (sslContext != null) { final SSLEngine sslEngine = sslContext.createSSLEngine(); sslEngine.setUseClientMode(false); switch (clientAuth) { case REQUIRED: sslEngine.setNeedClientAuth(true); break; case WANT: sslEngine.setWantClientAuth(true); break; case NONE: sslEngine.setNeedClientAuth(false); sslEngine.setWantClientAuth(false); break; } sslSocketChannel = new SSLSocketChannel(sslEngine, socketChannel); } // Attach the buffer and SSLSocketChannel to the key SocketChannelAttachment attachment = new SocketChannelAttachment(buffer, sslSocketChannel); readKey.attach(attachment); } else if (key.isReadable()) { // Clear out the operations the select is interested in until done reading key.interestOps(0); // Create a handler based on the protocol and whether an SSLEngine was provided or not final Runnable handler; if (sslContext != null) { handler = handlerFactory.createSSLHandler(key, this, charset, eventFactory, events, logger); } else { handler = handlerFactory.createHandler(key, this, charset, eventFactory, events, logger); } // run the handler executor.execute(handler); } } } // Add back all idle sockets to the select SelectionKey key; while((key = keyQueue.poll()) != null){ key.interestOps(SelectionKey.OP_READ); } } catch (IOException e) { logger.error("Error accepting connection from SocketChannel", e); } } } @Override public int getPort() { // Return the port for the key listening for accepts for(SelectionKey key : selector.keys()){ if (key.isValid()) { final Channel channel = key.channel(); if (channel instanceof ServerSocketChannel) { return ((ServerSocketChannel)channel).socket().getLocalPort(); } } } return 0; } @Override public void close() { stopped = true; if (selector != null) { selector.wakeup(); } if (executor != null) { executor.shutdown(); try { // Wait a while for existing tasks to terminate if (!executor.awaitTermination(1000L, TimeUnit.MILLISECONDS)) { executor.shutdownNow(); } } catch (InterruptedException ie) { // (Re-)Cancel if current thread also interrupted executor.shutdownNow(); // Preserve interrupt status Thread.currentThread().interrupt(); } } if (selector != null) { synchronized (selector.keys()) { for (SelectionKey key : selector.keys()) { IOUtils.closeQuietly(key.channel()); } } } IOUtils.closeQuietly(selector); } @Override public void completeConnection(SelectionKey key) { // connection is done. Return the buffer to the pool SocketChannelAttachment attachment = (SocketChannelAttachment) key.attachment(); try { bufferPool.put(attachment.getByteBuffer()); } catch (InterruptedException e) { Thread.currentThread().interrupt(); } currentConnections.decrementAndGet(); } @Override public void addBackForSelection(SelectionKey key) { keyQueue.offer(key); selector.wakeup(); } }