/***************************************************************************
* Copyright (c) 2013-2014 VMware, Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
***************************************************************************/
package com.vmware.vhadoop.vhm.hadoop;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Set;
import java.util.logging.Level;
import java.util.logging.Logger;
import com.jcraft.jsch.Channel;
import com.jcraft.jsch.ChannelExec;
import com.jcraft.jsch.JSch;
import com.jcraft.jsch.JSchException;
import com.jcraft.jsch.Session;
import com.vmware.vhadoop.util.ExternalizedParameters;
public class SshConnectionCache implements SshUtilities
{
private static final Logger _log = Logger.getLogger(SshConnectionCache.class.getName());
private final int NUM_SSH_RETRIES = ExternalizedParameters.get().getInt("SSH_INITIAL_CONNECTION_NUMBER_OF_RETRIES");
private final int RETRY_DELAY_MILLIS = ExternalizedParameters.get().getInt("SSH_INITIAL_CONNECTION_RETRY_DELAY_MILLIS");
private final int INPUTSTREAM_TIMEOUT_MILLIS = ExternalizedParameters.get().getInt("SSH_REMOTE_EXECUTION_TIMEOUT_MILLIS");
private final int SESSION_READ_TIMEOUT = ExternalizedParameters.get().getInt("SSH_SESSION_READ_TIMEOUT");
private final int NUM_KEEP_ALIVE = ExternalizedParameters.get().getInt("SSH_DROPPED_KEEP_ALIVE_GRACE");
private final int REMOTE_PROC_WAIT_FOR_DELAY = ExternalizedParameters.get().getInt("SSH_REMOTE_PROC_WAIT_FOR_DELAY");
private final String STRICT_HOST_KEY_CHECKING = ExternalizedParameters.get().getString("SSH_STRICT_HOST_KEY_CHECKING").trim();
private static final String SCP_COMMAND = "scp -t ";
private Map<Connection,Session> cache;
private Map<Session,Set<Channel>> channelMap = new HashMap<Session,Set<Channel>>();
private final JSch _jsch = new JSch();
protected final int capacity;
private float loadFactor = 0.75f;
static class Connection {
final String hostname;
final int port;
final Credentials credentials;
Connection(String hostname, int port, Credentials credentials) {
this.hostname = hostname;
this.port = port;
this.credentials = credentials;
}
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result + ((credentials == null) ? 0 : credentials.hashCode());
result = prime * result + ((hostname == null) ? 0 : hostname.hashCode());
return result;
}
@Override
public boolean equals(Object obj) {
if (this == obj)
return true;
if (obj == null)
return false;
if (getClass() != obj.getClass())
return false;
Connection other = (Connection) obj;
if (credentials == null) {
if (other.credentials != null)
return false;
} else if (!credentials.equals(other.credentials))
return false;
if (hostname == null) {
if (other.hostname != null)
return false;
} else if (!hostname.equals(other.hostname))
return false;
return true;
}
}
class RemoteProcess extends Process {
public final static int UNDEFINED_EXIT_STATUS = -1;
Session session;
ChannelExec channel;
InputStream stdout;
InputStream stderr;
OutputStream stdin;
private int exitStatus = -1;;
public RemoteProcess(ChannelExec channel) throws IOException, JSchException {
this.channel = channel;
this.stdin = channel.getOutputStream();
this.stderr = channel.getErrStream();
this.stdout = channel.getInputStream();
this.session = channel.getSession();
}
@Override
protected void finalize() throws Throwable {
super.finalize();
cleanup();
}
@Override
public OutputStream getOutputStream() {
return stdin;
}
@Override
public InputStream getInputStream() {
return stdout;
}
@Override
public InputStream getErrorStream() {
return stderr;
}
@Override
public synchronized int waitFor() throws InterruptedException {
if (channel == null) {
return exitStatus;
}
do {
exitStatus = channel.getExitStatus();
if (exitStatus != -1) {
cleanup();
return exitStatus;
}
this.wait(REMOTE_PROC_WAIT_FOR_DELAY);
} while (true);
}
/**
* Waits for the remote process to complete
* @param timeout timeout in milliseconds
* @return the return code for the process or -1 if the wait times out
* @throws InterruptedException
*/
public synchronized int waitFor(int timeout) throws InterruptedException {
if (channel == null) {
return exitStatus;
}
long deadline = System.currentTimeMillis() + timeout;
do {
exitStatus = channel.getExitStatus();
if (exitStatus != -1) {
cleanup();
return exitStatus;
}
this.wait(REMOTE_PROC_WAIT_FOR_DELAY);
} while (deadline > System.currentTimeMillis());
return exitStatus;
}
@Override
public synchronized int exitValue() {
if (channel == null) {
return exitStatus;
}
exitStatus = channel.getExitStatus();
if (exitStatus != -1) {
cleanup();
}
return exitStatus;
}
@Override
public synchronized void destroy() {
if (channel == null) {
return;
}
try {
channel.sendSignal("SIGKILL");
} catch (Exception e) {
_log.log(Level.INFO, "VHM: unable to send kill signal to remote process", e);
}
exitStatus = channel.getExitStatus();
cleanup();
}
private void cleanup() {
if (channel == null) {
return;
}
channel.disconnect();
synchronized (cache) {
Set<Channel> channels = channelMap.get(session);
if (channels != null) {
channels.remove(channel);
if (channels.isEmpty() && !cache.containsValue(session)) {
channelMap.remove(session);
_log.fine("Disconnecting session during RemoteProcess cleanup for "+session.getUserName()+"@"+session.getHost());
session.disconnect();
}
}
}
channel = null;
if (stdin != null) {
try {
stdin.close();
} catch (IOException e) { /* squash */ }
}
if (stdout != null) {
try {
stdout.close();
} catch (IOException e) { /* squash */ }
}
if (stderr != null) {
try {
stderr.close();
} catch (IOException e) { /* squash */ }
}
}
}
/**
* This does NOT leave channels and sessions connected if they're in use.
* It explicitly disconnects everything then clears the cache.
*
* Should be used for only for explicit shutdown of all sessions and connections that
* this cache has served.
*/
protected void clearCache() {
synchronized (cache) {
for (Set<Channel> channelSet : channelMap.values()) {
for (Channel channel : channelSet) {
channel.disconnect();
}
}
for (Session session : cache.values()) {
if (session.isConnected()) {
session.disconnect();
}
}
cache.clear();
}
}
public SshConnectionCache(int capacity) {
this.capacity = capacity;
int baseSize = (int)Math.ceil(capacity / loadFactor) + 2;
cache = Collections.synchronizedMap(new LinkedHashMap<Connection,Session>(baseSize, loadFactor, true) {
private static final long serialVersionUID = 1328753943644428132L;
@Override
protected boolean removeEldestEntry (Map.Entry<Connection,Session> eldest) {
boolean remove = size() > SshConnectionCache.this.capacity;
/* if we're removing this session it has to be disconnected to avoid leaking sockets */
if (remove) {
Session session = eldest.getValue();
if (session != null) {
_log.fine("Disconnecting session during cache eviction for "+session.getUserName()+"@"+session.getHost());
Set<Channel> channels = channelMap.get(session);
/* if there are no incomplete channels associated with this session then evict it */
if (channels == null || channels.isEmpty()) {
channelMap.remove(session);
session.disconnect();
}
}
}
return remove;
}
});
}
/**
* Extension point method for child classes. The cache is unmodifiable but the elements in it are not.
* Care should be taken when changing state of the contained objects.
* @return an unmodifiable view of the cache
*/
protected Map<Connection,Session> getCache() {
return Collections.unmodifiableMap(cache);
}
/**
* Create the basic JSCH session object that's going to be our handle to the host
* @param connection
* @return
*/
protected Session createSession(Connection connection) {
try {
Session session = _jsch.getSession(connection.credentials.username, connection.hostname, connection.port);
java.util.Properties config = new java.util.Properties();
config.put("StrictHostKeyChecking", STRICT_HOST_KEY_CHECKING);
session.setConfig(config);
session.setTimeout(SESSION_READ_TIMEOUT);
session.setServerAliveCountMax(NUM_KEEP_ALIVE);
return session;
} catch (JSchException e) {
String msg = "VHM: "+connection.hostname+" - could not create ssh session container";
_log.warning(msg + " - "+ e.getMessage());
_log.log(Level.INFO, msg, e);
}
return null;
}
/**
* Connect the provided session object using the credentials supplied.
* @param session
* @param credentials
* @return
*/
protected boolean connectSession(Session session, Credentials credentials) {
if (session.isConnected()) {
_log.finer("VHM: "+session.getHost()+" - using cached connection");
return true;
}
for (int i = 0; i < NUM_SSH_RETRIES; i++) {
try {
// If private key file is specified and not already added, use that as identity
String prvkeyFile = credentials.privateKeyFile;
if (prvkeyFile != null && !_jsch.getIdentityNames().contains(prvkeyFile)) {
_jsch.addIdentity(prvkeyFile);
}
if (credentials.password != null) {
session.setPassword(credentials.password);
}
_log.finer("VHM: "+session.getHost()+" - establishing ssh connection");
session.connect();
return true;
} catch (JSchException e) {
if (e.getMessage().equals("Packet corrupt")) {
_log.info("VHM: "+session.getHost()+" - connection to host dropped");
/* pretty log message if we're trying to reconnect a session that's been previously connected and now needs discarding */
return false;
} else {
_log.info("VHM: "+session.getHost()+" - could not create ssh channel to host - " + e.getMessage());
if (i < NUM_SSH_RETRIES - 1) {
try {
_log.info("VHM: "+session.getHost()+" - retrying ssh connection to host after delay");
Thread.sleep(RETRY_DELAY_MILLIS);
} catch (InterruptedException e1) {
_log.info("VHM: unexpected interruption while waiting to retry ssh connection");
}
}
}
}
}
_log.warning("VHM: "+session.getHost()+" - unable to establish ssh session");
return false;
}
/**
* Get the session to operate with for a given connection. This will return a connected cached session or create a new one.
* @param connection
* @return
*/
protected Session getSession(Connection connection) {
synchronized (cache) {
Session session = cache.get(connection);
/* we try this twice because if the cached connection's dropped then we'll need to discard it and
* try again with a new one. */
for (int i = 0; i < 2; i++) {
if (session == null) {
session = createSession(connection);
if (session == null) {
return null;
}
cache.put(connection, session);
channelMap.put(session, new HashSet<Channel>());
}
if (!connectSession(session, connection.credentials)) {
cache.remove(connection);
channelMap.remove(session);
if (session != null) {
/* ensure that even if it's something odd causing connectSession to fail we clean up */
session.disconnect();
}
session = null;
} else {
/* the session is valid and connected */
break;
}
}
return session;
}
}
/**
* This performs some validity checking on the streams from the SSH connection.
* @param in
* @return
* @throws IOException
*/
private boolean assertRemoteScpReady(InputStream in) throws IOException {
int b = in.read();
if (b == 0) {
return true;
} else if (b < 0) {
_log.log(Level.INFO, "VHM: expected byte 0x0 but end of stream received");
return false;
} else {
/* we weren't expecting data on this stream, so read it to log what we've been given */
StringBuffer sb = new StringBuffer();
do {
sb.append((char) b);
b = in.read();
} while (b != '\n' && b >= 0);
_log.log(Level.INFO, "VHM: expected byte 0x0 but saw the following data: " + sb.toString());
return false;
}
}
protected int copy(Connection connection, byte[] data, String remoteDirectory, String remoteName, String permissions) {
int exitCode = RemoteProcess.UNDEFINED_EXIT_STATUS;
String command = SCP_COMMAND + remoteDirectory;
/* ensure there's a path separator between directory and name */
String sep = System.getProperty("file.separator");
if (!remoteDirectory.endsWith(sep)) {
command+= sep;
}
command+= remoteName;
RemoteProcess proc = null;
try {
proc = invoke(connection, command, null, null);
OutputStream out = proc.getOutputStream();
InputStream in = proc.getInputStream();
if (!assertRemoteScpReady(in)) {
_log.info("VHM: scp protocol error while preparing channel to remote host");
return exitCode;
}
// send "C$perms filesize filename", where filename should not include
StringBuilder params = new StringBuilder("C0").append(permissions);
params.append(" ").append(data.length).append(" ");
params.append(remoteName).append("\n");
out.write(params.toString().getBytes());
out.flush();
if (!assertRemoteScpReady(in)) {
_log.info("VHM: scp protocol error while waiting for confirmation of specified permissions for remote file");
return exitCode;
}
out.write(data);
out.write(new byte[] { 0 }, 0, 1);
out.flush();
if (!assertRemoteScpReady(in)) {
_log.info("VHM: scp protocol error waiting for confirmation of data transfer");
}
out.close();
/* set this explicitly here as that last assert provided us with the return code for the copy */
exitCode = 0;
} catch (Exception e) {
String msg = "VHM: "+connection.hostname+" - exception copying data to remote host";
_log.log(Level.WARNING, msg+" - "+e.getMessage());
_log.log(Level.INFO, msg, e);
} finally {
if (proc != null) {
proc.cleanup();
}
}
return exitCode;
}
@Override
public int copy(String remote, int port, Credentials credentials, byte[] data, String remoteDirectory, String remoteName, String permissions) {
Connection connection = new Connection(remote, port, credentials);
return copy(connection, data, remoteDirectory, remoteName, permissions);
}
protected int execute(Connection connection, String command, OutputStream stdout) throws IOException {
int exitCode = RemoteProcess.UNDEFINED_EXIT_STATUS;
RemoteProcess proc = null;
try {
proc = invoke(connection, command, stdout, null);
long deadline = System.currentTimeMillis() + INPUTSTREAM_TIMEOUT_MILLIS;
do {
try {
exitCode = proc.waitFor(INPUTSTREAM_TIMEOUT_MILLIS);
if (exitCode != RemoteProcess.UNDEFINED_EXIT_STATUS) {
/* we only loop if the command hasn't completed */
break;
}
} catch (InterruptedException e) {
_log.info("VHM: unexpected interruption while waiting for remote command to complete");
}
} while (deadline > System.currentTimeMillis());
/* Caller is responsible for cleaning up resources passed in, but make sure all the data's been passed along */
if (stdout != null) {
try {
stdout.flush();
} catch (IOException e) { /* squash */ }
}
} catch (Exception e) {
String msg = "VHM: "+connection.hostname+" - exception executing command on remote host";
_log.log(Level.WARNING, msg+" - "+e.getMessage());
_log.log(Level.INFO, msg, e);
} finally {
if (proc != null) {
proc.cleanup();
}
}
_log.log(Level.FINE, "Exit status from exec is: " + exitCode);
return exitCode;
}
@Override
public int execute(String remote, int port, Credentials credentials, String command, OutputStream stdout) throws IOException {
Connection connection = new Connection(remote, port, credentials);
return execute(connection, command, stdout);
}
public RemoteProcess invoke(Connection connection, String command, OutputStream stdout, InputStream stdin) throws IOException {
/* get the cached session for the remote user/host or create a new one */
ChannelExec channel;
/* we synchronize on cache so that we don't have the potential to evict and disconnect a session in between confirming that
* it's functional and recording an opening in the channel map */
synchronized (cache) {
Session session = getSession(connection);
if (session == null) {
throw new IOException("unable to establish session to remote host "+connection.hostname);
}
/* open a new exec channel - this is tightly coupled to the execution of the command and will be closed on command completion */
try {
channel = (ChannelExec) session.openChannel("exec");
Set<Channel> channels = channelMap.get(session);
channels.add(channel);
} catch (JSchException e) {
String msg = "VHM: "+connection.hostname+" - exception opening SSH execution channel to host";
_log.log(Level.INFO, msg, e);
throw new IOException(msg);
}
}
/* execute the remote command and set up our remote process wrapper */
RemoteProcess proc = null;
try {
_log.log(Level.FINE, "About to execute: " + command);
if (command.startsWith("sudo")) {
/* sudo requires an allocated pty */
channel.setPty(true);
}
channel.setCommand(command);
/* this calls getOutput/Error/InputStream which seems to overwrite anything set by setOutputStream, so needs to be done first */
proc = new RemoteProcess(channel);
/* if we have sink and source already, set the channels up */
if (stdout != null) {
channel.setOutputStream(stdout);
}
if (stdin != null) {
channel.setInputStream(stdin);
}
channel.connect();
/* if we have sink and source then the corresponding streams are already linked so shouldn't be read directly */
if (stdout != null) {
proc.stdout = null;
}
if (stdin != null) {
proc.stdin = null;
}
_log.log(Level.FINE, "Finished channel connection in exec");
return proc;
} catch (Exception e) {
String msg = "VHM: "+connection.hostname+" - exception invoking remote command on host";
_log.log(Level.WARNING, msg+": "+e.getMessage());
_log.log(Level.INFO, msg, e);
channel.disconnect();
if (proc != null) {
proc.cleanup();
}
throw new IOException(msg);
}
}
@Override
public RemoteProcess invoke(String remote, int port, Credentials credentials, String command, OutputStream stdout) throws IOException {
Connection connection = new Connection(remote, port, credentials);
return invoke(connection, command, stdout, null);
}
}