package io.fathom.cloud.ssh;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.net.SocketAddress;
import java.net.SocketException;
import java.net.SocketImpl;
import java.net.UnknownHostException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.google.common.net.InetAddresses;
/**
* A Socket implementation that uses SSH 'direct-tcpip' forwarding.
*
* By implementing Socket, we avoid opening a local port for forwarding.
*
*/
public class SshTunnelSocket extends Socket {
private static final Logger log = LoggerFactory.getLogger(SshTunnelSocket.class);
private final SshTunnelSocketImpl impl;
private SshTunnelSocket(SshTunnelSocketImpl impl) throws IOException {
super(impl);
this.impl = impl;
}
public SshTunnelSocket(SshContext sshContext, InetSocketAddress localSocketAddress) throws IOException {
this(new SshTunnelSocketImpl(sshContext, localSocketAddress));
}
public SshTunnelSocket(SshContext sshContext) throws IOException {
this(sshContext, generateLocalSocketAddress());
}
private static InetSocketAddress generateLocalSocketAddress() throws UnknownHostException {
int port = 30000;
InetAddress localAddress = InetAddress.getLocalHost();
return new InetSocketAddress(localAddress, port);
}
static class SshTunnelSocketImpl extends SocketImpl {
private final SshContext sshContext;
private final InetSocketAddress localSocketAddress;
private SshDirectTcpipChannel channel;
SshConfig sshConfig;
private int timeoutOption = 60000;
private TimeoutInputStream in;
public SshTunnelSocketImpl(SshContext sshContext, InetSocketAddress localSocketAddress) {
this.sshContext = sshContext;
this.localSocketAddress = localSocketAddress;
}
@Override
public void setOption(int opt, Object value) throws SocketException {
log.warn("Ignoring setOption {} {}", opt, value);
if (opt == SO_TIMEOUT) {
this.timeoutOption = (int) value;
if (in != null) {
in.setTimeout(timeoutOption);
}
}
}
@Override
public Object getOption(int opt) throws SocketException {
if (opt == SO_TIMEOUT) {
return new Integer(timeoutOption);
}
// int ret = 0;
// /*
// * The native socketGetOption() knows about 3 options.
// * The 32 bit value it returns will be interpreted according
// * to what we're asking. A return of -1 means it understands
// * the option but its turned off. It will raise a SocketException
// * if "opt" isn't one it understands.
// */
//
// switch (opt) {
// case TCP_NODELAY:
// ret = socketGetOption(opt, null);
// return Boolean.valueOf(ret != -1);
// case SO_OOBINLINE:
// ret = socketGetOption(opt, null);
// return Boolean.valueOf(ret != -1);
// case SO_LINGER:
// ret = socketGetOption(opt, null);
// return (ret == -1) ? Boolean.FALSE: (Object)(new Integer(ret));
// case SO_REUSEADDR:
// ret = socketGetOption(opt, null);
// return Boolean.valueOf(ret != -1);
// case SO_BINDADDR:
// InetAddressContainer in = new InetAddressContainer();
// ret = socketGetOption(opt, in);
// return in.addr;
// case SO_SNDBUF:
// case SO_RCVBUF:
// ret = socketGetOption(opt, null);
// return new Integer(ret);
// case IP_TOS:
// ret = socketGetOption(opt, null);
// if (ret == -1) { // ipv6 tos
// return new Integer(trafficClass);
// } else {
// return new Integer(ret);
// }
// case SO_KEEPALIVE:
// ret = socketGetOption(opt, null);
// return Boolean.valueOf(ret != -1);
// // should never get here
// default:
// return null;
// }
throw new UnsupportedOperationException();
}
@Override
protected void create(boolean stream) throws IOException {
}
@Override
protected void connect(String host, int port) throws IOException {
InetAddress remoteAddr = InetAddress.getByName(host);
connect(remoteAddr, port);
}
@Override
protected void connect(InetAddress address, int port) throws IOException {
InetSocketAddress remote = new InetSocketAddress(address, port);
connect(remote, 0);
}
@Override
protected void connect(SocketAddress address, int timeout) throws IOException {
InetSocketAddress httpAddress = (InetSocketAddress) address;
InetSocketAddress sshServer = sshContext.getRemoteSshAddress(httpAddress);
SshDirectTcpipChannel channel;
try {
if (this.sshConfig != null) {
throw new IllegalStateException();
}
this.sshConfig = sshContext.buildConfig(sshServer);
InetSocketAddress tunnelRemote = new InetSocketAddress(InetAddresses.forString("127.0.0.1"),
httpAddress.getPort());
channel = sshConfig.getDirectTcpipConnection(localSocketAddress, tunnelRemote);
} catch (Exception e) {
throw new IOException("Error connecting channel", e);
}
this.channel = channel;
}
@Override
protected void bind(InetAddress host, int port) throws IOException {
throw new UnsupportedOperationException();
}
@Override
protected void listen(int backlog) throws IOException {
throw new UnsupportedOperationException();
}
@Override
protected void accept(SocketImpl s) throws IOException {
throw new UnsupportedOperationException();
}
@Override
protected InputStream getInputStream() throws IOException {
synchronized (this) {
if (in == null) {
in = new TimeoutInputStream(channel.getInputStream());
in.setTimeout(timeoutOption);
}
return in;
}
}
@Override
protected OutputStream getOutputStream() throws IOException {
return channel.getOutputStream();
}
@Override
protected int available() throws IOException {
throw new UnsupportedOperationException();
}
@Override
protected void close() throws IOException {
if (channel != null) {
channel.close();
}
}
@Override
protected void sendUrgentData(int data) throws IOException {
throw new UnsupportedOperationException();
}
}
}