/*
* Copyright (c) 2012-2015 iWave Software LLC
* All Rights Reserved
*/
package com.iwave.utility.ssh;
import java.io.BufferedWriter;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.io.Writer;
import java.util.Arrays;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.apache.commons.lang.StringUtils;
import org.apache.log4j.Logger;
import com.google.common.collect.Lists;
import com.iwave.ext.command.Command;
import com.iwave.ext.command.CommandException;
import com.iwave.ext.command.CommandExecutor;
import com.iwave.ext.command.CommandOutput;
import com.jcraft.jsch.ChannelShell;
import com.jcraft.jsch.JSch;
import com.jcraft.jsch.JSchException;
import com.jcraft.jsch.Session;
/**
* This is a command executor that can execute multiple commands within a shell session.
*
* @author jonnymiller
*/
public class ShellCommandExecutor implements CommandExecutor {
private static final Logger LOG = Logger.getLogger(ShellCommandExecutor.class);
private static final Pattern PROMPT = Pattern.compile("[$#>] ");
private static final int TIMEOUT = 30;
private static final String ENTER = "\r";
private static final String EXIT = "exit";
private static final int CONNECTION_TIMEOUT = 60 * 60 * 1000;
/** The connection information. */
private SSHConnection connection;
/** The timeout for commands. */
private int timeoutInSeconds = TIMEOUT;
/** The current session initial prompt. */
private String initialPrompt;
/** The pattern for the command prompt. */
private Pattern promptPattern;
/** The shell encoding (defaults to UTF-8). */
private String encoding = "UTF-8";
/** The SSH session. */
private Session session;
/** The shell. */
private ChannelShell shell;
/** The shell configurator (optional). */
private ShellConfigurator shellConfigurator;
/** Stdin writer. */
private Writer stdin;
/** The consumer of stdout. */
private CharStreamConsumer stdout;
/** Tracks the position of stdout between command executions. */
private int stdoutPos;
public ShellCommandExecutor() {
}
public ShellCommandExecutor(String host, String username, String password) {
this(host, 22, username, password);
}
public ShellCommandExecutor(String host, int port, String username, String password) {
this.connection = new SSHConnection(host, port, username, password);
}
public ShellCommandExecutor(SSHConnection connection) {
this.connection = connection;
}
/**
* Gets the SSH connection information.
*
* @return the connection information.
*/
public SSHConnection getConnection() {
return connection;
}
/**
* Sets the SSH connection information.
*
* @param connection the connection information.
*/
public void setConnection(SSHConnection connection) {
this.connection = connection;
}
public ShellConfigurator getShellConfigurator() {
return shellConfigurator;
}
public void setShellConfigurator(ShellConfigurator shellConfigurator) {
this.shellConfigurator = shellConfigurator;
}
public int getTimeoutInSeconds() {
return timeoutInSeconds;
}
public void setTimeoutInSeconds(int timeout) {
this.timeoutInSeconds = timeout;
}
public String getEncoding() {
return encoding;
}
public void setEncoding(String encoding) {
this.encoding = encoding;
}
/**
* Gets the prompt that was matched initially after login.
*
* @return the initial prompt.
*/
public String getInitialPrompt() {
return initialPrompt;
}
/**
* Gets the exact prompt pattern.
*
* @return the exact prompt pattern.
*/
public Pattern getPromptPattern() {
return promptPattern;
}
/**
* Sets the prompt pattern. This will be set after {@link #connect()}, and may be overridden
* then if required.
*
* @param promptPattern the exact prompt pattern.
*/
public void setPromptPattern(Pattern promptPattern) {
this.promptPattern = promptPattern;
}
/**
* Determines if this is connected.
*
* @return true if this is connected.
*/
public boolean isConnected() {
return (shell != null) && !shell.isClosed();
}
/**
* Connects to the system with the default prompt pattern (<code>[$#>] </code>).
*
* @throws SSHException
*/
public void connect() throws SSHException {
connect(PROMPT);
}
/**
* Connects to the system using the specified initial prompt pattern. The actual prompt matched
* will be available via {@link #getInitialPrompt()}, and the specific prompt pattern may be
* changed after using {@link #setPromptPattern(Pattern)}.
*
* @throws SSHException if an error occurs.
*/
public void connect(Pattern prompt) throws SSHException {
if (!isConnected()) {
try {
debug("Connecting to %s:%s as %s", connection.getHost(), connection.getPort(),
connection.getUsername());
// Create a new SSH session
session = new JSch().getSession(connection.getUsername(), connection.getHost(),
connection.getPort());
session.setUserInfo(new SSHUserInfo(connection.getPassword()));
session.setTimeout(CONNECTION_TIMEOUT);
session.connect(CONNECTION_TIMEOUT);
debug("Opening shell channel, encoding: %s", encoding);
// Open a shell and setup the input/output
shell = (ChannelShell) session.openChannel("shell");
stdin = new BufferedWriter(
new OutputStreamWriter(shell.getOutputStream(), encoding));
stdout = new CharStreamConsumer(shell.getInputStream(), encoding);
stdout.setLogger(Logger.getLogger(getClass().getName() + ".stdout"));
configureShell(shell);
shell.connect();
waitForInitialPrompt(prompt);
} catch (JSchException e) {
forceQuit();
throw new SSHException(e);
} catch (IOException e) {
forceQuit();
throw new SSHException(e);
}
}
}
/**
* Configures the shell before connect.
*
* @param shell the shell to configure.
*/
protected void configureShell(ChannelShell shell) {
if (shellConfigurator != null) {
shellConfigurator.configureShell(shell);
}
}
/**
* Sends a shell command.
*
* @param command the command to send.
*
* @throws IOException if an I/O error occurs.
*/
protected void send(String command) throws IOException {
debug("Sending: %s", command);
stdin.write(command);
stdin.write(ENTER);
stdin.flush();
}
/**
* Gets the output for the currently running command.
*
* @return the current command output.
*/
protected String getCurrentCommandOutput() {
return StringUtils.substring(getCurrentStandardOutContents(), stdoutPos);
}
/**
* Gets the current standard out contents (from the entire shell session).
*
* @return the current standard out contents.
*/
protected String getCurrentStandardOutContents() {
return stdout.toString();
}
/**
* Waits for the given pattern to appear in the output.
*
* @param pattern the pattern.
* @return the matched value.
*
* @throws IOException
*/
private String waitFor(Pattern pattern) throws IOException {
debug("Waiting for: %s", pattern.pattern());
String output = getCurrentCommandOutput();
int lastLength = output.length();
long lastChange = System.currentTimeMillis();
try {
Matcher m = pattern.matcher(output);
while (!m.find()) {
boolean noChange = (output.length() == lastLength);
if (noChange) {
long sinceLastChange = System.currentTimeMillis() - lastChange;
if (sinceLastChange > (timeoutInSeconds * 1000)) {
debug("Timeout with output: %s", output);
throw new IOException("Timeout waiting for: " + m.pattern().pattern());
}
}
else {
lastChange = System.currentTimeMillis();
}
Thread.sleep(20);
output = getCurrentCommandOutput();
lastLength = output.length();
m.reset(output);
}
String match = m.group();
debug("Found: '%s'", match);
return match;
} catch (InterruptedException e) {
throw new IOException(e);
}
}
/**
* Waits for the first prompt.
*
* @throws IOException
*/
private void waitForInitialPrompt(Pattern initialPromptPattern) throws IOException {
String prompt = waitFor(initialPromptPattern);
String output = getCurrentStandardOutContents();
int index = StringUtils.indexOf(output, prompt) + prompt.length();
// Find the initial prompt for this session
initialPrompt = StringUtils.substring(output, 0, index);
if (StringUtils.contains(initialPrompt, '\n')) {
initialPrompt = StringUtils.substringAfterLast(initialPrompt, "\n");
}
debug("Found initial prompt: '%s'", initialPrompt);
promptPattern = Pattern.compile(Pattern.quote(initialPrompt));
stdoutPos = index;
}
/**
* Disconnects from the remote host.
*
* @throws SSHException if an error occurs.
*/
public void disconnect() throws SSHException {
if (isConnected()) {
try {
send(EXIT);
} catch (IOException e) {
throw new SSHException(e);
} finally {
forceQuit();
}
debug("Disconnected");
}
}
/**
* Forcibly quits the session.
*/
private void forceQuit() {
if (stdout != null) {
stdout.close();
}
try {
if (shell != null) {
shell.disconnect();
}
} catch (RuntimeException e) {
LOG.error(e.getMessage(), e);
} finally {
shell = null;
}
try {
if (session != null) {
session.disconnect();
}
} catch (RuntimeException e) {
LOG.error(e.getMessage(), e);
} finally {
session = null;
}
}
/**
* Executes a command on the remote system.
*/
@Override
public CommandOutput executeCommand(Command command) throws CommandException {
if (!isConnected()) {
connect();
}
String cli = command.getCommandLine();
return sendCommand(cli);
}
/**
* Sends a command to the remote host. This returns a CommandOutput, but the exit value is
* always 0 since it is being handled by the remote shell.
*
* @param command the command to send.
* @return the command output.
*/
protected CommandOutput sendCommand(String command) {
try {
send(command);
String matched = waitFor(promptPattern);
String stdout = getCurrentCommandOutput();
stdoutPos += StringUtils.length(stdout);
// Strip the command from the start of the output (ignoring any inserted line breaks)
stdout = IWaveStringUtils.removeStartIgnoringWhiteSpace(stdout, command);
if (StringUtils.startsWith(stdout, "\r\n")) {
stdout = StringUtils.removeStart(stdout, "\r\n");
}
else if (StringUtils.startsWith(stdout, "\r") || StringUtils.startsWith(stdout, "\n")) {
stdout = StringUtils.substring(stdout, 1);
}
// Strip the prompt from the end of the output
stdout = StringUtils.removeEnd(stdout, matched);
return new CommandOutput(stdout, null, 0);
} catch (Exception e) {
CommandException ce = new CommandException(e);
ce.setOutput(tryGetCommandOutput());
throw ce;
}
}
private CommandOutput tryGetCommandOutput() {
try {
String stdout = getCurrentCommandOutput();
return new CommandOutput(stdout, null, 0);
} catch (RuntimeException e) {
return null;
}
}
public List<CommandOutput> executeCommands(Command... commands) throws CommandException {
return executeCommands(Arrays.asList(commands));
}
public List<CommandOutput> executeCommands(List<? extends Command> commands)
throws CommandException {
List<CommandOutput> results = Lists.newArrayList();
for (Command command : commands) {
command.setCommandExecutor(this);
command.execute();
results.add(command.getOutput());
}
return results;
}
protected void info(String message, Object... args) {
if (LOG.isInfoEnabled()) {
if (args.length > 0) {
LOG.info(String.format(message, args));
}
else {
LOG.info(message);
}
}
}
protected void debug(String message, Object... args) {
if (LOG.isDebugEnabled()) {
if (args.length > 0) {
LOG.debug(String.format(message, args));
}
else {
LOG.debug(message);
}
}
}
/**
* Interface for configuring the shell before connect.
*/
public static interface ShellConfigurator {
public void configureShell(ChannelShell shell);
}
/**
* Shell configurator for setting a VT100 terminal.
*/
public static class VT100ShellConfigurator implements ShellConfigurator {
private int columns;
private int rows;
public VT100ShellConfigurator(int columns, int rows) {
this.columns = columns;
this.rows = rows;
}
@Override
public void configureShell(ChannelShell shell) {
shell.setPtyType("vt100", columns, rows, 0, 0);
}
}
}