package com.github.sdbg.debug.core.internal.forwarder;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.ByteChannel;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.SelectableChannel;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.nio.channels.spi.SelectorProvider;
import java.util.HashMap;
import java.util.Map;
import java.util.logging.Level;
import java.util.logging.Logger;
public abstract class ReversePortForwarder {
protected static final byte CMD_HELLO = (byte) 0, CMD_UNKNOWN_OR_UNEXPECTED_COMMAND = (byte) 1,
CMD_OPEN_CHANNEL = (byte) 2, CMD_OPEN_CHANNEL_ACK = (byte) 3,
CMD_OPEN_CHANNEL_FAIL = (byte) 4;
private static final int CMD_MAX_LENGTH = 10;
private static int uuid;
protected Selector selector;
protected ByteChannel commandChannel;
protected ByteBuffer commandReadBuffer, commandWriteBuffer;
private Map<Integer, Tunnel> tunnels = new HashMap<Integer, Tunnel>();
private Map<ByteChannel, Integer> channels = new HashMap<ByteChannel, Integer>();
protected String tracePrefix;
protected final Logger logger;
public ReversePortForwarder(Logger logger) {
this.logger = logger;
}
protected void closeTunnel(int tunnelId) {
Tunnel tunnel = tunnels.remove(tunnelId);
if (tunnel != null) {
channels.remove(tunnel.getLeftChannel());
channels.remove(tunnel.getRightChannel());
tunnel.close();
}
logger.fine("Tunnel " + tunnelId + " closed");
}
protected int createTunnel() throws IOException {
createTunnel(uuid);
return uuid++;
}
protected Tunnel createTunnel(int tunnelId) throws IOException {
if (tunnels.containsKey(tunnelId)) {
throw new IOException("Tunnel with ID " + tunnelId + " already exists");
}
Tunnel tunnel = new Tunnel(logger, Integer.toString(tunnelId));
tunnels.put(tunnelId, tunnel);
logger.fine("Tunnel " + tunnelId + " registered");
return tunnel;
}
protected void done() {
for (Tunnel tunnel : tunnels.values()) {
tunnel.close();
}
tunnels.clear();
channels.clear();
if (selector != null) {
try {
selector.close();
} catch (IOException e) {
// Best effort
}
selector = null;
}
commandReadBuffer = commandWriteBuffer = null;
}
protected Tunnel getTunnel(int tunnelId) throws IOException {
Tunnel tunnel = tunnels.get(tunnelId);
if (tunnel != null) {
return tunnel;
} else {
throw new IOException("Unknown tunnel: " + tunnelId);
}
}
protected int getTunnelId(SelectionKey key) throws IOException {
Integer tunnelId = channels.get(key.channel());
if (tunnelId != null) {
return tunnelId;
} else {
throw new IOException("Unknown channel: " + key.channel());
}
}
protected void init() throws IOException {
// Create a new selector
selector = SelectorProvider.provider().openSelector();
commandReadBuffer = ByteBuffer.allocate(8192);
commandWriteBuffer = ByteBuffer.allocate(8192);
}
protected boolean processCommand(byte cmd, ByteBuffer commandBuffer) throws IOException {
commandWriteBuffer.put(CMD_UNKNOWN_OR_UNEXPECTED_COMMAND);
commandWriteBuffer.put(cmd);
writeCommand();
return true;
}
protected void processKey(SelectionKey key) throws IOException {
// Check what event is available and deal with it
if (key.channel() == commandChannel) {
if (key.isReadable()) {
readCommand();
} else if (key.isWritable()) {
writeCommand();
}
} else if (key.isReadable() || key.isWritable()) {
try {
int tunnelId = getTunnelId(key);
try {
if (!getTunnel(tunnelId).spool(key)) {
closeTunnel(tunnelId);
}
} catch (IOException e) {
logger.log(Level.INFO, "Spooling error for tunnel " + tunnelId + ": " + e.getMessage(), e);
closeTunnel(tunnelId);
}
} catch (IOException e) {
logger.log(Level.SEVERE, "PROTOCOL ERROR: " + e.getMessage(), e);
}
}
}
protected void readCommand() throws IOException {
int read = 0;
boolean commandProcessed = false;
do {
read = commandChannel.read(commandReadBuffer);
if (read == -1) {
// Remote entity shut the socket down cleanly. Do the
// same from our end and cancel the channel.
throw new IOException("Command channel closed");
}
commandProcessed = false;
if (commandReadBuffer.position() > 0) {
ByteBuffer readBuffer = (ByteBuffer) commandReadBuffer.duplicate().flip();
byte cmd = readBuffer.get();
if (processCommand(cmd, readBuffer)) {
readBuffer.compact();
commandReadBuffer = readBuffer;
commandProcessed = true;
}
}
} while (read > 0 || commandProcessed);
SelectionKey key = ((SelectableChannel) commandChannel).keyFor(selector);
if (commandReadBuffer.remaining() < CMD_MAX_LENGTH) {
key.interestOps(key.interestOps() & ~SelectionKey.OP_READ);
} else {
key.interestOps(key.interestOps() | SelectionKey.OP_READ);
}
}
protected void registerLeftChannel(int tunnelId, ByteChannel leftChannel) throws IOException {
Tunnel tunnel = getTunnel(tunnelId);
if (tunnel.getLeftChannel() != null) {
throw new IOException("Left channel of tunnel " + tunnelId + " is already registered");
} else {
logger.fine("Left channel of tunnel " + tunnelId + " registered: " + leftChannel);
tunnel.setLeftChannel(leftChannel);
channels.put(leftChannel, tunnelId);
channelRegistered(tunnelId, tunnel);
}
}
protected void registerRightChannel(int tunnelId, ByteChannel rightChannel) throws IOException {
Tunnel tunnel = getTunnel(tunnelId);
if (tunnel.getRightChannel() != null) {
throw new IOException("Right channel of tunnel " + tunnelId + " is already registered");
} else {
logger.fine("Right channel of tunnel " + tunnelId + " registered: " + rightChannel);
tunnel.setRightChannel(rightChannel);
channels.put(rightChannel, tunnelId);
channelRegistered(tunnelId, tunnel);
}
}
protected void writeCommand() throws IOException {
while (commandWriteBuffer.position() > 0) {
commandWriteBuffer.flip();
try {
if (commandChannel.write(commandWriteBuffer) <= 0) {
break;
}
} finally {
commandWriteBuffer.compact();
}
}
SelectionKey key = ((SelectableChannel) commandChannel).keyFor(selector);
if (commandWriteBuffer.position() > 0) {
key.interestOps(key.interestOps() | SelectionKey.OP_WRITE);
} else {
key.interestOps(key.interestOps() & ~SelectionKey.OP_WRITE);
}
}
private void channelRegistered(int tunnelId, Tunnel tunnel) throws ClosedChannelException {
if (tunnel.getLeftChannel() != null && tunnel.getRightChannel() != null) {
((SelectableChannel) tunnel.getLeftChannel()).register(selector, SelectionKey.OP_READ
| SelectionKey.OP_WRITE);
((SelectableChannel) tunnel.getRightChannel()).register(selector, SelectionKey.OP_READ
| SelectionKey.OP_WRITE);
logger.fine("Tunnel " + tunnelId + " ready for work");
}
}
}