package org.ovirt.engine.core.uutils.ssh;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.Closeable;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.PipedInputStream;
import java.io.PipedOutputStream;
import java.nio.charset.StandardCharsets;
import java.security.DigestInputStream;
import java.security.DigestOutputStream;
import java.security.KeyPair;
import java.security.MessageDigest;
import java.security.PublicKey;
import java.util.Arrays;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;
import javax.naming.AuthenticationException;
import javax.naming.TimeLimitExceededException;
import org.apache.commons.codec.DecoderException;
import org.apache.commons.codec.binary.Hex;
import org.apache.sshd.ClientChannel;
import org.apache.sshd.ClientSession;
import org.apache.sshd.SshClient;
import org.apache.sshd.client.future.AuthFuture;
import org.apache.sshd.client.future.ConnectFuture;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class SSHClient implements Closeable {
private static final String COMMAND_FILE_RECEIVE =
"test -r '%2$s' && md5sum -b '%2$s' | cut -d ' ' -f 1 >&2 && %1$s < '%2$s'";
private static final String COMMAND_FILE_SEND = "%1$s > '%2$s' && md5sum -b '%2$s' | cut -d ' ' -f 1 >&2";
private static final int STREAM_BUFFER_SIZE = 8192;
private static final int CONSTRAINT_BUFFER_SIZE = 1024;
private static final int THREAD_JOIN_WAIT_TIME = 2000;
private static final int DEFAULT_SSH_PORT = 22;
private static final Logger log = LoggerFactory.getLogger(SSHClient.class);
private SshClient client;
private ClientSession session;
private long softTimeout = 10000;
private long hardTimeout = 0;
private String user;
private String password;
private KeyPair keyPair;
private String host;
private int port = DEFAULT_SSH_PORT;
private PublicKey hostKey;
/**
* Create the client for testing using org.mockito.Mockito.
*
* @return client.
*/
SshClient createSshClient() {
return SshClient.setUpDefaultClient();
}
/**
* Check if file is valid.
*
* This is required as we use shell to pipe into file, so no special charachters are allowed.
*/
private void remoteFileName(String file) {
if (file.indexOf('\'') != -1 ||
file.indexOf('\n') != -1 ||
file.indexOf('\r') != -1) {
throw new IllegalArgumentException("File name should not contain \"'\"");
}
}
/**
* Compare string disgest to digest.
*
* @param digest
* MessageDigest.
* @param actual
* String digest.
*/
private void validateDigest(MessageDigest digest, String actual) throws IOException {
try {
if (!Arrays.equals(
digest.digest(),
Hex.decodeHex(actual.toCharArray()))) {
throw new IOException("SSH copy failed, invalid localDigest");
}
} catch (DecoderException e) {
throw new IOException("SSH copy failed, invalid localDigest");
}
}
/**
* Destructor.
*/
@Override
protected void finalize() {
try {
close();
} catch (IOException e) {
log.error("Finalize exception", e);
}
}
/**
* Set soft timeout.
*
* @param softTimeout
* timeout for network activity.
*
* default is 10 seconds.
*/
public void setSoftTimeout(long softTimeout) {
this.softTimeout = softTimeout;
}
/**
* Set hard timeout.
*
* @param hardTimeout
* timeout for the entire transaction.
*
* timeout of 0 is infinite.
*
* The timeout is evaluate at softTimeout intervals.
*/
public void setHardTimeout(long hardTimeout) {
this.hardTimeout = hardTimeout;
}
/**
* Set user.
*
* @param user
* user.
*/
public void setUser(String user) {
this.user = user;
}
/**
* Set password.
*
* @param password
* password.
*/
public void setPassword(String password) {
this.password = password;
}
/**
* Set keypair.
*
* @param keyPair
* key pair.
*/
public void setKeyPair(KeyPair keyPair) {
this.keyPair = keyPair;
}
/**
* Set host.
*
* @param host
* host.
* @param port
* port.
*/
public void setHost(String host, int port) {
this.host = host;
this.port = port;
hostKey = null;
}
/**
* Set host.
*
* @param host
* host.
* @param port
* port.
*/
public void setHost(String host, Integer port) {
setHost(host, port == null ? DEFAULT_SSH_PORT : port);
}
/**
* Set host.
*
* @param host
* host.
*/
public void setHost(String host) {
setHost(host, DEFAULT_SSH_PORT);
}
/**
* Get host.
*
* @return host as set by setHost()
*/
public String getHost() {
return host;
}
/**
* Get port.
*
* @return port.
*/
public int getPort() {
return port;
}
/**
* Get hard timeout.
*
* @return timeout.
*/
public long getHardTimeout() {
return hardTimeout;
}
/**
* Get soft timeout.
*
* @return timeout.
*/
public long getSoftTimeout() {
return softTimeout;
}
/**
* Get user.
*
* @return user.
*/
public String getUser() {
return user;
}
public String getDisplayHost() {
StringBuilder ret = new StringBuilder(100);
if (host == null) {
ret.append("N/A");
} else {
if (user != null) {
ret.append(user);
ret.append("@");
}
ret.append(host);
if (port != DEFAULT_SSH_PORT) {
ret.append(":");
ret.append(port);
}
}
return ret.toString();
}
/**
* Get host key
*
* @return host key.
*/
public PublicKey getHostKey() {
return hostKey;
}
/**
* Connect to host.
*/
public void connect() throws Exception {
log.debug("Connecting '{}'", this.getDisplayHost());
try {
client = createSshClient();
client.setServerKeyVerifier(
(sshClientSession, remoteAddress, serverKey) -> {
hostKey = serverKey;
return true;
});
client.start();
ConnectFuture cfuture = client.connect(host, port);
if (!cfuture.await(softTimeout)) {
throw new TimeLimitExceededException(
String.format(
"SSH connection timed out connecting to '%1$s'",
this.getDisplayHost()));
}
session = cfuture.getSession();
/*
* Wait for authentication phase so we have host key.
*/
int stat = session.waitFor(
ClientSession.CLOSED |
ClientSession.WAIT_AUTH |
ClientSession.TIMEOUT,
softTimeout);
if ((stat & ClientSession.CLOSED) != 0) {
throw new IOException(
String.format(
"SSH session closed during connection '%1$s'",
this.getDisplayHost()));
}
if ((stat & ClientSession.TIMEOUT) != 0) {
throw new TimeLimitExceededException(
String.format(
"SSH timed out waiting for authentication request '%1$s'",
this.getDisplayHost()));
}
} catch (Exception e) {
log.debug("Connect error", e);
throw e;
}
log.debug("Connected: '{}'", this.getDisplayHost());
}
/**
* Authenticate to host.
*/
public void authenticate() throws Exception {
log.debug("Authenticating: '{}'", this.getDisplayHost());
try {
AuthFuture afuture;
if (keyPair != null) {
afuture = session.authPublicKey(user, keyPair);
} else if (password != null) {
afuture = session.authPassword(user, password);
} else {
throw new AuthenticationException(
String.format(
"SSH authentication failure '%1$s', no password or key",
this.getDisplayHost()));
}
if (!afuture.await(softTimeout)) {
throw new TimeLimitExceededException(
String.format(
"SSH authentication timed out connecting to '%1$s'",
this.getDisplayHost()));
}
if (!afuture.isSuccess()) {
throw new AuthenticationException(
String.format(
"SSH authentication to '%1$s' failed. Please verify provided credentials. %2$s",
this.getDisplayHost(),
keyPair == null ? "Make sure host is configured for password authentication"
: "Make sure key is authorized at host"));
}
} catch (Exception e) {
log.debug("Connect error", e);
throw e;
}
log.debug("Authenticated: '{}'", this.getDisplayHost());
}
/**
* Disconnect and cleanup.
*
* Must be called when done with client.
*/
public void close() throws IOException {
try {
if (session != null) {
session.close(true);
session = null;
}
if (client != null) {
client.stop();
client = null;
}
} catch (Exception e) {
log.error("Failed to close session", e);
throw new IOException(e);
}
}
/**
* Execute generic command.
*
* @param command
* command to execute.
* @param in
* stdin.
* @param out
* stdout.
* @param err
* stderr.
*/
public void executeCommand(
String command,
InputStream in,
OutputStream out,
OutputStream err) throws Exception {
log.debug("Executing: '{}'", command);
if (in == null) {
in = new ByteArrayInputStream(new byte[0]);
}
if (out == null) {
out = new ConstraintByteArrayOutputStream(CONSTRAINT_BUFFER_SIZE);
}
if (err == null) {
err = new ConstraintByteArrayOutputStream(CONSTRAINT_BUFFER_SIZE);
}
/*
* Redirect streams into indexed streams.
*/
ClientChannel channel = null;
try (
final ProgressInputStream iin = new ProgressInputStream(in);
final ProgressOutputStream iout = new ProgressOutputStream(out);
final ProgressOutputStream ierr = new ProgressOutputStream(err)) {
channel = session.createExecChannel(command);
channel.setIn(iin);
channel.setOut(iout);
channel.setErr(ierr);
channel.open();
long hardEnd = 0;
if (hardTimeout != 0) {
hardEnd = System.currentTimeMillis() + hardTimeout;
}
boolean hardTimeout = false;
int stat;
boolean activity;
do {
stat = channel.waitFor(
ClientChannel.CLOSED |
ClientChannel.EOF |
ClientChannel.TIMEOUT,
softTimeout);
hardTimeout = hardEnd != 0 && System.currentTimeMillis() >= hardEnd;
/*
* Notice that we should visit all so do not cascade statement.
*/
activity = iin.wasProgress();
activity = iout.wasProgress() || activity;
activity = ierr.wasProgress() || activity;
} while (!hardTimeout &&
(stat & ClientChannel.TIMEOUT) != 0 &&
activity);
if (hardTimeout) {
throw new TimeLimitExceededException(
String.format(
"SSH session hard timeout host '%1$s'",
this.getDisplayHost()));
}
if ((stat & ClientChannel.TIMEOUT) != 0) {
throw new TimeLimitExceededException(
String.format(
"SSH session timeout host '%1$s'",
this.getDisplayHost()));
}
stat = channel.waitFor(
ClientChannel.CLOSED |
ClientChannel.EXIT_STATUS |
ClientChannel.EXIT_SIGNAL |
ClientChannel.TIMEOUT,
softTimeout);
if ((stat & ClientChannel.EXIT_SIGNAL) != 0) {
throw new IOException(
String.format(
"Signal received during SSH session host '%1$s'",
this.getDisplayHost()));
}
if ((stat & ClientChannel.EXIT_STATUS) != 0 && channel.getExitStatus() != 0) {
throw new IOException(
String.format(
"Command returned failure code %2$d during SSH session '%1$s'",
this.getDisplayHost(),
channel.getExitStatus()));
}
if ((stat & ClientChannel.TIMEOUT) != 0) {
throw new TimeLimitExceededException(
String.format(
"SSH session timeout waiting for status host '%1$s'",
this.getDisplayHost()));
}
// the PipedOutputStream does not
// flush streams at close
// this leads other side of pipe
// to miss last bytes
// not sure why it is required as
// FilteredOutputStream does flush
// on close.
out.flush();
err.flush();
} catch (RuntimeException e) {
log.debug("Execute failed", e);
throw e;
} finally {
if (channel != null) {
int stat = channel.waitFor(
ClientChannel.CLOSED |
ClientChannel.TIMEOUT,
1);
if ((stat & ClientChannel.CLOSED) != 0) {
channel.close(true);
}
}
}
log.debug("Executed: '{}'", command);
}
/**
* Send file using compression and digest check.
*
* We read the file content into gzip and then pipe it into the ssh. Calculating the remoteDigest on the fly.
*
* The digest is printed into stderr for us to collect.
*
* @param file1
* source.
* @param file2
* destination.
*
*
*/
public void sendFile(String file1, String file2) throws Exception {
log.debug("Sending: '{}' '{}'", file1, file2);
remoteFileName(file2);
MessageDigest localDigest = MessageDigest.getInstance("MD5");
// file1->{}->digest->in->out->pout->pin->stdin
Thread t = null;
try (
final InputStream in = new DigestInputStream(
new FileInputStream(file1),
localDigest);
final PipedInputStream pin = new PipedInputStream(STREAM_BUFFER_SIZE);
final OutputStream pout = new PipedOutputStream(pin);
final OutputStream dummy = new ConstraintByteArrayOutputStream(CONSTRAINT_BUFFER_SIZE);
final ByteArrayOutputStream remoteDigest =
new ConstraintByteArrayOutputStream(CONSTRAINT_BUFFER_SIZE)) {
t = new Thread(
() -> {
try (OutputStream out = new GZIPOutputStream(pout)) {
byte[] b = new byte[STREAM_BUFFER_SIZE];
int n;
while ((n = in.read(b)) != -1) {
out.write(b, 0, n);
}
} catch (IOException e) {
log.debug("Exceution during stream processing", e);
}
} ,
"SSHClient.compress " + file1);
t.start();
executeCommand(
String.format(COMMAND_FILE_SEND, "gunzip -q", file2),
pin,
dummy,
remoteDigest);
t.join(THREAD_JOIN_WAIT_TIME);
if (t.getState() != Thread.State.TERMINATED) {
throw new IllegalStateException("Cannot stop SSH stream thread");
}
validateDigest(localDigest, new String(remoteDigest.toByteArray(), StandardCharsets.UTF_8).trim());
} catch (Exception e) {
log.debug("Send failed", e);
throw e;
} finally {
if (t != null) {
t.interrupt();
}
}
log.debug("Sent: '{}' '{}'", file1, file2);
}
/**
* Receive file using compression and localDigest check.
*
* We read the stream and pipe into gunzip, and write into the file. Calculating the remoteDigest on the fly.
*
* The localDigest is printed into stderr for us to collect.
*
* @param file1
* source.
* @param file2
* destination.
*
*/
public void receiveFile(String file1, String file2) throws Exception {
log.debug("Receiving: '{}' '{}'", file1, file2);
remoteFileName(file1);
MessageDigest localDigest = MessageDigest.getInstance("MD5");
// stdout->pout->pin->in->out->digest->{}->file2
Thread t = null;
try (
final PipedOutputStream pout = new PipedOutputStream();
final InputStream pin = new PipedInputStream(pout, STREAM_BUFFER_SIZE);
final OutputStream out = new DigestOutputStream(
new FileOutputStream(file2),
localDigest);
final InputStream empty = new ByteArrayInputStream(new byte[0]);
final ByteArrayOutputStream remoteDigest =
new ConstraintByteArrayOutputStream(CONSTRAINT_BUFFER_SIZE)) {
t = new Thread(
() -> {
try (final InputStream in = new GZIPInputStream(pin)) {
byte[] b = new byte[STREAM_BUFFER_SIZE];
int n;
while ((n = in.read(b)) != -1) {
out.write(b, 0, n);
}
} catch (IOException e) {
log.debug("Exceution during stream processing", e);
}
} ,
"SSHClient.decompress " + file2);
t.start();
executeCommand(
String.format(COMMAND_FILE_RECEIVE, "gzip -q", file1),
empty,
pout,
remoteDigest);
t.join(THREAD_JOIN_WAIT_TIME);
if (t.getState() != Thread.State.TERMINATED) {
throw new IllegalStateException("Cannot stop SSH stream thread");
}
validateDigest(localDigest, new String(remoteDigest.toByteArray(), StandardCharsets.UTF_8).trim());
} catch (Exception e) {
log.debug("Receive failed", e);
throw e;
} finally {
if (t != null) {
t.interrupt();
}
}
log.debug("Received: '{}' '{}'", file1, file2);
}
}