package com.redhat.qe.tools;
import java.io.BufferedReader;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.logging.Level;
import java.util.logging.LogRecord;
import java.util.logging.Logger;
import com.redhat.qe.auto.testng.LogMessageUtil;
import com.trilead.ssh2.ChannelCondition;
import com.trilead.ssh2.Connection;
import com.trilead.ssh2.Session;
import com.trilead.ssh2.StreamGobbler;
public class SSHCommandRunner implements Runnable {
protected Connection connection;
protected String user = null;
protected Session session;
protected InputStream out;
protected static Logger log = Logger.getLogger(SSHCommandRunner.class.getName());
protected InputStream err;
protected String s_out = null;
protected String s_err = null;
protected boolean kill = false;
protected String command = null;
protected Object lock = new Object();
public SSHCommandRunner(Connection connection,
String command) {
super();
this.connection = connection;
this.command = command;
}
public SSHCommandRunner(String server,
String user,
File sshPemFile,
String passphrase,
String command) throws IOException{
super();
Connection newConn = new Connection(server);
newConn.connect();
if (!newConn.authenticateWithPublicKey(user, sshPemFile, passphrase)) {
throw new RuntimeException("Could not log in to " + newConn.getHostname() + " with the given credentials ("+user+").");
}
this.connection = newConn;
this.user = user;
this.command = command;
}
public SSHCommandRunner(String server,
String user,
String passphrase,
File sshPemFile,
String pemPassphrase,
String command) throws IOException{
super();
Connection newConn = new Connection(server);
newConn.connect();
try {
newConn.authenticateWithPublicKey(user, sshPemFile, pemPassphrase);
}
catch (IOException e) {
//e.printStackTrace();
newConn = new Connection(server);
newConn.connect();
if (!newConn.authenticateWithPassword(user, passphrase)) {
throw new RuntimeException("Could not log in to " + newConn.getHostname() + " with the given credentials ("+user+").");
}
}
this.connection = newConn;
this.user = user;
this.command = command;
}
public SSHCommandRunner(String server,
String user,
String password,
String command) throws IOException{
super();
Connection newConn = new Connection(server);
newConn.connect();
try {
newConn.authenticateWithPassword(user, password);
}
catch (IOException e) {
//e.printStackTrace();
newConn = new Connection(server);
newConn.connect();
if (!newConn.authenticateWithPassword(user, password)) {
throw new RuntimeException("Could not log in to " + newConn.getHostname() + " with the given credentials ("+user+").");
}
}
this.connection = newConn;
this.user = user;
this.command = command;
}
public SSHCommandRunner(String server,
String user,
String sshPemFile,
String passphrase,
String command) throws IOException{
this(server, user, new File(sshPemFile), passphrase, command);
}
public SSHCommandRunner(String server,
String user,
String passphrase,
String sshPemFile,
String pemPassphrase,
String command) throws IOException{
this(server, user, passphrase, new File(sshPemFile), pemPassphrase, command);
}
public void run(LogRecord logRecord) {
try {
if (logRecord == null) logRecord = LogMessageUtil.fine();
/*
* Sync'd block prevents other threads from getting the streams before they've been set up here.
*/
synchronized (lock) {
// log.info("SSH: Running '"+this.command+"' on '"+this.connection.getHostname()+"'");
String message = "ssh "+ connection.getHostname()+ " " + command;
if (this.user!=null) message = "ssh "+ user +"@"+ connection.getHostname()+" "+ command;
logRecord.setMessage(message);
log.log(logRecord);
// sshSession.requestDumbPTY();
session = connection.openSession();
//session.startShell();
session.execCommand(command);
out = new StreamGobbler(session.getStdout());
err = new StreamGobbler(session.getStderr());
}
} catch (Exception e) {
throw new RuntimeException(e);
}
}
public void run() {
run(LogMessageUtil.action());
}
public Integer waitFor(){
return waitForWithTimeout(null);
}
/**
* @param timeoutMS - time out, in milliseconds
* @return null if command was interrupted or timedout, the command return code otherwise
*/
public Integer waitForWithTimeout(Long timeoutMS){
/*getStderr();
getStdout();*/
//causes problem when another thread is reading the 'live' output.
int res = 0;
boolean timedOut = false;
int cond = ChannelCondition.EXIT_STATUS | ChannelCondition.EOF;
long startTime = System.currentTimeMillis();
while (!kill &&
((res & cond) != cond)){
if (timeoutMS != null && System.currentTimeMillis() - startTime > timeoutMS) {
timedOut = true;
break;
}
res = session.waitForCondition(cond, 1000);
}
Integer exitCode = null;
if (! (kill || timedOut))
exitCode = getExitCode();
session.close();
kill=false;
return exitCode;
}
public boolean isDone(){
if (session == null)
return false;
if (getExitCode() == null)
return false;
return true;
}
protected String convertStreamToString(InputStream is) {
/*
* To convert the InputStream to String we use the
* BufferedReader.readLine() method. We iterate until the BufferedReader
* return null which means there's no more data to read. Each line will
* appended to a StringBuilder and returned as String.
*/
BufferedReader reader = new BufferedReader(new InputStreamReader(is));
StringBuilder sb = new StringBuilder();
String line = null;
try {
while ((line = reader.readLine()) != null) {
sb.append(line + "\n");
}
} catch (IOException e) {
e.printStackTrace();
} finally {
try {
is.close();
} catch (IOException e) {
e.printStackTrace();
}
}
return sb.toString();
}
public SSHCommandResult getSSHCommandResult() {
return new SSHCommandResult(getExitCode(),getStdout(),getStderr());
}
public Integer getExitCode() {
return session.getExitStatus();
}
/**
* Consumes entire stdout stream of the command, this will block until the stream is closed.
* @return entire contents of stdout stream
*/
public String getStdout() {
synchronized (lock) {
if (s_out == null) s_out = convertStreamToString(out);
return s_out;
}
}
/**
* Consumes entire stderr stream of the command, this will block until the stream is closed.
* @return entire contents of stderr stream
*/
public String getStderr() {
synchronized (lock) {
if (s_err == null) s_err = convertStreamToString(err);
return s_err;
}
}
public void setCommand(String command) {
reset();
this.command = command;
}
public String getCommand() {
return command;
}
public void runCommand(String command){
runCommand(command,LogMessageUtil.fine());
}
public void runCommand(String command, LogRecord logRecord){
reset();
this.command = command;
run(logRecord);
}
public SSHCommandResult runCommandAndWait(String command){
return runCommandAndWait(command,null,LogMessageUtil.fine(), false, true);
}
public SSHCommandResult runCommandAndWait(String command, boolean liveLogOutput){
return runCommandAndWait(command,null,LogMessageUtil.fine(), liveLogOutput, true);
}
public SSHCommandResult runCommandAndWait(String command, Long timeoutMS){
return runCommandAndWait(command,timeoutMS,LogMessageUtil.fine(), false, true);
}
public SSHCommandResult runCommandAndWait(String command, LogRecord logRecord){
return runCommandAndWait(command,null,logRecord, false, true);
}
public SSHCommandResult runCommandAndWaitWithoutLogging(String command){
return runCommandAndWait(command,null,LogMessageUtil.fine(), false, false);
}
/**
* @param command - the remote command to run
* @param timeoutMS - abort if command doesn't complete in this many milliseconds
* (null means wait for command to complete, no matter how long it takes)
* @param logRecord - a log record whose Level and Parameters will be used to do all
* the command output logging. eg, a logRecord whose Level is INFO means log all the
* output at INFO level.
* @param liveLogOutput - if true, log output as the command runs. Good for long running
* commands, or commands that could potentially hang or timeout. If false, don't log
* any output until the command has finished running.
* @param logOutput - if false, the stdout, stderr, and exitCode will not be logged at all
* @return the integer return code of the command
*/
public SSHCommandResult runCommandAndWait(String command, Long timeoutMS, LogRecord logRecord, boolean liveLogOutput, boolean logOutput){
runCommand(command,logRecord);
SplitStreamLogger logger = null;
if (liveLogOutput && logOutput){
logger = new SplitStreamLogger(this);
logger.log(logRecord.getLevel(), logRecord.getLevel());
}
waitForWithTimeout(timeoutMS);
SSHCommandResult sshCommandResult = null;
if (liveLogOutput && logOutput) {
s_out = logger.getStdout();
s_err = logger.getStderr();
}
sshCommandResult = getSSHCommandResult();
if (!liveLogOutput && logOutput){
String o = (this.getStdout().split("\n").length>1)? "\n":"";
String e = (this.getStderr().split("\n").length>1)? "\n":"";
log.log(logRecord.getLevel(), "Stdout: "+o+sshCommandResult.getStdout());
log.log(logRecord.getLevel(), "Stderr: "+e+sshCommandResult.getStderr());
}
if (logOutput){
log.log(logRecord.getLevel(), "ExitCode: "+sshCommandResult.getExitCode());
}
return sshCommandResult;
}
/**
* Stop waiting for the command to complete.
*/
public synchronized void kill(){
kill= true;
}
public InputStream getStdoutStream() {
synchronized (lock) {
return out;
}
}
public InputStream getStdErrStream() {
synchronized (lock) {
return err;
}
}
public void reset(){
try {
if (out!= null) out.close();
if (err != null) err.close();
if (session!= null)session.close();
}
catch(IOException ioe) {
log.log(Level.FINER, "Couldn't close input stream", ioe);
}
s_out = null;
s_err = null;
command = null;
}
public Connection getConnection() {
return connection;
}
/**
* Runs a command via SSH as specified user, logs all output to INFO
* logging level, returns String[] containing stdout in 0 position
* and stderr in 1 position
* @param hostname hostname of system
* @param user user to execute command as
* @param command command to execute
* @return output as String[], stdout in 0 pos and stderr in 1 pos
*/
public static String[] executeViaSSHWithReturn(String hostname,
String user, String command){
return executeViaSSHWithReturnWithTimeout(hostname,
user,
command,
null);
}
/**
* Runs a command via SSH as specified user, logs all output to INFO
* logging level, returns String[] containing stdout in 0 position
* and stderr in 1 position
* @param hostname hostname of system
* @param user user to execute command as
* @param command command to execute
* @param timeout amount of time to wait for command completion, in seconds
* @return output as String[], stdout in 0 pos and stderr in 1 pos
*/
public static String[] executeViaSSHWithReturnWithTimeout(String hostname,
String user, String command, Long timeoutMS){
SSHCommandRunner runner = null;
SplitStreamLogger logger;
// log.info("SSH: Running '"+command+"' on '"+hostname+"'"); // moved log.info into run() method - jsefler 1/4/2010
try{
runner = new SSHCommandRunner(hostname,
user,
new File(System.getProperty("user.dir")+"/.ssh/id_auto_dsa"),
System.getProperty("jon.server.sshkey.passphrase"),command);
runner.run();
logger = new SplitStreamLogger(runner);
logger.log();
Integer exitcode = runner.waitForWithTimeout(timeoutMS);
if (exitcode == null){
log.log(Level.INFO, "SSH command did not complete within timeout window");
return failSSH();
}
}
catch (Exception e){
log.log(Level.INFO, "SSH command failed:", e);
return failSSH();
}
return new String[] {logger.getStdout(), logger.getStderr()};
}
private static String[] failSSH(){
return new String[] {"fail", "fail"};
}
/**
* Test code
* @param args
*/
public static void main(String[] args) throws Exception{
/*Connection conn = new Connection("jweiss-rhel3.usersys.redhat.com");
conn.connect();
if (!conn.authenticateWithPassword("jonqa", "dog8code"))
throw new IllegalStateException("Authentication failed.");
SSHCommandRunner runner = new SSHCommandRunner(conn, "sleep 3");
runner.run();
Integer exitcode = runner.waitForWithTimeout(null);
System.out.println("exit code: " + exitcode);*/
Logger log = Logger.getLogger(SSHCommandRunner.class.getName());
SSHCommandRunner scr = new SSHCommandRunner("f14-1.usersys.redhat.com", "root", "dog8code", "sdf", "sdfs", null);
scr.runCommandAndWait("sleep 5;echo 'hi there';sleep 3", true);
System.out.println("Result: " + scr.getStdout());
/*SCPClient scp = new SCPClient(conn);
scp.put(System.getProperty("user.dir")+ "/../jon/bin/DummyJVM.class", "/tmp");
SSHCommandRunner jrunner = new SSHCommandRunner(conn, "java -Dcom.sun.management.jmxremote.port=1500 -Dcom.sun.management.jmxremote.ssl=false -Dcom.sun.management.jmxremote.authenticate=false -cp /tmp DummyJVM");
jrunner.run();
new SplitStreamLogger(jrunner).log();
Thread.sleep(10000);
SSHCommandRunner runner = new SSHCommandRunner(conn, "ps -ef | grep [D]ummy | awk '{print $2}'");
runner.run();
String pid = runner.getStdout().trim();
log.info("Found pid " + pid);
runner = new SSHCommandRunner(conn, "kill " + pid);
runner.run();
new SplitStreamLogger(runner).log();
runner.waitFor();
jrunner.waitFor();*/
/*SSHCommandRunner jrunner = new SSHCommandRunner(conn, "grep sdf /tmp/sdsdfs");
jrunner.run();
System.out.println(jrunner.waitFor());*/
/* System.out.println("Output: " + runner.getStdout());
System.out.println("Stderr: " + runner.getStderr());*/
}
}