/**
* Copyright 2010 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package datameer.awstasks.ssh;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.InetSocketAddress;
import java.net.NoRouteToHostException;
import java.net.Socket;
import java.net.UnknownHostException;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.List;
import java.util.Properties;
import java.util.concurrent.TimeUnit;
import org.apache.log4j.Logger;
import awstasks.com.jcraft.jsch.CachedSession;
import awstasks.com.jcraft.jsch.ChannelExec;
import awstasks.com.jcraft.jsch.Identity;
import awstasks.com.jcraft.jsch.IdentityKeyString;
import awstasks.com.jcraft.jsch.JSch;
import awstasks.com.jcraft.jsch.JSchException;
import awstasks.com.jcraft.jsch.Proxy;
import awstasks.com.jcraft.jsch.Session;
import awstasks.com.jcraft.jsch.SocketFactory;
import awstasks.com.jcraft.jsch.UIKeyboardInteractive;
import awstasks.com.jcraft.jsch.UserInfo;
import datameer.awstasks.exec.ExecOutputHandler;
import datameer.awstasks.exec.ShellCommand;
import datameer.awstasks.exec.ShellExecutor;
import datameer.awstasks.util.ExceptionUtil;
import datameer.awstasks.util.Retry;
import datameer.com.google.common.base.Preconditions;
import datameer.com.google.common.base.Throwables;
import datameer.com.google.common.hash.Hashing;
import datameer.com.google.common.io.Files;
public class JschRunner extends ShellExecutor {
private static final String CHANGE_ON_ALREADY_RUNNING_SESSION_ERROR_MESSAGE = "This instance of jsch is already connected please disconnect first.";
protected static final Logger LOG = Logger.getLogger(JschRunner.class);
private static final boolean DEFAULT_SESSION_CACHING_ENABLED = false;
private final String _user;
private final String _host;
private int _port = 22;
private File _keyFile;
private String _keyFileContent;
private String _password;
private String _knownHosts = System.getProperty("user.home") + "/.ssh/known_hosts";
private boolean _trust;
protected int _connectTimeout = (int) TimeUnit.SECONDS.toMillis(80);
private int _timeout = 0;
private boolean _debug;
private boolean _enableConnectionRetries;
private int _createdSessions;
private String _credentialHash;
private Properties _config = new Properties();
private Proxy _proxy = null;
private CachedSession _cachedSession = null;
private boolean _sessionCachingEnabled;
public JschRunner(String user, String host) {
this(user, host, DEFAULT_SESSION_CACHING_ENABLED);
}
public JschRunner(String user, String host, boolean sessionCachingEnabled) {
_sessionCachingEnabled = sessionCachingEnabled;
_user = user;
_host = host;
}
public String getHost() {
return _host;
}
public void setKeyfile(File keyfile) {
Preconditions.checkState(!isConnected(_cachedSession), CHANGE_ON_ALREADY_RUNNING_SESSION_ERROR_MESSAGE);
if (_password != null || _keyFileContent != null) {
throwAuthenticationAlreadySetException();
}
_keyFile = keyfile;
try {
_credentialHash = Files.hash(keyfile, Hashing.md5()).toString();
} catch (Exception e) {
throw Throwables.propagate(e);
}
}
public void setKeyfileContent(String keyFileContent) {
Preconditions.checkState(!isConnected(_cachedSession), CHANGE_ON_ALREADY_RUNNING_SESSION_ERROR_MESSAGE);
if (_password != null || _keyFile != null) {
throwAuthenticationAlreadySetException();
}
_keyFileContent = keyFileContent;
_credentialHash = Hashing.md5().hashString(keyFileContent, Charset.defaultCharset()).toString();
}
public void setPassword(String password) {
Preconditions.checkState(!isConnected(_cachedSession), CHANGE_ON_ALREADY_RUNNING_SESSION_ERROR_MESSAGE);
if (_keyFile != null || _keyFileContent != null) {
throwAuthenticationAlreadySetException();
}
_password = password;
_credentialHash = Hashing.md5().hashString(password, Charset.defaultCharset()).toString();
}
public void setConfig(Properties config) {
Preconditions.checkState(!isConnected(_cachedSession), CHANGE_ON_ALREADY_RUNNING_SESSION_ERROR_MESSAGE);
_config = config;
}
private static void throwAuthenticationAlreadySetException() {
throw new IllegalStateException("set either password OR keyfile OR keyfile-content");
}
public void setKnownHosts(String knownHosts) {
Preconditions.checkState(!isConnected(_cachedSession), CHANGE_ON_ALREADY_RUNNING_SESSION_ERROR_MESSAGE);
_knownHosts = knownHosts;
}
public void setTrust(boolean trust) {
Preconditions.checkState(!isConnected(_cachedSession), CHANGE_ON_ALREADY_RUNNING_SESSION_ERROR_MESSAGE);
_trust = trust;
}
public void setPort(int port) {
Preconditions.checkState(!isConnected(_cachedSession), CHANGE_ON_ALREADY_RUNNING_SESSION_ERROR_MESSAGE);
_port = port;
}
public int getPort() {
return _port;
}
public void setConnectTimeout(int connectTimeout) {
_connectTimeout = connectTimeout;
}
public int getConnectTimeout() {
return _connectTimeout;
}
public void setTimeout(int timeout) {
Preconditions.checkState(!isConnected(_cachedSession), CHANGE_ON_ALREADY_RUNNING_SESSION_ERROR_MESSAGE);
_timeout = timeout;
}
public int getTimeout() {
return _timeout;
}
public void setDebug(boolean debug) {
_debug = debug;
}
public boolean isDebug() {
return _debug;
}
public int getCreatedSessions() {
return _createdSessions;
}
public void setEnableConnectionRetries(boolean enableConnectionRetries) {
Preconditions.checkState(!isConnected(_cachedSession), CHANGE_ON_ALREADY_RUNNING_SESSION_ERROR_MESSAGE);
_enableConnectionRetries = enableConnectionRetries;
}
public boolean isEnableConnectionRetries() {
return _enableConnectionRetries;
}
public Proxy getProxy() {
return _proxy;
}
public void setProxy(Proxy proxy) {
Preconditions.checkState(!isConnected(_cachedSession), CHANGE_ON_ALREADY_RUNNING_SESSION_ERROR_MESSAGE);
_proxy = proxy;
}
public void run(JschCommand command) throws IOException {
try {
Session session = null;
try {
session = openSession();
command.execute(session);
} finally {
if (session != null) {
session.disconnect();
}
}
} catch (JSchException e) {
throw new IOException(e);
}
}
@Override
public <R> R execute(ShellCommand<?> command, ExecOutputHandler<R> outputHandler) throws IOException {
SshExecDelegateCommand<R> sshCommand = new SshExecDelegateCommand<R>(command, outputHandler);
run(sshCommand);
return sshCommand.getResult();
}
public InputStream openFile(String remoteFile) throws IOException {
Session session = null;
try {
session = openSession();
return new ScpFileInputStream(session, remoteFile);
} catch (JSchException e) {
throw new IOException(e);
}
}
public OutputStream createFile(String remoteFile, long length) throws IOException {
Session session = null;
try {
session = openSession();
return new ScpFileOutputStream(session, remoteFile, length);
} catch (JSchException e) {
throw new IOException(e);
}
}
/**
* Connects to the host and then closes the connection. Throws an exception if connection cannot
* be established.
*
* @throws IOException
*/
public void testConnect() throws IOException {
run(new JschCommand() {
@Override
public void execute(Session session) throws IOException {
// nothing todo
}
});
}
public void testConnect(long maxWaitTime) throws IOException {
boolean succeed = false;
long startTime = System.currentTimeMillis();
do {
try {
testConnect();
succeed = true;
} catch (IOException e) {
LOG.warn("Failed to connect with " + targetUrl() + " :" + e.getMessage());
try {
Thread.sleep(1000);
} catch (InterruptedException e1) {
Thread.interrupted();
}
}
} while (!succeed && (System.currentTimeMillis() - startTime) < maxWaitTime);
if (!succeed) {
throw new IOException("Failed to establish ssh connection to " + targetUrl());
}
}
public boolean isSessionCacheEnabled() {
return _sessionCachingEnabled;
}
public Session openSession() throws JSchException {
if (isSessionCacheEnabled()) {
if (null == _cachedSession || !isConnected(_cachedSession)) {
_cachedSession = (CachedSession) createFreshSession(true);
}
return _cachedSession;
} else {
return createFreshSession(false);
}
}
private String targetUrl() {
return CachedSession.sshUrl(_user, _credentialHash, _host, _port);
}
static boolean isConnected(Session cachedSession) {
if (null == cachedSession || !cachedSession.isConnected()) {
return false;
}
try {
ChannelExec testChannel = (ChannelExec) cachedSession.openChannel("exec");
testChannel.connect();
testChannel.setCommand("true");
testChannel.disconnect();
return true;
} catch (Exception e) {
LOG.info(String.format("Dropping cached but unusable session " + cachedSession));
return false;
}
}
@SuppressWarnings("unchecked")
private Session createFreshSession(boolean cached) throws JSchException {
JSch jsch = new JSch();
if (isDebug()) {
JSch.setLogger(DEBUG_LOGGER);
}
if (_keyFile != null) {
jsch.addIdentity(_keyFile.getAbsolutePath());
}
if (_keyFileContent != null) {
Identity identity = IdentityKeyString.newInstance(_keyFileContent, jsch);
jsch.addIdentity(identity, null);
}
if (!_trust && _knownHosts != null && new File(_knownHosts).exists()) {
if (LOG.isDebugEnabled()) {
LOG.debug("Using known hosts: " + _knownHosts);
}
jsch.setKnownHosts(_knownHosts);
}
final Session session;
if (cached) {
session = new CachedSession(_user, _host, _port, _credentialHash, jsch);
} else {
session = jsch.getSession(_user, _host, _port);
}
session.setSocketFactory(new SocketFactoryWithConnectTimeout());
session.setUserInfo(new UserInfoImpl(_password));
session.setTimeout(_timeout);
session.setDaemonThread(true);
session.setConfig(_config);
if (_proxy != null) {
session.setProxy(_proxy);
}
if (LOG.isDebugEnabled()) {
LOG.debug("Creating session (cached=" + cached + ") to " + targetUrl());
}
if (_enableConnectionRetries) {
// experimental
Retry retry = Retry.onExceptions(NoRouteToHostException.class).withMaxRetries(3).withWaitTime(2500);
retry.execute(new Runnable() {
@Override
public void run() {
try {
session.connect();
} catch (JSchException e) {
throw ExceptionUtil.convertToRuntimeException(e.getCause());
}
}
});
} else {
session.connect();
}
LOG.info("Created session (cached=" + cached + ") for " + targetUrl());
_createdSessions++;
return session;
}
class SocketFactoryWithConnectTimeout implements SocketFactory {
@Override
public OutputStream getOutputStream(Socket socket) throws IOException {
return socket.getOutputStream();
}
@Override
public InputStream getInputStream(Socket socket) throws IOException {
return socket.getInputStream();
}
@Override
public Socket createSocket(String host, int port) throws IOException, UnknownHostException {
Socket socket = new Socket();
socket.bind(null);
socket.connect(new InetSocketAddress(host, port), _connectTimeout);
return socket;
}
}
private static class UserInfoImpl implements UserInfo, UIKeyboardInteractive {
private final String _password;
public UserInfoImpl(String password) {
_password = password;
}
@Override
public String getPassphrase() {
return "";
}
@Override
public String getPassword() {
return _password;
}
@Override
public boolean promptPassphrase(String arg0) {
return true;
}
@Override
public boolean promptPassword(String arg0) {
return true;
}
@Override
public boolean promptYesNo(String arg0) {
return true;
}
@Override
public void showMessage(String message) {
LOG.info(message);
}
@Override
public String[] promptKeyboardInteractive(String destination, String name, String instruction, String[] prompt, boolean[] echo) {
if (prompt.length != 1 || echo[0] != false || _password == null) {
return null;
}
String[] response = new String[1];
response[0] = _password;
return response;
}
}
protected static awstasks.com.jcraft.jsch.Logger DEBUG_LOGGER = new awstasks.com.jcraft.jsch.Logger() {
@Override
public void log(int level, String message) {
System.out.println("jsch[" + level + "]: " + message);
}
@Override
public boolean isEnabled(int level) {
return true;
}
};
public static File findStandardKeyFile(boolean failIfNotFound) {
String homeFolder = System.getProperty("user.home");
if (homeFolder == null) {
if (failIfNotFound) {
throw new IllegalStateException("no user.home set");
}
return null;
}
List<File> standardPathes = new ArrayList<File>();
standardPathes.add(new File(homeFolder, ".ssh/id_rsa"));
standardPathes.add(new File(homeFolder, ".ssh/id_dsa"));
for (File file : standardPathes) {
if (file.exists()) {
return file;
}
}
if (failIfNotFound) {
throw new IllegalStateException("No private keyfile found in standard locations: " + standardPathes);
}
return null;
}
/**
* Disconnect a cached session. This should always be done if the JschRunner instance is no
* longer used and was created to cache the created session.
*/
public void disconnect() {
if (null != _cachedSession) {
_cachedSession.forcedDisconnect();
_cachedSession = null;
}
}
}