package org.zstack.utils.ssh; import net.schmizz.sshj.SSHClient; import net.schmizz.sshj.common.IOUtils; import net.schmizz.sshj.connection.ConnectionException; import net.schmizz.sshj.connection.channel.direct.Session; import net.schmizz.sshj.transport.TransportException; import net.schmizz.sshj.transport.verification.HostKeyVerifier; import org.apache.commons.io.FileUtils; import org.apache.commons.lang.StringUtils; import org.apache.commons.lang.time.StopWatch; import org.zstack.utils.CollectionUtils; import org.zstack.utils.DebugUtils; import org.zstack.utils.Utils; import org.zstack.utils.function.Function; import org.zstack.utils.logging.CLogger; import org.zstack.utils.path.PathUtil; import java.io.File; import java.io.IOException; import java.security.PublicKey; import java.util.*; import java.util.concurrent.TimeUnit; import static org.zstack.utils.CollectionDSL.e; import static org.zstack.utils.CollectionDSL.map; import static org.zstack.utils.StringDSL.ln; import static org.zstack.utils.StringDSL.s; /** */ public class Ssh { private static final CLogger logger = Utils.getLogger(Ssh.class); private String hostname; private String username; private String privateKey; private String password; private int port = 22; private int timeout = Integer.MAX_VALUE; private List<SshRunner> commands = new ArrayList<SshRunner>(); private SSHClient ssh; private File privateKeyFile; private boolean closed = false; private boolean suppressException = false; private ScriptRunner script; private boolean init = false; private interface SshRunner { SshResult run(); String getCommand(); } private class ScriptRunner { String scriptName; File scriptFile; SshRunner scriptCommand; String scriptContent; ScriptRunner(String scriptName, String parameters, Map token) { this.scriptName = scriptName; String scriptPath = PathUtil.findFileOnClassPath(scriptName, true).getAbsolutePath(); try { if (parameters == null) { parameters = ""; } if (token == null) { token = new HashMap(); } String contents = FileUtils.readFileToString(new File(scriptPath)); String srcScript = String.format("zstack-script-%s", UUID.randomUUID().toString()); scriptFile = new File(PathUtil.join(PathUtil.getFolderUnderZStackHomeFolder("temp-scripts"), srcScript)); scriptContent = s(contents).formatByMap(token); String remoteScript = ln( "/bin/bash << EOF", "cat << EOF1 > {remotePath}", "{scriptContent}", "EOF1", "/bin/bash {remotePath} {parameters} 1>{stdout} 2>{stderr}", "ret=$?", "test -f {stdout} && cat {stdout}", "test -f {stderr} && cat {stderr} 1>&2", "rm -f {remotePath}", "rm -f {stdout}", "rm -f {stderr}", "exit $ret", "EOF" ).formatByMap(map(e("remotePath", String.format("/tmp/%s", UUID.randomUUID().toString())), e("scriptContent", scriptContent), e("parameters", parameters), e("stdout", String.format("/tmp/%s", UUID.randomUUID().toString())), e("stderr", String.format("/tmp/%s", UUID.randomUUID().toString())) )); scriptCommand = createCommand(remoteScript); } catch (IOException e) { throw new RuntimeException(e); } } ScriptRunner(String script) { String remoteScript = ln( "/bin/bash << EOF", "cat << EOF1 > {remotePath}", "{scriptContent}", "EOF1", "/bin/bash {remotePath} 1>{stdout} 2>{stderr}", "ret=$?", "test -f {stdout} && cat {stdout}", "test -f {stderr} && cat {stderr} 1>&2", "rm -f {remotePath}", "rm -f {stdout}", "rm -f {stderr}", "exit $ret", "EOF" ).formatByMap(map(e("remotePath", String.format("/tmp/%s", UUID.randomUUID().toString())), e("scriptContent", script), e("stdout", String.format("/tmp/%s", UUID.randomUUID().toString())), e("stderr", String.format("/tmp/%s", UUID.randomUUID().toString())) )); scriptCommand = createCommand(remoteScript); } SshResult run() { return scriptCommand.run(); } void cleanup() { if (scriptFile != null) { scriptFile.delete(); } } } public String getHostname() { return hostname; } public Ssh setHostname(String hostname) { this.hostname = hostname; return this; } public String getUsername() { return username; } public Ssh setUsername(String username) { this.username = username; return this; } public String getPrivateKey() { return privateKey; } public Ssh setPrivateKey(String privateKey) { this.privateKey = privateKey; return this; } public String getPassword() { return password; } public Ssh setPassword(String password) { this.password = password; return this; } public int getPort() { return port; } public Ssh setPort(int port) { this.port = port; return this; } public boolean isSuppressException() { return suppressException; } public Ssh setSuppressException(boolean suppressException) { this.suppressException = suppressException; return this; } public Ssh command(String...cmds) { for (String cmd : cmds) { commands.add(createCommand(cmd)); } return this; } private SshRunner createCommand(final String cmd) { return new SshRunner() { @Override public SshResult run() { SshResult ret = new SshResult(); ret.setCommandToExecute(cmd); Session.Command sshCmd = null; try { Session session = null; try { session = ssh.startSession(); if (logger.isTraceEnabled()) { logger.trace(String.format("[start SSH] %s", cmd)); } sshCmd = session.exec(cmd); sshCmd.join(timeout, TimeUnit.SECONDS); String output = IOUtils.readFully(sshCmd.getInputStream()).toString(); String stderr = IOUtils.readFully(sshCmd.getErrorStream()).toString(); ret.setReturnCode(sshCmd.getExitStatus()); ret.setStderr(stderr); ret.setStdout(output); if (logger.isTraceEnabled()) { logger.trace(String.format("[end SSH] %s", cmd)); } } finally { if (session != null) { session.close(); } } } catch (Exception e) { if (e instanceof ConnectionException || e instanceof IOException || e instanceof TransportException) { ret.setSshFailure(true); } StringBuilder sb = new StringBuilder(String.format("exec ssh command: %s, exception\n", cmd)); sb.append(String.format("[host:%s, port:%s, user:%s, timeout:%s]\n", hostname, port, username, timeout)); if (!suppressException) { logger.warn(sb.toString(), e); } ret.setExitErrorMessage(e.getMessage()); ret.setReturnCode(1); } finally { if (sshCmd != null) { try { sshCmd.close(); } catch (Exception e) { logger.warn(String.format("failed close ssh channel for command[%s, host:%s, port:%s]", cmd, hostname, port), e); } } } return ret; } @Override public String getCommand() { return cmd; } }; } public Ssh scp(final String src, final String dst) { commands.add(createScpCommand(src, dst)); return this; } private SshRunner createScpCommand(final String src, final String dst) { return new SshRunner() { @Override public SshResult run() { SshResult ret = new SshResult(); String cmd = getCommand(); ret.setCommandToExecute(cmd); try { ssh.newSCPFileTransfer().upload(src, dst); if (logger.isTraceEnabled()) { logger.trace(String.format("[SCP done]: %s", cmd)); } ret.setReturnCode(0); } catch (IOException e) { if (!suppressException) { logger.warn(String.format("[SCP failed]: %s", cmd), e); } ret.setSshFailure(true); ret.setReturnCode(1); ret.setExitErrorMessage(e.getMessage()); } return ret; } @Override public String getCommand() { return String.format("scp -P %d %s %s@%s:%s", port, src, username, hostname, dst); } }; } public Ssh checkTool(String...toolNames) { String tool = StringUtils.join(Arrays.asList(toolNames), " "); String cmdstr = s("EXIT (){ echo \"$1\"; exit 1;}; cmds=\"{0}\"; for cmd in $cmds; do which $cmd >/dev/null 2>&1 || EXIT \"Not find command: $cmd\"; done").format(tool); return command(cmdstr); } public Ssh shell(String script) { DebugUtils.Assert(this.script==null, String.format("every Ssh object can only specify one script")); this.script = new ScriptRunner(script); return this; } public Ssh script(String scriptName, String parameters, Map token) { DebugUtils.Assert(script==null, String.format("every Ssh object can only specify one script")); script = new ScriptRunner(scriptName, parameters, token); return this; } public Ssh script(String scriptName, Map tokens) { return script(scriptName, null, tokens); } public Ssh script(String scriptName, String parameters) { return script(scriptName, parameters, null); } public Ssh script(String scriptName) { return script(scriptName, null, null); } private void build() throws IOException { if (init) { return; } ssh = new SSHClient(); ssh.addHostKeyVerifier(new HostKeyVerifier() { @Override public boolean verify(String arg0, int arg1, PublicKey arg2) { return true; } }); ssh.connect(hostname, port); if (privateKey != null) { privateKeyFile = File.createTempFile("zstack", "tmp"); FileUtils.writeStringToFile(privateKeyFile, privateKey); ssh.authPublickey(username, privateKeyFile.getAbsolutePath()); } else { ssh.authPassword(username, password); } init = true; } public void close() { if (closed) { return; } closed = true; try { ssh.disconnect(); if (privateKeyFile != null) { privateKeyFile.delete(); } if (script != null) { script.cleanup(); } } catch (IOException e) { StringBuilder sb = new StringBuilder(String.format("failed to close connection")); sb.append(String.format("[host:%s, port:%s, user:%s, timeout:%s]\n", hostname, port, username, timeout)); logger.warn(sb.toString(), e); } } public SshResult run() { if (closed) { throw new SshException("this Ssh instance has been closed, you can not call run() after close()"); } StopWatch watch = new StopWatch(); watch.start(); try { build(); if (commands.isEmpty() && script == null) { throw new IllegalArgumentException(String.format("no command or scp command or script specified")); } if (!commands.isEmpty() && script != null) { throw new IllegalArgumentException(String.format("you cannot use script with command or scp")); } if (privateKey == null && password == null) { throw new IllegalArgumentException(String.format("no password and private key specified")); } if (username == null) { throw new IllegalArgumentException(String.format("no username specified")); } if (hostname == null) { throw new IllegalArgumentException(String.format("no hostname specified")); } if (script != null) { if (logger.isTraceEnabled()) { logger.trace(String.format("run script remotely[ip: %s, port: %s]:\n%s\n", hostname, port, script.scriptContent)); } return script.run(); } else { SshResult ret = null; for (SshRunner runner : commands) { ret = runner.run(); if (ret.getReturnCode() != 0) { return ret; } } return ret; } } catch (IOException e) { StringBuilder sb = new StringBuilder(String.format("ssh exception\n")); sb.append(String.format("[host:%s, port:%s, user:%s, timeout:%s]\n", hostname, port, username, timeout)); if (!suppressException) { logger.warn(sb.toString(), e); } SshResult ret = new SshResult(); ret.setSshFailure(true); ret.setExitErrorMessage(e.getMessage()); ret.setReturnCode(1); return ret; } finally { watch.stop(); if (logger.isTraceEnabled()) { if (script != null) { logger.trace(String.format("execute script[%s], cost time:%s", script.scriptName, watch.getTime())); } else { String cmd = StringUtils.join(CollectionUtils.transformToList(commands, new Function<String, SshRunner>() { @Override public String call(SshRunner arg) { return arg.getCommand(); } }), ","); String info = s( "\nssh execution[host: {0}, port:{1}]\n", "command: {2}\n", "cost time: {3}ms\n" ).format(hostname, port, cmd, watch.getTime()); logger.trace(info); } } } } public int getTimeout() { return timeout; } public Ssh setTimeout(int timeout) { this.timeout = timeout; return this; } public Ssh reset() { commands = new ArrayList<SshRunner>(); return this; } public SshResult runAndClose() { SshResult ret = run(); close(); return ret; } public void runErrorByExceptionAndClose() { SshResult ret = run(); close(); ret.raiseExceptionIfFailed(); } public void runErrorByException() { SshResult ret = run(); try { ret.raiseExceptionIfFailed(); } catch (SshException e) { close(); throw e; } } }