/** * Copyright 2014 Comcast Cable Communications Management, LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.comcast.viper.flume2storm.utility.forwarder; import java.io.IOException; import java.net.BindException; import java.net.InetSocketAddress; import java.nio.ByteBuffer; import java.nio.channels.SelectionKey; import java.nio.channels.Selector; import java.nio.channels.ServerSocketChannel; import java.nio.channels.SocketChannel; import java.nio.channels.spi.SelectorProvider; import java.util.Iterator; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.slf4j.MDC; /** * Simple implementation of the {@link TCPForwarder} based on Java NIO library * that link 1 input socket to 1 output socket. We synchronize outside * operations (start/stop/freeze/resume) in order to avoid potential armful * state changes. */ class TCPForwarderImpl implements TCPForwarder { protected static final Logger LOG = LoggerFactory.getLogger(TCPForwarder.class); private static final String MAIN_THREAD_NAME = "TCPForwarderAcceptThread"; private static final long TERMINATION_TIMEOUT = 10000; private static final int MAX_TCP_PACKET_SIZE = 64 * 1024; final TCPForwarderConfig configuration; final AtomicReference<MyState> state; MainThread mainThread; final AtomicInteger connectionDelay; final AtomicInteger clientSendDelay; final AtomicInteger serverSendDelay; enum MyState { STOPPED, STARTING, STARTED, FROZEN, STOPPING; } class MainThread extends Thread { final ExecutorService executor; private final ConcurrentLinkedQueue<SocketForwarder> socketForwarders; private Selector selector; private ServerSocketChannel serverSocketChannel; public MainThread() { setName(MAIN_THREAD_NAME); setDaemon(true); executor = Executors.newFixedThreadPool(configuration.getMaxWorkerThread()); socketForwarders = new ConcurrentLinkedQueue<SocketForwarder>(); } protected void wakeup() { if (selector != null) selector.wakeup(); } protected void onSocketForwarderTerminated(final SocketForwarder sf) { socketForwarders.remove(sf); } @Override public void run() { try { LOG.debug("Thread started"); open(configuration.getListenAddress(), configuration.getInputPort()); setState(MyState.STARTED); boolean terminate = false; while (!terminate) { switch (state.get()) { case FROZEN: try { Thread.sleep(50); } catch (final InterruptedException e) { LOG.debug("Interrupted"); } break; case STOPPING: case STOPPED: terminate = true; break; case STARTED: LOG.trace("Waiting for new connection...."); if (selector.select(100) > 0) { LOG.trace("Processing potential connection event..."); final Iterator<SelectionKey> i = selector.selectedKeys().iterator(); while (i.hasNext()) { final SelectionKey sk = i.next(); i.remove(); if (!sk.isValid()) { LOG.warn("Skipping invalid event"); } else if (sk.isAcceptable()) { handleConnection(sk); } } } break; case STARTING: assert false : "This is unlikely..."; break; default: throw new AssertionError("Forgetting state?"); } } LOG.debug("Exiting main thead loop..."); } catch (final BindException e) { LOG.error("Failed to bind socket: " + e.getLocalizedMessage(), e); } catch (final Exception e) { LOG.error("Processing exception: " + e.getMessage(), e); } finally { close(); setState(MyState.STOPPED); LOG.info("Thread terminated"); } } private void open(final String serverAddress, final int serverPort) throws IOException { final String logServerAddress = serverAddress == null ? "0.0.0.0" : serverAddress; LOG.debug("Opening server socket on {} TCP port {}...", logServerAddress, serverPort); selector = SelectorProvider.provider().openSelector(); serverSocketChannel = ServerSocketChannel.open(); serverSocketChannel.configureBlocking(false); InetSocketAddress isa = null; if (serverAddress == null) { isa = new InetSocketAddress(serverPort); } else { isa = new InetSocketAddress(serverAddress, serverPort); } serverSocketChannel.socket().bind(isa); serverSocketChannel.register(selector, SelectionKey.OP_ACCEPT); LOG.info("Listening on {} TCP port {}", logServerAddress, serverPort); } private void close() { // Closing server socket channel if (serverSocketChannel != null) { try { LOG.debug("Closing server socket channel..."); serverSocketChannel.close(); } catch (final IOException e) { LOG.error("Failed to close server socket channel", e); } } if (selector != null) { try { LOG.debug("Closing server socket selector..."); selector.close(); } catch (final IOException e) { LOG.error("Failed to close server socket selector", e); } } // Waking up all the SocketForwarder. Each of them should terminate // and remove themselves from the list executor.shutdown(); if (!socketForwarders.isEmpty()) { LOG.debug("Shutting down SocketForwarders..."); for (final SocketForwarder sf : socketForwarders) { sf.wakeup(); } } LOG.debug("Waiting for executor termination..."); try { executor.awaitTermination(TERMINATION_TIMEOUT, TimeUnit.MILLISECONDS); } catch (final InterruptedException e) { // Nothing to see here } socketForwarders.clear(); LOG.debug("Executor terminated"); } @SuppressWarnings("resource") private void handleConnection(final SelectionKey sk) { // Delay while connecting final int d = connectionDelay.get(); if (d > 0) { LOG.trace("Waiting {} ms before accepting connection...", d); try { Thread.sleep(d); } catch (final InterruptedException e) { LOG.info("Connection wait interrupted"); } } // Received connection attempt. Trying to connect output server SocketChannel outputSocketChannel = null; try { LOG.debug("Received connection event, opening socket to output server {}:{}...", configuration.getOutputServer(), configuration.getOutputPort()); outputSocketChannel = SocketChannel.open(); outputSocketChannel.configureBlocking(true); final InetSocketAddress outputAddress = new InetSocketAddress(configuration.getOutputServer(), configuration.getOutputPort()); outputSocketChannel.connect(outputAddress); outputSocketChannel.configureBlocking(false); LOG.info("Connected output server: {}", getFullAddress(outputSocketChannel, Direction.OUTGOING)); } catch (final IOException e) { LOG.error("Failed to connect output server: " + e.getMessage(), e); if (outputSocketChannel != null) { LOG.debug("Closing output connection to server..."); try { outputSocketChannel.close(); } catch (final IOException e2) { LOG.warn("Failed to close output connection to server: " + e.getMessage(), e); } finally { outputSocketChannel = null; } } } // Accepting incoming connection SocketChannel inputSocketChannel = null; try { // Programming note: This needs to be done whether or not the // output connection has been established inputSocketChannel = serverSocketChannel.accept(); inputSocketChannel.configureBlocking(false); LOG.info("Accepted connection from {}", inputSocketChannel.socket().getRemoteSocketAddress()); } catch (final Exception e) { LOG.error("Failed to accept incoming connection: " + e.getMessage(), e); } // server socket failed - closing input connection if (outputSocketChannel == null && inputSocketChannel != null) { try { inputSocketChannel.close(); } catch (final IOException e) { LOG.error( "Failed to close input connection following output connection establishment failure:" + e.getMessage(), e); } } // Linking the 2 sockets if (inputSocketChannel != null && outputSocketChannel != null) { try { LOG.debug("Linking input {} and output {} channels...", getFullAddress(inputSocketChannel, Direction.INCOMING), getFullAddress(outputSocketChannel, Direction.OUTGOING)); final SocketForwarder inputSocketForwarder = new SocketForwarder(inputSocketChannel, Direction.INCOMING); socketForwarders.add(inputSocketForwarder); final SocketForwarder outputSocketForwarder = new SocketForwarder(outputSocketChannel, Direction.OUTGOING); socketForwarders.add(outputSocketForwarder); inputSocketForwarder.bindTo(outputSocketForwarder); outputSocketForwarder.bindTo(inputSocketForwarder); executor.execute(inputSocketForwarder); executor.execute(outputSocketForwarder); LOG.info("Linked input {} and output {} channels", getFullAddress(inputSocketChannel, Direction.INCOMING), getFullAddress(outputSocketChannel, Direction.OUTGOING)); } catch (final Exception e) { LOG.error("Failed to link input and output channels: " + e.getMessage(), e); } } } } private enum Direction { INCOMING, OUTGOING } protected static String getFullAddress(final SocketChannel channel, final Direction direction) { try { String origin, destination; if (direction == Direction.INCOMING) { origin = channel.socket().getRemoteSocketAddress().toString(); destination = channel.socket().getLocalSocketAddress().toString(); } else { origin = channel.socket().getLocalSocketAddress().toString(); destination = channel.socket().getRemoteSocketAddress().toString(); } return new StringBuilder("'").append(origin).append("->").append(destination).append("'").toString(); } catch (Exception e) { return "[Unknown]"; } } /** * This reads from a channel and forward what is received to the channel bound * to it */ class SocketForwarder implements Runnable { final Selector selector; final SocketChannel socketChannel; private final AtomicReference<SocketForwarder> boundTo; private final AtomicInteger delay; private final AtomicBoolean terminate; private final Direction direction; public SocketForwarder(final SocketChannel socketChannel, final Direction type) throws IOException { assert socketChannel != null; MDC.put("ip", getFullAddress(socketChannel, type)); selector = SelectorProvider.provider().openSelector(); this.socketChannel = socketChannel; this.delay = type == Direction.INCOMING ? clientSendDelay : serverSendDelay; boundTo = new AtomicReference<SocketForwarder>(); terminate = new AtomicBoolean(false); this.direction = type; } protected void wakeup() { selector.wakeup(); } protected void bindTo(final SocketForwarder other) { boundTo.set(other); } /** * @return The {@link SocketForwarder} is was bound to */ protected SocketForwarder unbind() { return boundTo.getAndSet(null); } /** * @see java.lang.Runnable#run() */ @Override public void run() { Thread.currentThread().setName(getFullAddress(socketChannel, direction)); try { socketChannel.register(this.selector, SelectionKey.OP_READ); while (!terminate.get()) { switch (state.get()) { case FROZEN: try { Thread.sleep(100); } catch (final InterruptedException e) { LOG.debug("Interrupted"); } break; case STOPPING: case STOPPED: terminate.set(true); break; case STARTED: LOG.trace("Waiting for socket event...."); int nbKeys = selector.select(50); if (nbKeys == 0) { LOG.trace("Forwarder has no keys... continuing"); } else { LOG.trace("Processing potential socket event..."); final Iterator<SelectionKey> i = selector.selectedKeys().iterator(); while (i.hasNext()) { final SelectionKey sk = i.next(); i.remove(); if (!sk.isValid()) { LOG.warn("Skipping invalid selection key"); } else if (sk.isReadable()) { doRead(sk); } } } break; case STARTING: assert false : "This is unlikely..."; break; default: throw new AssertionError("Forgetting state?"); } } } catch (final Exception e) { LOG.error("Processing exception: " + e.getMessage(), e); } finally { close(); LOG.info("Task terminated"); } } private void doRead(final SelectionKey sk) { try { final ByteBuffer readBuffer = ByteBuffer.allocate(MAX_TCP_PACKET_SIZE); if (socketChannel.read(readBuffer) < 0) { disconnect(); } else { readBuffer.flip(); if (boundTo != null) { boundTo.get().send(readBuffer, delay.get()); } } } catch (final Exception e) { LOG.error("Failed to read socket: " + e.getMessage(), e); } } protected void send(final ByteBuffer toSend, final int sendDelay) { mainThread.executor.execute(new Runnable() { @Override public void run() { Thread.currentThread().setName(socketChannel.socket().getRemoteSocketAddress() + "-sender"); try { // Delay while sending (i.e. forwarding data) if (sendDelay > 0) { Thread.sleep(sendDelay); } // Programming note: Taking a shortcut assuming the // socket is always ready to write to if (socketChannel.isOpen() && socketChannel.isConnected()) { socketChannel.write(toSend); selector.wakeup(); } } catch (final Exception e) { LOG.error("Failed to send data: " + e.getMessage(), e); } } }); } protected void disconnect() { terminate.set(true); // Closing socket LOG.debug("Socket {} disconnected. Closing channel", getFullAddress(socketChannel, direction)); try { socketChannel.close(); LOG.info("Channel for socket {} closed", socketChannel.socket().getRemoteSocketAddress()); } catch (final Exception e) { LOG.error("Failed to close channel for socket {}: {}", getFullAddress(socketChannel, direction), e.getMessage()); } } private void close() { LOG.debug("Closing..."); try { disconnect(); final SocketForwarder other = unbind(); if (other != null) { other.unbind(); other.disconnect(); } } catch (final Exception e) { LOG.error("Closing SocketForwarder failed: " + e.getMessage(), e); } try { selector.close(); } catch (final IOException e) { LOG.error("Closing selector failed: " + e.getMessage(), e); } mainThread.onSocketForwarderTerminated(this); } } /** * @param config * The {@link TCPForwarder} configuration */ public TCPForwarderImpl(final TCPForwarderConfig config) { configuration = TCPForwarderConfig.copyOf(config); state = new AtomicReference<MyState>(MyState.STOPPED); connectionDelay = new AtomicInteger(0); clientSendDelay = new AtomicInteger(0); serverSendDelay = new AtomicInteger(0); } /** * @see com.comcast.viper.flume2storm.utility.forwarder.test.forwarder.TCPForwarder#start() */ @Override public synchronized void start() { if (state.get() != MyState.STOPPED) { LOG.warn("Invalid state while trying to start: {}", state.get()); return; } LOG.debug("Starting..."); setState(MyState.STARTING); mainThread = new MainThread(); mainThread.start(); LOG.info("Start requested"); } /** * @see com.comcast.viper.flume2storm.utility.forwarder.test.forwarder.TCPForwarder#stop() */ @Override public synchronized void stop() { if (state.get() == MyState.STOPPED || state.get() == MyState.STOPPING) { LOG.warn("Invalid state while trying to stop: {}", state.get()); return; } LOG.debug("Stopping..."); setState(MyState.STOPPING); if (mainThread != null) { mainThread.wakeup(); try { mainThread.join(TERMINATION_TIMEOUT); } catch (final InterruptedException e) { LOG.warn("Interrupted while waiting for {} termination", MAIN_THREAD_NAME); } } setState(MyState.STOPPED); LOG.info("Stopped"); } /** * @see com.comcast.viper.flume2storm.utility.forwarder.test.forwarder.TCPForwarder#freeze() */ @Override public synchronized void freeze() { switch (state.get()) { case FROZEN: LOG.warn("Already frozen"); return; case STOPPED: case STOPPING: LOG.warn("Invalid state while trying to freeze: {}", state.get()); return; default: LOG.info("Freezing..."); setState(MyState.FROZEN); if (mainThread != null) { mainThread.wakeup(); } break; } } /** * @see com.comcast.viper.flume2storm.utility.forwarder.test.forwarder.TCPForwarder#resume() */ @Override public synchronized void resume() { switch (state.get()) { case FROZEN: LOG.info("Resuming..."); setState(MyState.STARTED); return; default: LOG.warn("Invalid state while trying to resume: {}", state.get()); break; } } protected void setState(final MyState newState) { final MyState oldState = state.getAndSet(newState); LOG.debug("State change: {} -> {}", oldState, newState); } /** * @see com.comcast.viper.flume2storm.utility.forwarder.test.forwarder.TCPForwarder#isActive() */ @Override public boolean isActive() { return state.get() == MyState.STARTED; } /** * @see com.comcast.viper.flume2storm.utility.forwarder.test.forwarder.TCPForwarder#isFrozen() */ @Override public boolean isFrozen() { return state.get() == MyState.FROZEN; } // // Delay management // /** * @see com.comcast.viper.flume2storm.utility.forwarder.test.forwarder.TCPForwarder#resetDelay() */ @Override public void resetDelay() { connectionDelay.set(0); clientSendDelay.set(0); serverSendDelay.set(0); } /** * @see com.comcast.viper.flume2storm.utility.forwarder.test.forwarder.TCPForwarder#getConnectionDelay() */ @Override public int getConnectionDelay() { return connectionDelay.get(); } /** * @see com.comcast.viper.flume2storm.utility.forwarder.test.forwarder.TCPForwarder#setConnectionDelay(int) */ @Override public TCPForwarder setConnectionDelay(final int delay) { assert delay >= 0 : "Invalid value while setting connection delay (must be positive)"; connectionDelay.set(delay); return this; } /** * @see com.comcast.viper.flume2storm.utility.forwarder.test.forwarder.TCPForwarder#getClientSendDelay() */ @Override public int getClientSendDelay() { return clientSendDelay.get(); } /** * @see com.comcast.viper.flume2storm.utility.forwarder.test.forwarder.TCPForwarder#setClientSendDelay(int) */ @Override public TCPForwarder setClientSendDelay(final int delay) { assert delay >= 0 : "Invalid value while setting client send delay (must be positive)"; clientSendDelay.set(delay); return this; } /** * @see com.comcast.viper.flume2storm.utility.forwarder.test.forwarder.TCPForwarder#getServerSendDelay() */ @Override public int getServerSendDelay() { return serverSendDelay.get(); } /** * @see com.comcast.viper.flume2storm.utility.forwarder.test.forwarder.TCPForwarder#setServerSendDelay(int) */ @Override public TCPForwarder setServerSendDelay(final int delay) { assert delay >= 0 : "Invalid value while setting server send delay (must be positive)"; serverSendDelay.set(delay); return this; } }