package com.github.sdbg.debug.core.internal.forwarder; import java.io.IOException; import java.net.InetSocketAddress; import java.nio.ByteBuffer; import java.nio.channels.ByteChannel; import java.nio.channels.SelectableChannel; import java.nio.channels.SelectionKey; import java.nio.channels.SocketChannel; import java.util.Arrays; import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.logging.Level; import java.util.logging.Logger; public class HostReversePortForwarder extends ReversePortForwarder { public static class Forward { private String host; private int port; private int devicePort; public Forward(String host, int port, int devicePort) { this.host = host; this.port = port; this.devicePort = devicePort; } public int getDevicePort() { return devicePort; } public String getHost() { return host; } public int getPort() { return port; } } private Map<Integer, Forward> forwards = new HashMap<Integer, Forward>(); private Object mainMonitor = new Object(); private boolean stopRequest; private Thread thread; public HostReversePortForwarder(Forward... forwards) { this(Arrays.asList(forwards)); } public HostReversePortForwarder(List<Forward> forwards) { super(Logger.getLogger(HostReversePortForwarder.class.getName())); for (Forward forward : forwards) { this.forwards.put(forward.getDevicePort(), forward); } } public void connect(String commandHost, int commandPort) throws IOException { // Create a new non-blocking socket channel commandChannel = SocketChannel.open(commandHost != null ? new InetSocketAddress( commandHost, commandPort) : new InetSocketAddress(commandPort)); try { ByteBuffer helloCmdBuff = ByteBuffer.allocate(1); int read = commandChannel.read(helloCmdBuff); if (read == -1 || helloCmdBuff.hasRemaining()) { throw new IOException("Unexpected"); } else { helloCmdBuff.flip(); if (helloCmdBuff.get() != CMD_HELLO) { throw new IOException("Unexpected"); } } } catch (IOException e) { try { commandChannel.close(); } catch (IOException e2) { } commandChannel = null; throw e; } ((SelectableChannel) commandChannel).configureBlocking(false); } public boolean isConnected() { return commandChannel != null; } public void start() throws IOException { if (thread != null) { throw new IOException("Already started"); } try { init(); } catch (IOException e) { done(); throw e; } thread = new Thread(new Runnable() { @Override public void run() { try { HostReversePortForwarder.this.run(); } catch (IOException e) { logger.log(Level.SEVERE, "PROTOCOL ERROR: " + e.getMessage(), e); } } }, "Host Reverse Port Forwarder"); thread.start(); } public void stop() { if (thread != null) { try { synchronized (mainMonitor) { if (!stopRequest) { stopRequest = true; selector.wakeup(); try { mainMonitor.wait(); } catch (InterruptedException e) { } } } try { thread.join(); } catch (InterruptedException e) { } thread = null; } finally { done(); } } else if (commandChannel != null) { try { commandChannel.close(); } catch (IOException e) { } commandChannel = null; } } @Override protected void done() { if (commandChannel != null) { try { commandChannel.close(); } catch (IOException e) { } commandChannel = null; } super.done(); forwards.clear(); } @Override protected void init() throws IOException { if (commandChannel == null) { throw new IOException("Command channel is closed"); } super.init(); ((SelectableChannel) commandChannel).register(selector, SelectionKey.OP_READ | SelectionKey.OP_WRITE); } @Override protected boolean processCommand(byte cmd, ByteBuffer commandBuffer) throws IOException { if (cmd == CMD_OPEN_CHANNEL) { if (commandBuffer.remaining() >= 8) { int devicePort = commandBuffer.getInt(); Forward forward = forwards.get(devicePort); int tunnelId = commandBuffer.getInt(); Tunnel tunnel = createTunnel(tunnelId); try { registerLeftChannel(tunnelId, openChannel(forward.getHost(), forward.getPort())); registerRightChannel( tunnelId, openChannel("localhost", ((SocketChannel) commandChannel).socket().getPort())); // TODO XXX FIXME tunnel.getLeftToRight().put(CMD_OPEN_CHANNEL_ACK); tunnel.getLeftToRight().putInt(tunnelId); if (!tunnel.spoolLeftToRight(selector)) { closeTunnel(tunnelId); } } catch (IOException e) { logger.log(Level.INFO, "Spooling error: " + e.getMessage()); closeTunnel(tunnelId); commandWriteBuffer.put(CMD_OPEN_CHANNEL_FAIL); commandWriteBuffer.putInt(tunnelId); writeCommand(); } return true; } else { return false; } } else { return super.processCommand(cmd, commandBuffer); } } private ByteChannel openChannel(String host, int port) throws IOException { SocketChannel channel = SocketChannel.open(new InetSocketAddress(host, port)); channel.configureBlocking(false); return channel; } private void run() throws IOException { try { while (true) { // Wait for an event one of the registered channels selector.select(); synchronized (mainMonitor) { if (stopRequest) { break; } } for (Iterator<SelectionKey> selectedKeys = selector.selectedKeys().iterator(); selectedKeys.hasNext();) { SelectionKey key = selectedKeys.next(); selectedKeys.remove(); if (key.isValid()) { processKey(key); } } } } finally { synchronized (mainMonitor) { stopRequest = true; mainMonitor.notify(); } } } }