package io.fathom.cloud.ssh.jsch; import io.fathom.cloud.ssh.SftpChannel; import io.fathom.cloud.ssh.SftpChannelBase; import io.fathom.cloud.ssh.SftpStat; import io.fathom.cloud.ssh.SshConfig; import io.fathom.cloud.ssh.SshContext; import io.fathom.cloud.ssh.SshDirectTcpipChannel; import io.fathom.cloud.ssh.SshForwardChannel; import java.io.File; import java.io.FileNotFoundException; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.net.InetAddress; import java.net.InetSocketAddress; import java.security.KeyPair; import java.security.PublicKey; import java.util.List; import java.util.Map; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import com.fathomdb.crypto.bouncycastle.KeyPairs; import com.google.common.collect.Lists; import com.google.common.collect.Maps; import com.google.common.net.InetAddresses; import com.jcraft.jsch.ChannelDirectTCPIP; import com.jcraft.jsch.ChannelExec; import com.jcraft.jsch.ChannelSftp; import com.jcraft.jsch.ChannelSftp.LsEntry; import com.jcraft.jsch.ChannelSftp.LsEntrySelector; import com.jcraft.jsch.JSch; import com.jcraft.jsch.JSchException; import com.jcraft.jsch.Session; import com.jcraft.jsch.SftpATTRS; import com.jcraft.jsch.SftpException; public class SshContextImpl implements SshContext { private static final Logger log = LoggerFactory.getLogger(SshContextImpl.class); final JSch jsch; final String sshUsername; final KeyPair sshKeyPair; public SshContextImpl(String sshUsername, File privateKeyPath) { super(); this.jsch = new JSch(); this.sshUsername = sshUsername; // this.sshKey = sshKey; try { this.jsch.addIdentity(privateKeyPath.getAbsolutePath()); } catch (JSchException e) { throw new IllegalArgumentException("Error loading ssh key", e); } try { this.sshKeyPair = KeyPairs.fromPem(privateKeyPath); } catch (IOException e) { throw new IllegalArgumentException("Error loading ssh key", e); } } class ConnectionState implements SshConfig { final InetSocketAddress remote; Session session; int channelCount; @Override public String toString() { return remote.toString(); } public ConnectionState(InetSocketAddress remote) { super(); this.remote = remote; } String getInfo() { return sshUsername + "@" + remote; } synchronized Session getSession() throws IOException { if (session == null) { Session session; try { session = jsch.getSession(sshUsername, InetAddresses.toAddrString(remote.getAddress()), remote.getPort()); // TODO: Strict host key checking session.setConfig("StrictHostKeyChecking", "no"); // session.setPassword("super_secre_password"); session.connect(); } catch (JSchException e) { session = null; throw new IOException("Error connecting to SSH (" + getInfo() + ")", e); } // ConnectFuture connectFuture; // try { // connectFuture = sshClient.connect(remote); // } catch (Exception e) { // throw new IOException( // "Error connecting to SSH server: " + remote, e); // } // long connectTimeoutMillis = 60000; // if (!connectFuture.await(connectTimeoutMillis)) { // connectFuture.cancel(); // throw new IOException( // "Timeout connecting to SSH server: " + remote); // } // // ClientSession clientSession = connectFuture.getSession(); // // int ret = ClientSession.WAIT_AUTH; // while ((ret & ClientSession.WAIT_AUTH) != 0) { // int authTimeout = 30000; // // clientSession.authPublicKey(sshUsername, sshKey); // ret = clientSession.waitFor(ClientSession.WAIT_AUTH // | ClientSession.CLOSED | ClientSession.AUTHED, // authTimeout); // } // // if ((ret & ClientSession.CLOSED) != 0) { // throw new IOException( // "Unable to authenticate with SSH server: " // + remote); // } this.session = session; // this.session = new PooledClientSession(this, session); } // useCount++; return this.session; } @Override public synchronized SftpChannel getSftpChannel() throws IOException { Session session = getSession(); ChannelSftp sftpChannel; try { sftpChannel = (ChannelSftp) session.openChannel("sftp"); channelCount++; sftpChannel.connect(); } catch (JSchException e) { // TODO: Close session if it's failed?? channelCount--; throw new IOException("Error opening sftp channel (" + getInfo() + ")", e); } return new JschSftpChannel(sftpChannel); } public class JschSftpChannel extends SftpChannelBase { private final ChannelSftp channel; public JschSftpChannel(ChannelSftp channel) { this.channel = channel; } @Override public synchronized void close() throws IOException { try { channel.exit(); } finally { channelCount--; } } @Override public InputStream open(File file) throws IOException { try { return channel.get(file.getPath()); } catch (SftpException e) { if (e.id == ChannelSftp.SSH_FX_NO_SUCH_FILE) { throw new FileNotFoundException(); } log.info("SFTP Error reading file: {} {}", file, e.id); throw new IOException("Error reading file: " + file, e); } } @Override public OutputStream writeFile(File file, WriteMode mode) throws IOException { try { int jschMode; switch (mode) { case Append: jschMode = ChannelSftp.APPEND; break; case Overwrite: jschMode = ChannelSftp.OVERWRITE; break; default: throw new IllegalArgumentException(); } return channel.put(file.getPath(), jschMode); } catch (SftpException e) { throw new IOException("Error writing file: " + file, e); } } @Override public void delete(File file) throws IOException { try { channel.rm(file.getPath()); } catch (SftpException e) { throw new IOException("Error deleting file: " + file, e); } } @Override public SftpStat stat(File file) throws IOException { try { SftpATTRS attrs = channel.stat(file.getPath()); return new JschSftpStat(attrs); } catch (SftpException e) { if (e.id == ChannelSftp.SSH_FX_NO_SUCH_FILE) { return null; } log.info("SFTP Error doing stat on file: {} {}", file, e.id); throw new IOException("Error getting file stat: " + file, e); } } public SftpATTRS lstat(File file) throws IOException { SftpATTRS lstat = null; try { lstat = channel.lstat(file.getAbsolutePath()); } catch (SftpException e) { if (e.id == ChannelSftp.SSH_FX_NO_SUCH_FILE) { throw new FileNotFoundException(); } throw new IOException("Error during sftp lstat: " + file, e); } return lstat; } @Override public boolean exists(File file) throws IOException { try { channel.lstat(file.getAbsolutePath()); return true; } catch (SftpException e) { if (e.id == ChannelSftp.SSH_FX_NO_SUCH_FILE) { return false; } throw new IOException("Error during sftp lstat: " + file, e); } } @Override public boolean mkdir(File file) throws IOException { try { if (exists(file)) { return false; } channel.mkdir(file.getAbsolutePath()); return true; } catch (SftpException e) { if (e.id == ChannelSftp.SSH_FX_NO_SUCH_FILE) { throw new FileNotFoundException(); } throw new IOException("Error during sftp mkdir: " + file, e); } } @Override public void mv(File from, File to) throws IOException { try { channel.rename(from.getPath(), to.getPath()); } catch (SftpException e) { throw new IOException("Error during sftp rename: " + from + " to " + to, e); } } @Override public List<String> ls(File file) throws IOException { try { final List<String> names = Lists.newArrayList(); LsEntrySelector selector = new LsEntrySelector() { @Override public int select(LsEntry entry) { // TODO: Filter out directory etc through // attributes... names.add(entry.getFilename()); return CONTINUE; } }; // TODO: The source code for the ls function is not // confidence-inspiring // TODO: Don't get attributes?? channel.ls(file.getPath(), selector); return names; } catch (SftpException e) { throw new IOException("Error during sftp ls: " + file, e); } } @Override public void chmod(File file, int mode) throws IOException { try { channel.chmod(mode, file.getPath()); } catch (SftpException e) { throw new IOException("Error during sftp chmod: " + file, e); } } @Override public void chown(File file, int uid) throws IOException { try { channel.chown(uid, file.getPath()); } catch (SftpException e) { throw new IOException("Error during sftp chown: " + file, e); } } @Override public void chgrp(File file, int gid) throws IOException { try { channel.chgrp(gid, file.getPath()); } catch (SftpException e) { throw new IOException("Error during sftp chgrp: " + file, e); } } } @Override public SshDirectTcpipChannel getDirectTcpipConnection(InetSocketAddress local, InetSocketAddress remote) throws IOException { Session session = getSession(); ChannelDirectTCPIP directChannel; try { directChannel = (ChannelDirectTCPIP) session.openChannel("direct-tcpip"); directChannel.setHost(InetAddresses.toAddrString(remote.getAddress())); directChannel.setPort(remote.getPort()); directChannel.setOrgIPAddress(InetAddresses.toAddrString(local.getAddress())); directChannel.setOrgPort(local.getPort()); channelCount++; directChannel.connect(); } catch (JSchException e) { // TODO: Close session if it's failed?? channelCount--; throw new IOException("Error opening direct-tcpip channel", e); } return new JschDirectTcpipChannel(directChannel); } @Override public SshForwardChannel forwardLocalPort(InetAddress localAddress, InetSocketAddress remoteSocketAddress) throws IOException { Session session = getSession(); String bindAddress = InetAddresses.toAddrString(localAddress); int assignedPort; ChannelDirectTCPIP directChannel; try { int port = 0; // Auto asssign assignedPort = session.setPortForwardingL(bindAddress, port, remoteSocketAddress.getHostString(), remoteSocketAddress.getPort()); channelCount++; } catch (JSchException e) { // TODO: Close session if it's failed?? channelCount--; throw new IOException("Error opening direct-tcpip channel", e); } return new JschSshForwardChannel(session, bindAddress, assignedPort); } public class JschSshForwardChannel implements SshForwardChannel { private final Session session; private final int port; private final String bindAddress; public JschSshForwardChannel(Session session, String bindAddress, int port) { this.session = session; this.bindAddress = bindAddress; this.port = port; } @Override public synchronized void close() throws IOException { try { session.delPortForwardingL(bindAddress, port); } catch (JSchException e) { throw new IOException("Error deleting port binding", e); } finally { channelCount--; } } @Override public InetSocketAddress getLocalSocketAddress() { return new InetSocketAddress(bindAddress, port); } } public class JschDirectTcpipChannel implements SshDirectTcpipChannel { private final ChannelDirectTCPIP channel; public JschDirectTcpipChannel(ChannelDirectTCPIP channel) { this.channel = channel; } @Override public synchronized void close() throws IOException { try { channel.disconnect(); } finally { channelCount--; } } @Override public InputStream getInputStream() throws IOException { return channel.getInputStream(); } @Override public OutputStream getOutputStream() throws IOException { return channel.getOutputStream(); } } @Override public int execute(String command, OutputStream stdout, OutputStream stderr) throws IOException { Session session = getSession(); ChannelExec channel; try { channel = (ChannelExec) session.openChannel("exec"); channel.setInputStream(null); channel.setOutputStream(stdout); channel.setErrStream(stderr); channel.setCommand(command); channelCount++; channel.connect(); } catch (JSchException e) { // TODO: Close session if it's failed?? channelCount--; throw new IOException("Error opening direct-tcpip channel", e); } try { while (true) { int exitStatus = channel.getExitStatus(); if (exitStatus != -1) { return exitStatus; } try { Thread.sleep(50); } catch (InterruptedException e) { Thread.currentThread().interrupt(); throw new IOException("Interrupted while waiting for SSH command execution", e); } } } finally { channel.disconnect(); channelCount--; } } @Override public String getUser() { return sshUsername; } } final Map<String, ConnectionState> connections = Maps.newHashMap(); ConnectionState getConnectionState(InetSocketAddress server) { String key = InetAddresses.toAddrString(server.getAddress()) + ":" + server.getPort(); ConnectionState connectionState = connections.get(key); if (connectionState == null) { connectionState = new ConnectionState(server); connections.put(key, connectionState); } return connectionState; } @Override public InetSocketAddress getRemoteSshAddress(InetSocketAddress address) { int sshPort = 22; return new InetSocketAddress(address.getAddress(), sshPort); } @Override public SshConfig buildConfig(InetSocketAddress server) { ConnectionState connectionState = getConnectionState(server); return connectionState; } @Override public PublicKey getPublicKey() { return sshKeyPair.getPublic(); // Vector identities = jsch.getIdentityRepository().getIdentities(); // Identity identity = (Identity) identities.get(0); // byte[] publicKeyBytes = identity.getPublicKeyBlob(); // String s = new String(publicKeyBytes); // // return PublicKeys.fromPem(s); } public KeyPair getKeypair() { return sshKeyPair; } }