package org.ovirt.engine.core.uutils.ssh; import java.io.ByteArrayOutputStream; import java.io.Closeable; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.io.PipedInputStream; import java.io.PipedOutputStream; import java.io.SequenceInputStream; import java.nio.charset.StandardCharsets; import java.security.KeyPair; import java.security.PublicKey; import java.util.Arrays; import java.util.Collections; import java.util.LinkedList; import java.util.List; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * SSH dialog to be used with SSHClient class. * * Easy processing of stdin/stdout of SSHClient session. Provided the limitations of the SSH implementation this is the * ease the usage of the session. * * The implementation is a wrapper around SSHClient's executeCommand(). */ public class SSHDialog implements Closeable { private static final int BUFFER_SIZE = 10 * 1024; private static final int DEFAULT_SSH_PORT = 22; /** * Control interface. Callback for the sink. */ public interface Control { /** * Disconnect session. */ void close() throws IOException; } /** * Dialog sink. */ public interface Sink { /** * Set control interface. * * @param control * control. */ void setControl(SSHDialog.Control control); /** * Set streams to process. * * @param incoming * incoming stream. * @param outgoing * outgoing stream. * * Streams are null when sink is removed from session. */ void setStreams(InputStream incoming, OutputStream outgoing); /** * Start processing. Usually a thread will be created to process streams. This guarantee to be called after * setStreams(). */ void start(); /** * Stop processing. Called before streams are set to null. */ void stop(); } private static final Logger log = LoggerFactory.getLogger(SSHDialog.class); private String host; private int port; private String user = "root"; private KeyPair keyPair; private String password; private long softTimeout = 0; private long hardTimeout = 0; protected SSHClient client; /** * Get SSH Client. Used for mocking. */ protected SSHClient getSSHClient() { return new SSHClient(); } /** * Destructor. */ @Override protected void finalize() { try { close(); } catch (IOException e) { log.error("Finalize exception", e); } } /** * Get session public key. * * @return public key or null. */ public PublicKey getPublicKey() { if (keyPair == null) { return null; } else { return keyPair.getPublic(); } } /** * Get host public key. */ public PublicKey getHostKey() throws IOException { if (client == null) { throw new IOException("Cannot acquire host key, session is disconnected"); } PublicKey hostKey = client.getHostKey(); if (hostKey == null) { throw new IOException("Unable to retrieve host key"); } return hostKey; } /** * Set host to connect to. * * @param host * host. * @param port * port. */ public void setHost(String host, int port) { this.host = host; this.port = port; } /** * Set host to connect to. * * @param host * host. */ public void setHost(String host) { setHost(host, DEFAULT_SSH_PORT); } /** * Set user to use. * * @param user * user. */ public void setUser(String user) { this.user = user; } /** * Set password to use. If both password and key pair are set key pair is used. */ public void setPassword(String password) { this.password = password; } /** * Set key pair. If both password and key pair are set key pair is used. * * @param keyPair * key pair. */ public void setKeyPair(KeyPair keyPair) { this.keyPair = keyPair; } /** * Set soft timeout. Soft timeout is reset when there is session activity. * * @param timeout * timeout in milliseconds. */ public void setSoftTimeout(long timeout) { softTimeout = timeout; } /** * Set hard timeout. Hard timeout is maximum duration of session. * * @param timeout * timeout in milliseconds. */ public void setHardTimeout(long timeout) { hardTimeout = timeout; } /** * Disconnect session. */ public void close() throws IOException { if (client != null) { client.close(); client = null; } } /** * Connect to host. After connection host fingerprint can be acquired. */ public void connect() throws Exception { log.debug( "connect enter ({}:{}, {}, {})", host, port, hardTimeout, softTimeout); try { if (client != null) { throw new IOException("Already connected"); } client = getSSHClient(); if (hardTimeout != 0) { client.setHardTimeout(hardTimeout); } if (softTimeout != 0) { client.setSoftTimeout(softTimeout); } client.setHost(host, port); log.debug("connecting"); client.setUser(user); client.connect(); } catch (Exception e) { if (client != null) { log.debug( "Could not connect to host '{}'", client.getDisplayHost()); } else { log.debug( "Could not connect to host"); } log.debug("Exception", e); throw e; } } /** * Authenticate. */ public void authenticate() throws Exception { client.setPassword(password); client.setKeyPair(keyPair); client.authenticate(); } /** * Execute command. * * @param sink * sink to use. * @param command * command to execute. * @param initial * initial input streams to send to host before dialog begins. */ public void executeCommand( Sink sink, String command, InputStream[] initial) throws Exception { log.info("SSH execute '{}' '{}'", client.getDisplayHost(), command); try ( final PipedInputStream pinStdin = new PipedInputStream(BUFFER_SIZE); final OutputStream poutStdin = new PipedOutputStream(pinStdin); final PipedInputStream pinStdout = new PipedInputStream(BUFFER_SIZE); final OutputStream poutStdout = new PipedOutputStream(pinStdout); final ByteArrayOutputStream stderr = new ConstraintByteArrayOutputStream(1024)) { try { List<InputStream> stdinList; if (initial == null) { stdinList = new LinkedList<>(); } else { stdinList = new LinkedList<>(Arrays.asList(initial)); } stdinList.add(pinStdin); sink.setControl( () -> { if (client != null) { client.close(); } }); sink.setStreams(pinStdout, poutStdin); sink.start(); try { client.executeCommand( command, new SequenceInputStream(Collections.enumeration(stdinList)), poutStdout, stderr); } catch (Exception e) { if (stderr.size() == 0) { throw e; } log.error( "Swallowing exception as preferring stderr", e); } finally { if (stderr.size() > 0) { throw new RuntimeException( String.format( "Unexpected error during execution: %1$s", new String(stderr.toByteArray(), StandardCharsets.UTF_8))); } } } catch (Exception e) { log.error( "SSH error running command {}:'{}': {}", client.getDisplayHost(), command, e.getMessage()); log.error("Exception", e); throw e; } finally { sink.stop(); sink.setStreams(null, null); } } log.debug("execute leave"); } /** * Send file. Send file using the embedded SSHClient. */ public void sendFile( String file1, String file2) throws Exception { client.sendFile(file1, file2); } /** * Recieve file. Receive file using the embedded SSHClient. */ public void receiveFile( String file1, String file2) throws Exception { client.receiveFile(file1, file2); } }