package com.sohu.cache.ssh; import java.io.BufferedReader; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; import java.util.Arrays; import java.util.concurrent.Callable; import java.util.concurrent.Future; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import ch.ethz.ssh2.Connection; import ch.ethz.ssh2.SCPClient; import ch.ethz.ssh2.Session; import ch.ethz.ssh2.StreamGobbler; import com.google.common.util.concurrent.ThreadFactoryBuilder; import com.sohu.cache.exception.SSHException; import com.sohu.cache.util.ConstUtils; /** * SSH操作模板类 */ public class SSHTemplate { private static final Logger logger = LoggerFactory.getLogger(SSHTemplate.class); private static final int CONNCET_TIMEOUT = 6000; private static final int OP_TIMEOUT = 12000; private static ThreadPoolExecutor taskPool = new ThreadPoolExecutor( 200, 200, 0L, TimeUnit.MILLISECONDS, new LinkedBlockingQueue<Runnable>(1000), new ThreadFactoryBuilder().setNameFormat("SSH-%d").setDaemon(true).build()); public Result execute(String ip, SSHCallback callback) throws SSHException{ return execute(ip,ConstUtils.DEFAULT_SSH_PORT_DEFAULT, ConstUtils.USERNAME, ConstUtils.PASSWORD, callback); } /** * 通过回调执行命令 * @param ip * @param port * @param username * @param password * @param callback 可以使用Session执行多个命令 * @throws SSHException */ public Result execute(String ip, int port, String username, String password, SSHCallback callback) throws SSHException{ Connection conn = null; try { conn = getConnection(ip, port, username, password); return callback.call(new SSHSession(conn, ip+":"+port)); } catch (Exception e) { throw new SSHException("SSH err: " + e.getMessage(), e); } finally { close(conn); } } /** * 获取连接并校验 * @param ip * @param port * @param username * @param password * @return Connection * @throws Exception */ private Connection getConnection(String ip, int port, String username, String password) throws Exception { Connection conn = new Connection(ip, port); conn.connect(null, CONNCET_TIMEOUT, CONNCET_TIMEOUT); boolean isAuthenticated = conn.authenticateWithPassword(username, password); if (isAuthenticated == false) { throw new Exception("SSH authentication failed with [ userName: " + username + ", password: " + password + "]"); } return conn; } /** * 获取调用命令后的返回结果 * @param is 输入流 * @return 如果获取结果有异常或者无结果,那么返回null */ private String getResult(InputStream is) { final StringBuilder buffer = new StringBuilder(); LineProcessor lp = new DefaultLineProcessor() { public void process(String line, int lineNum) throws Exception { if(lineNum > 1) { buffer.append(System.lineSeparator()); } buffer.append(line); } }; processStream(is, lp); return buffer.length() > 0 ? buffer.toString() : null; } /** * 从流中获取内容 * @param is */ private void processStream(InputStream is, LineProcessor lineProcessor) { BufferedReader reader = null; try { reader = new BufferedReader(new InputStreamReader(new StreamGobbler(is))); String line = null; int lineNum = 1; while ((line = reader.readLine()) != null) { try { lineProcessor.process(line, lineNum); } catch (Exception e) { logger.error("err line:"+line, e); } lineNum++; } lineProcessor.finish(); } catch (IOException e) { logger.error(e.getMessage(), e); } finally { close(reader); } } private void close(BufferedReader read) { if (read != null) { try { read.close(); } catch (IOException e) { logger.error(e.getMessage(), e); } } } private void close(Connection conn) { if (conn != null) { try { conn.close(); } catch (Exception e) { logger.error(e.getMessage(), e); } } } private static void close(Session session) { if (session != null) { try { session.close(); } catch (Exception e) { logger.error(e.getMessage(), e); } } } /** * 可以调用多次executeCommand, 并返回结果 */ public class SSHSession{ private String address; private Connection conn; private SSHSession(Connection conn, String address) { this.conn = conn; this.address = address; } /** * 执行命令并返回结果,可以执行多次 * @param cmd * @return 执行成功Result为true,并携带返回信息,返回信息可能为null * 执行失败Result为false,并携带失败信息 * 执行异常Result为false,并携带异常 */ public Result executeCommand(String cmd) { return executeCommand(cmd, OP_TIMEOUT); } public Result executeCommand(String cmd, int timoutMillis) { return executeCommand(cmd, null, timoutMillis); } public Result executeCommand(String cmd, LineProcessor lineProcessor) { return executeCommand(cmd, lineProcessor, OP_TIMEOUT); } /** * 执行命令并返回结果,可以执行多次 * @param cmd * @param lineProcessor 回调处理行 * @return 如果lineProcessor不为null,那么永远返回Result.true */ public Result executeCommand(String cmd, LineProcessor lineProcessor, int timoutMillis) { Session session = null; try { session = conn.openSession(); return executeCommand(session, cmd, timoutMillis, lineProcessor); } catch (Exception e) { logger.error("execute ip:"+conn.getHostname()+" cmd:"+cmd, e); return new Result(e); } finally { close(session); } } public Result executeCommand(final Session session, final String cmd, final int timoutMillis, final LineProcessor lineProcessor) throws Exception{ Future<Result> future = taskPool.submit(new Callable<Result>() { public Result call() throws Exception { session.execCommand(cmd); //如果客户端需要进行行处理,则直接进行回调 if(lineProcessor != null) { processStream(session.getStdout(), lineProcessor); } else { //获取标准输出 String rst = getResult(session.getStdout()); if(rst != null) { return new Result(true, rst); } //返回为null代表可能有异常,需要检测标准错误输出,以便记录日志 Result errResult = tryLogError(session.getStderr(), cmd); if(errResult != null) { return errResult; } } return new Result(true, null); } }); Result rst = null; try { rst = future.get(timoutMillis, TimeUnit.MILLISECONDS); future.cancel(true); } catch (TimeoutException e) { logger.error("exec ip:{} {} timeout:{}", conn.getHostname(), cmd, timoutMillis); throw new SSHException(e); } return rst; } private Result tryLogError(InputStream is, String cmd) { String errInfo = getResult(is); if(errInfo != null) { logger.error("address "+address+" execute cmd:({}), err:{}", cmd, errInfo); return new Result(false, errInfo); } return null; } /** * Copy a set of local files to a remote directory, uses the specified mode when * creating the file on the remote side. * @param localFiles * Path and name of local file. * @param remoteFiles * name of remote file. * @param remoteTargetDirectory * Remote target directory. Use an empty string to specify the default directory. * @param mode * a four digit string (e.g., 0644, see "man chmod", "man open") * @throws IOException */ public Result scp(String[] localFiles, String[] remoteFiles, String remoteTargetDirectory, String mode) { try { SCPClient client = conn.createSCPClient(); client.put(localFiles, remoteFiles, remoteTargetDirectory, mode); return new Result(true); } catch (Exception e) { logger.error("scp local="+Arrays.toString(localFiles)+" to "+ remoteTargetDirectory+" remote="+Arrays.toString(remoteFiles)+" err", e); return new Result(e); } } public Result scpToDir(String localFile, String remoteTargetDirectory) { return scpToDir(localFile, remoteTargetDirectory, "0744"); } public Result scpToDir(String localFile, String remoteTargetDirectory, String mode) { return scp(new String[] { localFile }, null, remoteTargetDirectory, mode); } public Result scpToDir(String[] localFile, String remoteTargetDirectory) { return scp(localFile, null, remoteTargetDirectory, "0744"); } public Result scpToFile(String localFile, String remoteFile, String remoteTargetDirectory) { return scpToFile(localFile, remoteFile, remoteTargetDirectory, "0744"); } public Result scpToFile(String localFile, String remoteFile, String remoteTargetDirectory, String mode) { return scp(new String[] { localFile }, new String[] { remoteFile }, remoteTargetDirectory, "0744"); } } /** * 结果封装 */ public class Result{ private boolean success; private String result; private Exception excetion; public Result(boolean success) { this.success = success; } public Result(boolean success, String result) { this.success = success; this.result = result; } public Result(Exception excetion) { this.success = false; this.excetion = excetion; } public Exception getExcetion() { return excetion; } public void setExcetion(Exception excetion) { this.excetion = excetion; } public boolean isSuccess() { return success; } public void setSuccess(boolean success) { this.success = success; } public String getResult() { return result; } public void setResult(String result) { this.result = result; } @Override public String toString() { return "Result [success=" + success + ", result=" + result + ", excetion=" + excetion + "]"; } } /** * 执行命令回调 */ public interface SSHCallback{ /** * 执行回调 * @param session */ Result call(SSHSession session); } /** * 从流中直接解析数据 */ public static interface LineProcessor{ /** * 处理行 * @param line 内容 * @param lineNum 行号,从1开始 * @throws Exception */ void process(String line, int lineNum) throws Exception; /** * 所有的行处理完毕回调该方法 */ void finish(); } public static abstract class DefaultLineProcessor implements LineProcessor{ public void finish() {} } }