/*
* Copyright 2014 Ricardo Lorenzo<unshakablespirit@gmail.com>
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package utils.ssh;
import com.jcraft.jsch.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import utils.file.IOStreamUtils;
import utils.security.SSHKey;
import utils.security.SSHKeyStore;
import java.io.*;
import java.util.AbstractMap;
import java.util.HashMap;
import java.util.Map;
import java.util.Properties;
/**
* Created by ricardolorenzo on 25/07/2014.
*/
public class SSHClient {
private static Logger log = LoggerFactory.getLogger(SSHClient.class);
private SSHKeyStore keyStore;
private JSch client;
private String host;
private Integer port;
private Session session;
private Map<String, Session> forwardSessions;
private Channel pipedChannel;
private Map<String, Channel> forwardPipedChannels;
private InputStream pipedStream;
private Map<String, InputStream> forwardPipedStreams;
private byte[] output;
private byte[] error;
public SSHClient(String host, int port) throws IOException {
JSch.setLogger(new SSHLogger());
forwardSessions = new HashMap<>();
forwardPipedChannels = new HashMap<>();
forwardPipedStreams = new HashMap<>();
client = new JSch();
try {
keyStore = new SSHKeyStore();
} catch(ClassNotFoundException e) {
throw new IOException(e.getMessage());
}
this.host = host;
this.port = port;
}
public void connect(String user) throws SSHException {
if(this.host == null || this.host.isEmpty()) {
throw new SSHException("ssh host not defined");
}
if(user == null || user.isEmpty()) {
throw new SSHException("ssh user not defined");
}
Session session;
try {
SSHKey key = keyStore.getKey(user);
client.addIdentity(user, key.getSSHPrivateKey().getBytes(), key.getSSHPublicKey(user).getBytes(), null);
if(user.contains("@")) {
user = user.substring(0, user.indexOf("@"));
}
session = client.getSession(user, this.host, this.port);
Properties config = new java.util.Properties();
config.put("StrictHostKeyChecking", "no");
session.setConfig(config);
session.connect();
} catch(JSchException e) {
log.error("connection error: " + e.getMessage());
throw new SSHException(e);
}
this.session = session;
}
public void forwardConnect(String host, String user, int port) throws SSHException {
if(host == null || host.isEmpty()) {
throw new SSHException("ssh host not defined");
}
if(user == null || user.isEmpty()) {
throw new SSHException("ssh user not defined");
}
Session session;
try {
SSHKey key = keyStore.getKey(user);
client.addIdentity(user, key.getSSHPrivateKey().getBytes(), key.getSSHPublicKey(user).getBytes(), null);
if(user.contains("@")) {
user = user.substring(0, user.indexOf("@"));
}
int assignedPort = this.session.setPortForwardingL(0, host, port);
session = this.client.getSession(user, "127.0.0.1", assignedPort);
Properties config = new java.util.Properties();
config.put("StrictHostKeyChecking", "no");
session.setConfig(config);
session.connect();
forwardSessions.put(host, session);
} catch(JSchException e) {
log.error("connection error: " + e.getMessage());
throw new SSHException(e);
}
this.session = session;
}
public void disconnect() {
this.forwardPipedChannels.clear();
this.forwardPipedStreams.clear();
for(Session session : this.forwardSessions.values()) {
if(session != null && session.isConnected()) {
session.disconnect();
}
}
if(this.session != null && this.session.isConnected()) {
this.session.disconnect();
}
}
public void forwardDisconnect(String host) {
this.forwardPipedChannels.remove(host);
this.forwardPipedStreams.remove(host);
Session session = this.forwardSessions.remove(host);
if(session != null && session.isConnected()) {
session.disconnect();
}
}
private static int verifyResponse(InputStream in) throws IOException {
/**
* 0 for success and 1 or 2 for error
*/
int b = in.read();
if(b == 0 || b == -1) {
return b;
}
if(b == 1 || b == 2) {
StringBuilder sb = new StringBuilder();
int c;
do {
c = in.read();
sb.append((char)c);
} while(c != '\n');
if(b == 1 || b == 2) {
throw new IOException("transfer file connection returned an error reply");
}
}
return b;
}
public byte[] getOutput() {
return this.output;
}
public String getStringOutput() {
if(this.output != null) {
return new String(this.output);
}
return null;
}
/**
* This method read the files from the server and load the content on memory.
*
* @param sourcePath
* @return
* @throws SSHException
*/
public Map<String, byte[]> getFiles(String sourcePath) throws SSHException {
return getFiles(this.session, sourcePath);
}
/**
*
* @param host
* @param sourcePath
* @return
* @throws SSHException
*/
public Map<String, byte[]> getForwardFiles(String host, String sourcePath) throws SSHException {
Session session = this.forwardSessions.get(host);
return getFiles(session, sourcePath);
}
private static Map<String, byte[]> getFiles(Session session, String sourcePath) throws SSHException {
if(session == null || !session.isConnected()) {
throw new SSHException("not connected, please connect first");
}
Map<String, byte[]> files = new HashMap<>();
try {
Channel channel = null;
OutputStream out = null;
InputStream in = null;
try {
StringBuilder sb = new StringBuilder();
sb.append("scp -f ");
sb.append(sourcePath);
channel = session.openChannel("exec");
ChannelExec.class.cast(channel).setCommand(sb.toString());
out = channel.getOutputStream();
in = channel.getInputStream();
channel.connect();
while(true) {
/**
* Send '\0'
*/
IOStreamUtils.write(new byte[] { 0 }, out);
int c = verifyResponse(in);
if(c != 'C') {
break;
}
String fileName = null;
long fileSize = 0L;
/**
* Read permissions (ex. '0644 ')
*/
byte[] data = IOStreamUtils.read(in, 5);
/**
* Read file size
*/
while(true) {
data = IOStreamUtils.read(in, 1);
if(data[0] == ' ') {
break;
}
fileSize = fileSize * 10L + Long.valueOf(data[0] - '0');
}
/**
* Read file name
*/
data = IOStreamUtils.readUntilDataIsFound(in, new byte[] { (byte) 0x0a }, (1024 * 1024) * 1L);
fileName = new String(data).trim();
/**
* Send '\0'
*/
IOStreamUtils.write(new byte[] { 0 }, out);
/**
* Read file content
*/
ByteArrayOutputStream fileContent = new ByteArrayOutputStream();
IOStreamUtils.write(in, fileContent, fileSize);
verifyResponse(in);
files.put(fileName, fileContent.toByteArray());
}
return files;
} finally {
if(in != null) {
in.close();
}
if(out != null) {
out.close();
}
if(channel != null) {
channel.disconnect();
}
}
} catch(IOException e) {
log.error("file send error: " + e.getMessage());
throw new SSHException(e);
} catch(JSchException e) {
log.error("file send error: " + e.getMessage());
throw new SSHException(e);
}
}
/**
*
* @param command
* @return
* @throws SSHException
*/
public int sendCommand(String... command) throws SSHException {
Map.Entry<Integer, byte[]> response = sendCommand(this.session, command);
this.output = response.getValue();
return response.getKey();
}
/**
*
* @param host
* @param command
* @return
* @throws SSHException
*/
public int sendForwardCommand(String host, String... command) throws SSHException {
Session session = this.forwardSessions.get(host);
Map.Entry<Integer, byte[]> response = sendCommand(session, command);
this.output = response.getValue();
return response.getKey();
}
private static Map.Entry<Integer, byte[]> sendCommand(Session session, String... command) throws SSHException {
if(session == null || !session.isConnected()) {
throw new SSHException("not connected, please connect first");
}
try {
Channel channel = null;
InputStream in = null;
ByteArrayOutputStream out = new ByteArrayOutputStream();
try {
StringBuilder sb = new StringBuilder();
for(String tok : command) {
if(sb.length() > 0) {
sb.append(" ");
}
sb.append(tok);
}
channel = session.openChannel("exec");
ChannelExec.class.cast(channel).setCommand(sb.toString());
channel.setOutputStream(null);
in = channel.getInputStream();
channel.connect();
byte[] buffer = new byte[1024];
while(true) {
for(int length = in.read(buffer, 0, buffer.length); length > 0; length = in.read(buffer, 0, buffer.length)) {
out.write(buffer, 0, length);
}
if(channel.isClosed()) {
if(in.available() > 0) {
continue;
}
break;
}
try {
Thread.sleep(250);
} catch(InterruptedException e) {}
}
return new AbstractMap.SimpleEntry<Integer, byte[]>(channel.getExitStatus(), out.toByteArray());
} finally {
if(in != null) {
in.close();
}
if(out != null) {
out.close();
}
if(channel != null) {
channel.disconnect();
}
}
} catch(IOException e) {
log.error("command error: " + e.getMessage());
throw new SSHException(e);
} catch(JSchException e) {
log.error("connection error: " + e.getMessage());
throw new SSHException(e);
}
}
/**
*
* @return
* @throws SSHException
* @throws IOException
*/
public String readPipedCommandOutputLine() throws SSHException, IOException {
if(this.pipedChannel != null && this.pipedChannel.isConnected() && this.pipedStream != null) {
return readLine(this.pipedStream);
}
return null;
}
/**
*
* @param host
* @return
* @throws SSHException
* @throws IOException
*/
public String readForwardPipedCommandOutputLine(String host) throws SSHException, IOException {
Channel channel = this.forwardPipedChannels.get(host);
if(channel == null) {
throw new SSHException("channel not available for " + host + ", not connected?");
}
InputStream stream = this.forwardPipedStreams.get(host);
if(stream == null) {
throw new SSHException("stream not available for " + host + ", not connected?");
}
if(channel.isConnected() || stream.available() > 0) {
return readLine(stream);
}
return null;
}
public static String readLine(InputStream is) throws IOException {
/**
* Read line, maximum 1M to avoid exhausting the memory
*/
final byte EOL = (byte) 0x0a;
final byte CR = (byte) 0x0d;
final Long MAX = (1024 * 1024) * 1L;
ByteArrayOutputStream out = new ByteArrayOutputStream();
while(true) {
int c = is.read();
if(c == -1) {
if(out.size() == 0) {
return null;
}
return new String(out.toByteArray());
}
if(c == EOL || c == CR) {
return new String(out.toByteArray());
}
out.write(c);
if(out.size() >= MAX) {
return new String(out.toByteArray());
}
}
}
public int getPipedCommandStatus() throws SSHException {
if(this.pipedChannel != null) {
throw new SSHException("channel not available, not connected?");
}
if(this.pipedChannel.isClosed()) {
return this.pipedChannel.getExitStatus();
}
return 0;
}
public int getForwardPipedCommandStatus(String host) throws SSHException {
Channel channel = this.forwardPipedChannels.get(host);
if(channel == null) {
throw new SSHException("channel not available for " + host + ", not connected?");
}
if(channel.isClosed()) {
return channel.getExitStatus();
}
return 0;
}
/**
*
* @param command
* @throws SSHException
*/
public void sendPipedCommand(String... command) throws SSHException {
this.pipedChannel = sendPipedCommand(this.session, command);
try {
this.pipedStream = this.pipedChannel.getInputStream();
} catch(IOException e) {
this.pipedChannel.disconnect();
throw new SSHException(e);
}
}
/**
*
* @param host
* @param command
* @throws SSHException
*/
public void sendForwardPipedCommand(String host, String... command) throws SSHException {
Session session = this.forwardSessions.get(host);
Channel channel = sendPipedCommand(session, command);
this.forwardPipedChannels.put(host, channel);
try {
this.forwardPipedStreams.put(host, channel.getInputStream());
} catch(IOException e) {
if(channel != null){
channel.disconnect();
}
throw new SSHException(e);
}
}
private static Channel sendPipedCommand(Session session, String... command) throws SSHException {
if(session == null || !session.isConnected()) {
throw new SSHException("not connected, please connect first");
}
try {
StringBuilder sb = new StringBuilder();
for(String tok : command) {
if(sb.length() > 0) {
sb.append(" ");
}
sb.append(tok);
}
Channel channel = session.openChannel("exec");
ChannelExec.class.cast(channel).setCommand(sb.toString());
channel.setOutputStream(null);
channel.connect();
return channel;
} catch(JSchException e) {
log.error("connection error: " + e.getMessage());
throw new SSHException(e);
}
}
public void terminatePipedCommand() {
if(this.pipedStream != null) {
try {
this.pipedStream.close();
} catch(IOException e) {}
}
terminatePipedCommand(this.pipedChannel);
}
public void terminateForwardPipedCommand(String host) {
InputStream stream = this.forwardPipedStreams.remove(host);
if(stream != null) {
try {
stream.close();
} catch(IOException e) {}
}
terminatePipedCommand(this.forwardPipedChannels.remove(host));
}
private static void terminatePipedCommand(Channel channel) {
if(channel != null && !channel.isClosed()) {
channel.disconnect();
}
}
/**
*
* @param file
* @param destinationPath
* @param permissions
* @throws SSHException
*/
public void sendFile(File file, String destinationPath, FilePermissions permissions) throws SSHException {
sendFile(this.session, file, destinationPath, permissions);
}
/**
*
* @param host
* @param file
* @param destinationPath
* @param permissions
* @throws SSHException
*/
public void sendForwardFile(String host, File file, String destinationPath, FilePermissions permissions)
throws SSHException {
Session session = this.forwardSessions.get(host);
sendFile(session, file, destinationPath, permissions);
}
private static void sendFile(Session session, File file, String destinationPath, FilePermissions permissions)
throws SSHException {
if(session == null || !session.isConnected()) {
throw new SSHException("not connected, please connect first");
}
try {
Channel channel = null;
OutputStream out = null;
InputStream in = null;
try {
/**
* remote execution of 'scp -t destinationPath' command
*/
StringBuilder sb = new StringBuilder();
sb.append("scp -p -t ");
sb.append(destinationPath);
channel = session.openChannel("exec");
ChannelExec.class.cast(channel).setCommand(sb.toString());
out = channel.getOutputStream();
in = channel.getInputStream();
channel.connect();
verifyResponse(in);
/**
* Send modified and access time
*/
/*
sb = new StringBuilder();
sb.append("T ");
sb.append(file.lastModified() / 1000);
sb.append(" 0");
sb.append(" ");
sb.append(file.lastModified() / 1000);
sb.append(" 0\n");
out.write(sb.toString().getBytes());
out.flush();
verifyResponse(in);
*/
sb = new StringBuilder();
sb.append("C0");
sb.append(permissions.getUserPermission());
sb.append(permissions.getGroupPermission());
sb.append(permissions.getAllPermission());
sb.append(" ");
sb.append(file.length());
sb.append(" ");
if(destinationPath.contains("/")) {
sb.append(destinationPath.substring(destinationPath.lastIndexOf("/") + 1));
} else {
sb.append(destinationPath);
}
sb.append("\n");
out.write(sb.toString().getBytes());
out.flush();
verifyResponse(in);
/**
* Send file
*/
FileInputStream sourceStream = new FileInputStream(file);
try {
IOStreamUtils.write(sourceStream, out);
} finally {
sourceStream.close();
}
/**
* Send the final '\0'
*/
IOStreamUtils.write(new byte[]{0}, out);
verifyResponse(in);
} finally {
if(in != null) {
in.close();
}
if(out != null) {
out.close();
}
if(channel != null) {
channel.disconnect();
}
}
} catch(IOException e) {
log.error("file send error: " + e.getMessage());
throw new SSHException(e);
} catch(JSchException e) {
log.error("file send error: " + e.getMessage());
throw new SSHException(e);
}
}
}
class SSHLogger implements com.jcraft.jsch.Logger {
private static Logger log = LoggerFactory.getLogger(SSHClient.class);
public boolean isEnabled(int level){
return true;
}
public void log(int level, String message){
switch(level) {
case INFO:
log.info(message);
break;
case WARN:
log.warn(message);
break;
case ERROR:
log.error(message);
break;
case FATAL:
log.error(message);
break;
default:
log.debug(message);
break;
}
}
}