package com.github.sdbg.debug.core.internal.forwarder; import java.io.IOException; import java.net.InetSocketAddress; import java.nio.ByteBuffer; import java.nio.channels.ByteChannel; import java.nio.channels.SelectionKey; import java.nio.channels.ServerSocketChannel; import java.nio.channels.SocketChannel; import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.Map; import java.util.logging.ConsoleHandler; import java.util.logging.Handler; import java.util.logging.Level; import java.util.logging.LogManager; import java.util.logging.Logger; public class DeviceReversePortForwarder extends ReversePortForwarder { public static void main(String[] args) throws IOException { for (Handler handler : LogManager.getLogManager().getLogger("").getHandlers()) { if (handler instanceof ConsoleHandler) { ConsoleHandler ch = (ConsoleHandler) handler; ch.setFormatter(new ShellFormatter()); break; } } int commandPort = Integer.parseInt(args[0]); int[] ports = new int[args.length - 1]; for (int i = 0; i < ports.length; i++) { ports[i] = Integer.parseInt(args[i + 1]); } new DeviceReversePortForwarder(commandPort, ports).run(); } private int commandPort; private int[] ports; private ServerSocketChannel commandServerChannel; private Collection<ServerSocketChannel> serverChannels = new HashSet<ServerSocketChannel>(); private Map<ByteChannel, ByteBuffer> pendingChannels = new HashMap<ByteChannel, ByteBuffer>(); public DeviceReversePortForwarder(int commandPort, int[] ports) { super(Logger.getLogger(DeviceReversePortForwarder.class.getName())); this.commandPort = commandPort; this.ports = ports; } public void run() throws IOException { init(); try { do { // Wait for an event on one of the registered channels if (commandChannel != null) { selector.select(); } else { logger.info("Waiting for command connection on port " + commandServerChannel.socket().getLocalPort() + "..."); // Command channel not established yet. Wait for 10 seconds and then timeout selector.select(10000); } for (Iterator<SelectionKey> selectedKeys = selector.selectedKeys().iterator(); selectedKeys.hasNext();) { SelectionKey key = selectedKeys.next(); selectedKeys.remove(); if (key.isValid()) { processKey(key); } } } while (commandChannel != null); logger.info("Command connection timed out"); } finally { done(); } } @Override protected void done() { for (ByteChannel channel : new ArrayList<ByteChannel>(pendingChannels.keySet())) { close(channel); } pendingChannels.clear(); if (commandChannel != null) { try { commandChannel.close(); } catch (IOException e) { } commandChannel = null; } if (commandServerChannel != null) { try { commandServerChannel.close(); } catch (IOException e) { } commandServerChannel = null; } super.done(); for (ServerSocketChannel channel : new ArrayList<ServerSocketChannel>(serverChannels)) { close(channel); } } @Override protected void init() throws IOException { super.init(); // Create a new non-blocking server socket channel createCommandServerSocketChannel(); for (int port : ports) { createServerSocketChannel(port); } } @Override protected boolean processCommand(byte cmd, ByteBuffer commandBuffer) throws IOException { if (cmd == CMD_OPEN_CHANNEL_FAIL) { int tunnelId = commandBuffer.getInt(); closeTunnel(tunnelId); logger.info("Opening tunnel " + tunnelId + " failed"); return true; } else { return super.processCommand(cmd, commandBuffer); } } @Override protected void processKey(SelectionKey key) throws IOException { if (key.isAcceptable()) { if (key.channel() == commandServerChannel) { acceptCommand(key); } else { if (commandChannel == null) { throw new IOException("Unexpected"); } accept(key); } } else { if (commandChannel == null) { throw new IOException("Unexpected"); } if (key.isReadable() && pendingChannels.containsKey(key.channel())) { ByteChannel channel = (ByteChannel) key.channel(); try { ByteBuffer readBuffer = pendingChannels.get(channel); int read = channel.read(readBuffer); if (read == -1) { throw new IOException("Pending channel closed"); } if (readBuffer.position() >= 5) { readBuffer.flip(); try { byte cmd = readBuffer.get(); if (cmd == CMD_OPEN_CHANNEL_ACK) { int tunnelId = readBuffer.getInt(); try { registerLeftChannel(tunnelId, channel); pendingChannels.remove(channel); } catch (IOException e) { logger.log(Level.INFO, "Spooling error: " + e.getMessage(), e); closeTunnel(tunnelId); } } else { throw new IOException("Unknown command"); } } finally { readBuffer.compact(); } } } catch (IOException e) { logger.log(Level.SEVERE, "PROTOCOL ERROR: " + e.getMessage(), e); close(channel); } } else { super.processKey(key); } } } private void accept(SelectionKey key) throws IOException { // For an accept to be pending the channel must be a server socket channel. ServerSocketChannel serverSocketChannel = (ServerSocketChannel) key.channel(); int tunnelId = createTunnel(); logger.fine("New incoming connection to device port " + serverSocketChannel.socket().getLocalPort() + "; tunnel " + tunnelId); try { // Accept the connection and make it non-blocking SocketChannel socketChannel = serverSocketChannel.accept(); socketChannel.configureBlocking(false); registerRightChannel(tunnelId, socketChannel); } catch (IOException e) { logger.log(Level.SEVERE, "PROTOCOL ERROR: " + e.getMessage(), e); closeTunnel(tunnelId); } commandWriteBuffer.put(CMD_OPEN_CHANNEL); commandWriteBuffer.putInt(serverSocketChannel.socket().getLocalPort()); commandWriteBuffer.putInt(tunnelId); writeCommand(); } private void acceptCommand(SelectionKey key) throws IOException { // Accept the connection and make it non-blocking SocketChannel socketChannel = commandServerChannel.accept(); socketChannel.configureBlocking(false); if (commandChannel == null) { socketChannel.register(selector, SelectionKey.OP_READ | SelectionKey.OP_WRITE); commandChannel = socketChannel; commandWriteBuffer.put(CMD_HELLO); writeCommand(); // Now that the command channel is opened we can start accepting connections from the other channels for (ServerSocketChannel channel : serverChannels) { channel.register(selector, SelectionKey.OP_ACCEPT); } logger.info("Command connection established"); } else { socketChannel.register(selector, SelectionKey.OP_READ); pendingChannels.put(socketChannel, ByteBuffer.allocate(5)); } } private void close(ByteChannel channel) { pendingChannels.remove(channel); try { channel.close(); } catch (IOException e) { // Best effort } } private void close(ServerSocketChannel channel) { try { channel.close(); } catch (IOException e) { // Best effort } serverChannels.remove(channel); } private void createCommandServerSocketChannel() throws IOException { ServerSocketChannel channel = ServerSocketChannel.open(); channel.configureBlocking(false); channel.socket().bind(new InetSocketAddress(commandPort)); channel.register(selector, SelectionKey.OP_ACCEPT); commandServerChannel = channel; } private void createServerSocketChannel(int port) throws IOException { ServerSocketChannel channel = ServerSocketChannel.open(); serverChannels.add(channel); channel.configureBlocking(false); channel.socket().bind(new InetSocketAddress(port)); logger.info("Opened reverse proxy port: " + port); } }