/*
* To change this license header, choose License Headers in Project Properties.
* To change this template file, choose Tools | Templates
* and open the template in the editor.
*/
package se.kth.karamel.backend.machines;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.SequenceInputStream;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.Semaphore;
import net.schmizz.sshj.SSHClient;
import net.schmizz.sshj.connection.channel.direct.PTYMode;
import net.schmizz.sshj.connection.channel.direct.Session;
import net.schmizz.sshj.connection.channel.direct.SessionChannel;
import net.schmizz.sshj.transport.verification.PromiscuousVerifier;
import net.schmizz.sshj.userauth.UserAuthException;
import net.schmizz.sshj.userauth.keyprovider.KeyProvider;
import net.schmizz.sshj.userauth.password.PasswordFinder;
import net.schmizz.sshj.userauth.password.Resource;
import se.kth.karamel.common.exception.KaramelException;
/**
*
* @author kamal
*/
public class SshShell {
private static final org.apache.log4j.Logger logger = org.apache.log4j.Logger.getLogger(SshShell.class);
private final String privateKey;
private final String publicKey;
private final String ipAddress;
private final String sshUser;
private final String passphrase;
private final int sshPort;
private SSHClient client = null;
private SessionChannel shell = null;
private Session session = null;
private InputStreamReader stdAll;
private SequenceInputStream sequenceInputStream;
private final Semaphore semaphor = new Semaphore(1);
private StringBuilder builder = new StringBuilder();
private Thread streamReader;
public SshShell(String privateKey, String publicKey, String ipAddress, String sshUser, int sshPort) {
this(privateKey, publicKey, ipAddress, sshUser, null, sshPort);
}
public SshShell(String privateKey, String publicKey, String ipAddress, String sshUser,
String passphrase, int sshPort) {
this.privateKey = privateKey;
this.publicKey = publicKey;
this.ipAddress = ipAddress;
this.sshUser = sshUser;
this.passphrase = passphrase;
this.sshPort = sshPort;
}
public String getIpAddress() {
return ipAddress;
}
private PasswordFinder getPasswordFinder() {
return new PasswordFinder() {
@Override
public char[] reqPassword(Resource<?> resource) {
return passphrase.toCharArray();
}
@Override
public boolean shouldRetry(Resource<?> resource) {
return false;
}
};
}
public void connect() throws KaramelException {
try {
if (isConnected()) {
disconnect();
}
client = new SSHClient();
client.addHostKeyVerifier(new PromiscuousVerifier());
KeyProvider keys;
if (passphrase == null) {
keys = client.loadKeys(privateKey, publicKey, null);
} else {
keys = client.loadKeys(privateKey, publicKey, getPasswordFinder());
}
client.connect(ipAddress, sshPort);
client.authPublickey(sshUser, keys);
session = client.startSession();
Map<PTYMode, Integer> modes = new HashMap<>();
session.allocatePTY("vt220", 160, 80, 0, 0, modes);
shell = (SessionChannel) session.startShell();
sequenceInputStream = new SequenceInputStream(shell.getInputStream(), shell.getErrorStream());
stdAll = new InputStreamReader(sequenceInputStream);
streamReader = new Thread() {
@Override
public void run() {
int c;
try {
while ((c = stdAll.read()) != -1) {
semaphor.acquire();
builder.append((char) c);
semaphor.release();
}
} catch (IOException | InterruptedException ex) {
logger.error("", ex);
}
}
};
streamReader.start();
exec("PS1=\""+sshUser+"@"+ipAddress+":~$\"\r");
} catch (UserAuthException ex) {
logger.error("", ex);
throw new KaramelException("Issue for using ssh keys, make sure you keypair is not password protected..", ex);
} catch (Exception ex) {
logger.error("", ex);
throw new KaramelException("Exception Occured", ex);
}
}
public void exec(String cmdStr) throws KaramelException {
try {
byte[] bytes = cmdStr.getBytes();
shell.getOutputStream().write(bytes);
shell.getOutputStream().flush();
} catch (Exception ex) {
logger.error("", ex);
throw new KaramelException("", ex);
}
}
public String readStreams() throws KaramelException {
try {
semaphor.acquire();
String s = builder.toString();
builder = new StringBuilder();
semaphor.release();
// if (s != null && !s.isEmpty()) {
// s = s.replace("\r", "\\r").replaceAll("\n", "\\n").trim();
// }
logger.info("shell output:\n" + s + "\n");
return s;
} catch (Exception ex) {
logger.error("", ex);
throw new KaramelException("Exception occured", ex);
}
}
public boolean isConnected() throws KaramelException {
try {
return client != null && client.isConnected() && session != null && session.isOpen() && shell != null &&
shell.isOpen();
} catch (Exception ex) {
logger.error("", ex);
throw new KaramelException("Exception occured", ex);
}
}
public void disconnect() throws KaramelException {
try {
if (streamReader != null && streamReader.isAlive()) {
streamReader.interrupt();
}
session.close();
client.disconnect();
} catch (IOException ex) {
logger.error("", ex);
throw new KaramelException("Exception occured", ex);
}
}
}