/*
* To change this license header, choose License Headers in Project Properties.
* To change this template file, choose Tools | Templates
* and open the template in the editor.
*/
package se.kth.karamel.backend.machines;
import java.io.File;
import java.io.IOException;
import java.io.SequenceInputStream;
import java.nio.file.Files;
import java.util.List;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.TimeUnit;
import net.schmizz.sshj.SSHClient;
import net.schmizz.sshj.connection.ConnectionException;
import net.schmizz.sshj.connection.channel.direct.Session;
import net.schmizz.sshj.transport.TransportException;
import net.schmizz.sshj.transport.verification.PromiscuousVerifier;
import net.schmizz.sshj.userauth.keyprovider.KeyProvider;
import org.apache.log4j.Logger;
import se.kth.karamel.backend.running.model.MachineRuntime;
import se.kth.karamel.backend.running.model.tasks.ShellCommand;
import se.kth.karamel.backend.running.model.tasks.Task;
import se.kth.karamel.backend.running.model.tasks.Task.Status;
import se.kth.karamel.common.util.Settings;
import se.kth.karamel.common.exception.KaramelException;
import org.bouncycastle.jce.provider.BouncyCastleProvider;
import java.security.Security;
import java.util.ArrayList;
import java.util.Arrays;
import net.schmizz.sshj.userauth.UserAuthException;
import net.schmizz.sshj.userauth.password.PasswordFinder;
import net.schmizz.sshj.userauth.password.Resource;
import net.schmizz.sshj.xfer.scp.SCPFileTransfer;
import se.kth.karamel.backend.ClusterService;
import se.kth.karamel.backend.LogService;
import se.kth.karamel.backend.running.model.ClusterRuntime;
import se.kth.karamel.backend.running.model.Failure;
import se.kth.karamel.backend.running.model.tasks.KillSessionTask;
import se.kth.karamel.backend.running.model.tasks.RunRecipeTask;
import se.kth.karamel.common.util.Confs;
import se.kth.karamel.common.util.IoUtils;
/**
*
* @author kamal
*/
public class SshMachine implements MachineInterface, Runnable {
static {
Security.addProvider(new BouncyCastleProvider());
}
private static final Logger logger = Logger.getLogger(SshMachine.class);
private final MachineRuntime machineEntity;
private final String serverPubKey;
private final String serverPrivateKey;
private final String passphrase;
private SSHClient client;
private long lastHeartbeat = 0;
private final BlockingQueue<Task> taskQueue = new ArrayBlockingQueue<>(Settings.MACHINES_TASKQUEUE_SIZE);
private boolean stopping = false;
private boolean killing = false;
private final SshShell shell;
private Task activeTask;
private boolean isSucceedTaskHistoryUpdated = false;
private final List<String> succeedTasksHistory = new ArrayList<>();
private static Confs confs = Confs.loadKaramelConfs();
/**
* This constructor is used for users with SSH keys protected by a password
*
* @param machineEntity
* @param serverPubKey
* @param serverPrivateKey
* @param passphrase
*/
public SshMachine(MachineRuntime machineEntity, String serverPubKey, String serverPrivateKey, String passphrase) {
this.machineEntity = machineEntity;
this.serverPubKey = serverPubKey;
this.serverPrivateKey = serverPrivateKey;
this.passphrase = passphrase;
this.shell = new SshShell(serverPrivateKey, serverPubKey, machineEntity.getPublicIp(),
machineEntity.getSshUser(), passphrase, machineEntity.getSshPort());
}
public MachineRuntime getMachineEntity() {
return machineEntity;
}
public SshShell getShell() {
return shell;
}
public void setStopping(boolean stopping) {
this.stopping = stopping;
}
public void pause() {
if (anyFailure() && machineEntity.getTasksStatus().ordinal() < MachineRuntime.TasksStatus.PAUSING.ordinal()) {
machineEntity.setTasksStatus(MachineRuntime.TasksStatus.PAUSING, null, null);
}
}
public void resume() {
if (!anyFailure()) {
if (taskQueue.isEmpty()) {
machineEntity.setTasksStatus(MachineRuntime.TasksStatus.EMPTY, null, null);
} else {
machineEntity.setTasksStatus(MachineRuntime.TasksStatus.ONGOING, null, null);
}
}
}
private boolean anyFailure() {
boolean anyfailure = false;
if (machineEntity.getTasksStatus() == MachineRuntime.TasksStatus.FAILED) {
for (Task task : machineEntity.getTasks()) {
if (task.getStatus() == Task.Status.FAILED) {
anyfailure = true;
}
}
}
return anyfailure;
}
@Override
public void run() {
logger.debug(String.format("%s: Started SSH_Machine d'-'", machineEntity.getId()));
try {
while (!stopping) {
try {
if (machineEntity.getLifeStatus() == MachineRuntime.LifeStatus.CONNECTED
&& (machineEntity.getTasksStatus() == MachineRuntime.TasksStatus.ONGOING
|| machineEntity.getTasksStatus() == MachineRuntime.TasksStatus.EMPTY)) {
try {
if (activeTask == null) {
if (taskQueue.isEmpty()) {
machineEntity.setTasksStatus(MachineRuntime.TasksStatus.EMPTY, null, null);
}
activeTask = taskQueue.take();
logger.debug(String.format("%s: Taking a new task from the queue.", machineEntity.getId()));
machineEntity.setTasksStatus(MachineRuntime.TasksStatus.ONGOING, null, null);
} else {
logger.debug(
String.format("%s: Retrying a task that didn't complete on last execution attempt.",
machineEntity.getId()));
}
logger.debug(String.format("%s: Task for execution.. '%s'", machineEntity.getId(), activeTask.getName()));
runTask(activeTask);
} catch (InterruptedException ex) {
if (stopping) {
logger.debug(String.format("%s: Stopping SSH_Machine", machineEntity.getId()));
return;
} else {
logger.error(
String.format("%s: Got interrupted without having recieved stopping signal",
machineEntity.getId()));
}
}
} else {
if (machineEntity.getTasksStatus() == MachineRuntime.TasksStatus.PAUSING) {
machineEntity.setTasksStatus(MachineRuntime.TasksStatus.PAUSED, null, null);
}
try {
Thread.sleep(Settings.MACHINE_TASKRUNNER_BUSYWAITING_INTERVALS);
} catch (InterruptedException ex) {
if (!stopping) {
logger.error(
String.format("%s: Got interrupted without having recieved stopping signal",
machineEntity.getId()));
}
}
}
} catch (Exception e) {
logger.error(String.format("%s: ", machineEntity.getId()), e);
}
}
} finally {
disconnect();
}
}
public void enqueue(Task task) throws KaramelException {
logger.debug(String.format("%s: Queuing '%s'", machineEntity.getId(), task.toString()));
try {
taskQueue.put(task);
task.queued();
} catch (InterruptedException ex) {
String message = String.format("%s: Couldn't queue task '%s'", machineEntity.getId(), task.getName());
task.failed(message);
throw new KaramelException(message, ex);
}
}
public void remove(Task task) throws KaramelException {
logger.debug(String.format("%s: De-queuing '%s'", machineEntity.getId(), task.toString()));
taskQueue.remove(task);
if (activeTask == task) {
activeTask = null;
}
}
public void killTaskSession(Task task) {
if (activeTask == task) {
logger.info(String.format("Killing '%s' on '%s'", task.getName(), task.getMachine().getPublicIp()));
KillSessionTask killTask = new KillSessionTask(machineEntity);
killing = true;
runTask(killTask);
} else {
logger.warn(String.format("Request to kill '%s' on '%s' but the task is not ongoing now", task.getName(),
task.getMachine().getPublicIp()));
}
}
public void retryFailedTask(Task task) throws KaramelException {
if (task.getStatus() == Status.FAILED) {
logger.info(String.format("Retrying '%s' on '%s'", task.getName(), task.getMachine().getPublicIp()));
machineEntity.getGroup().getCluster().resolveFailure(Failure.hash(Failure.Type.TASK_FAILED, task.getUuid()));
task.retried();
activeTask = task;
resume();
} else {
String msg = String.format("Impossible to retry '%s' on '%s' because the task is not failed", task.getName(),
task.getMachineId());
logger.error(msg);
throw new KaramelException(msg);
}
}
public void skipFailedTask(Task task) throws KaramelException {
if (task.getStatus() == Status.FAILED) {
logger.info(String.format("Skipping '%s' on '%s'", task.getName(), task.getMachine().getPublicIp()));
machineEntity.getGroup().getCluster().resolveFailure(Failure.hash(Failure.Type.TASK_FAILED, task.getUuid()));
task.skipped();
if (activeTask == task) {
activeTask = null;
}
resume();
} else {
String msg = String.format("Impossible to skip '%s' on '%s' because the task is not failed", task.getName(),
task.getMachineId());
logger.error(msg);
throw new KaramelException(msg);
}
}
private void runTask(Task task) {
logger.debug("start running " + task.getId());
if (!isSucceedTaskHistoryUpdated) {
logger.debug("updating the task history");
loadSucceedListFromMachineToMemory();
logger.debug("the taks history was updated");
isSucceedTaskHistoryUpdated = true;
}
String skipConf = confs.getProperty(Settings.SKIP_EXISTINGTASKS_KEY);
if (skipConf != null && skipConf.equalsIgnoreCase("true")
&& task.isIdempotent() && succeedTasksHistory.contains(task.getId())) {
task.exists();
logger.info(String.format("Task skipped due to idempotency '%s'", task.getId()));
if (!(task instanceof KillSessionTask)) {
activeTask = null;
}
} else {
logger.debug(String.format("task '%s' was not found in the task history, running it", task.getId()));
try {
task.started();
List<ShellCommand> commands = task.getCommands();
logger.debug(String.format("task %s has %d commands to run", task.getId(), commands.size()));
for (ShellCommand cmd : commands) {
if (cmd.getStatus() != ShellCommand.Status.DONE) {
logger.debug(String.format("command to run %s", cmd.getCmdStr()));
runSshCmd(cmd, task, false);
if (cmd.getStatus() != ShellCommand.Status.DONE) {
task.failed(String.format("%s: Command did not complete: %s", machineEntity.getId(),
cmd.getCmdStr()));
break;
} else {
try {
task.collectResults(this);
if (task instanceof RunRecipeTask) {
// If this task is an experiment, try and download the experiment results
// In contrast with 'collectResults' - the results will not necessarly be json objects,
// they could be anything - but will be stored in a single file in /tmp/cookbook_recipe.out .
if (cmd.getCmdStr().contains("experiment") && cmd.getCmdStr().contains("json")) {
task.downloadExperimentResults(this);
}
}
} catch (KaramelException ex) {
logger.error(String.format("%s: Error in collecting/downloading the results", machineEntity.getId()),
ex);
task.failed(ex.getMessage());
}
}
} else {
logger.debug(String.format("skiping this command, status is %s", cmd.getStatus().toString()));
}
}
if (task.getStatus() == Status.ONGOING) {
if (!(task instanceof KillSessionTask)) {
task.succeed();
succeedTasksHistory.add(task.uniqueId());
activeTask = null;
}
}
} catch (Exception ex) {
logger.debug(String.format("failing the task because of the exception %s", ex.getMessage()), ex);
task.failed(ex.getMessage());
}
}
}
private void runSshCmd(ShellCommand shellCommand, Task task, boolean killcommand) {
logger.debug(String.format("recieved a command to run '%s'", shellCommand.getCmdStr()));
int numCmdRetries = Settings.SSH_CMD_RETRY_NUM;
int timeBetweenRetries = Settings.SSH_CMD_RETRY_INTERVALS;
boolean finished = false;
Session session = null;
while (!stopping && !killing && !finished && numCmdRetries > 0) {
shellCommand.setStatus(ShellCommand.Status.ONGOING);
try {
logger.info(String.format("%s: Running task: %s", machineEntity.getId(), task.getName()));
logger.debug(String.format("%s: running: %s", machineEntity.getId(), shellCommand.getCmdStr()));
//there is no harm of retrying to start session several times for running a command
int numSessionRetries = Settings.SSH_SESSION_RETRY_NUM;
while (numSessionRetries > 0) {
try {
session = client.startSession();
if (task.isSudoTerminalReqd()) {
session.allocateDefaultPTY();
}
numSessionRetries = -1;
} catch (ConnectionException | TransportException ex) {
logger.warn(String.format("%s: Couldn't start ssh session, will retry", machineEntity.getId()), ex);
numSessionRetries--;
if (numSessionRetries == -1) {
logger.error(String.format("%s: Exhasuted retrying to start a ssh session", machineEntity.getId()));
return;
}
//make sure to relese the session in case of exception to avoid to many session leak problem
if (session != null) {
try {
session.close();
} catch (TransportException | ConnectionException ex2) {
logger.error(String.format("Couldn't close ssh session to '%s' ", machineEntity.getId()), ex);
}
}
try {
Thread.sleep(timeBetweenRetries);
} catch (InterruptedException ex3) {
if (!stopping && !killing) {
logger.warn(String.format("%s: Interrupted while waiting to start ssh session. Continuing...",
machineEntity.getId()));
}
}
}
}
Session.Command cmd = null;
try {
String cmdStr = shellCommand.getCmdStr();
String password = ClusterService.getInstance().getCommonContext().getSudoAccountPassword();
if (password != null && !password.isEmpty()) {
cmd = session.exec(cmdStr.replaceAll("%password_hidden%", password));
} else {
cmd = session.exec(cmdStr);
}
cmd.join(Settings.SSH_CMD_MAX_TIOMEOUT, TimeUnit.MINUTES);
updateHeartbeat();
if (cmd.getExitStatus() != 0) {
shellCommand.setStatus(ShellCommand.Status.FAILED);
// Retry just in case there was a network problem somewhere on the server side
} else {
shellCommand.setStatus(ShellCommand.Status.DONE);
finished = true;
}
SequenceInputStream sequenceInputStream = new SequenceInputStream(cmd.getInputStream(), cmd.getErrorStream());
LogService.serializeTaskLog(task, machineEntity.getPublicIp(), sequenceInputStream);
} catch (ConnectionException | TransportException ex) {
if (!killing
&& getMachineEntity().getGroup().getCluster().getPhase() != ClusterRuntime.ClusterPhases.TERMINATING) {
logger.error(String.format("%s: Couldn't excecute command", machineEntity.getId()), ex);
}
if (killing) {
logger.info(String.format("Killed '%s' on '%s' successfully...", task.getName(), machineEntity.getId()));
}
}
} finally {
// Retry if we have a network problem
numCmdRetries--;
if (!finished) {
try {
Thread.sleep(timeBetweenRetries);
} catch (InterruptedException ex) {
if (!stopping && !killing) {
logger.warn(
String.format("%s: Interrupted waiting to retry a command. Continuing...", machineEntity.getId()));
}
}
timeBetweenRetries *= Settings.SSH_CMD_RETRY_SCALE;
}
//regardless of sucess or fail we must release the session in each iteration of retrying the command
if (session != null) {
try {
session.close();
} catch (TransportException | ConnectionException ex) {
logger.error(String.format("Couldn't close ssh session to '%s' ", machineEntity.getId()), ex);
}
}
killing = false;
}
}
}
private PasswordFinder getPasswordFinder() {
return new PasswordFinder() {
@Override
public char[] reqPassword(Resource<?> resource) {
return passphrase.toCharArray();
}
@Override
public boolean shouldRetry(Resource<?> resource) {
return false;
}
};
}
private void connect() throws KaramelException {
if (client == null || !client.isConnected()) {
isSucceedTaskHistoryUpdated = false;
try {
KeyProvider keys;
client = new SSHClient();
client.addHostKeyVerifier(new PromiscuousVerifier());
client.setConnectTimeout(Settings.SSH_CONNECTION_TIMEOUT);
client.setTimeout(Settings.SSH_SESSION_TIMEOUT);
keys = (passphrase == null || passphrase.isEmpty())
? client.loadKeys(serverPrivateKey, serverPubKey, null)
: client.loadKeys(serverPrivateKey, serverPubKey, getPasswordFinder());
logger.info(String.format("%s: connecting ...", machineEntity.getId()));
int numRetries = 3;
int timeBetweenRetries = 2000;
float scaleRetryTimeout = 1.0f;
boolean succeeded = false;
while (!succeeded && numRetries > 0) {
try {
client.connect(machineEntity.getPublicIp(), machineEntity.getSshPort());
} catch (IOException ex) {
logger.warn(String.format("%s: Opps!! coudln' t connect :@", machineEntity.getId()));
if (passphrase != null && passphrase.isEmpty() == false) {
if (numRetries > 1) {
logger.warn(String.format("%s: Could be a slow network, will retry. ", machineEntity.getId()));
} else {
logger.warn(String.format("%s: Could be a network problem. But if your network is fine,"
+ "then you have probably entered an incorrect the passphrase for your private key.",
machineEntity.getId()));
}
}
logger.debug(ex);
}
if (client.isConnected()) {
succeeded = true;
logger.info(String.format("%s: Yey!! connected ^-^", machineEntity.getId()));
machineEntity.getGroup().getCluster().resolveFailure(Failure.hash(Failure.Type.SSH_KEY_NOT_AUTH,
machineEntity.getPublicIp()));
client.authPublickey(machineEntity.getSshUser(), keys);
machineEntity.setLifeStatus(MachineRuntime.LifeStatus.CONNECTED);
return;
} else {
machineEntity.setLifeStatus(MachineRuntime.LifeStatus.UNREACHABLE);
}
numRetries--;
if (!succeeded) {
try {
Thread.sleep(timeBetweenRetries);
} catch (InterruptedException ex) {
logger.error(String.format(""), ex);
}
timeBetweenRetries *= scaleRetryTimeout;
}
}
if (!succeeded) {
String message = String.format("%s: Exhausted retry for ssh connection, is the port '%d' open?",
machineEntity.getId(), machineEntity.getSshPort());
if (passphrase != null && !passphrase.isEmpty()) {
message += " or is the passphrase for your private key correct?";
}
logger.error(message);
}
} catch (UserAuthException ex) {
String message = String.format("%s: Authentication problem using ssh keys.", machineEntity.getId());
if (passphrase != null && !passphrase.isEmpty()) {
message = message + " Is the passphrase for your private key correct?";
}
KaramelException exp = new KaramelException(message, ex);
machineEntity.getGroup().getCluster().issueFailure(new Failure(Failure.Type.SSH_KEY_NOT_AUTH,
machineEntity.getPublicIp(), message));
throw exp;
} catch (IOException e) {
throw new KaramelException(e);
}
return;
}
}
public void disconnect() {
logger.info(String.format("%s: Closing ssh session", machineEntity.getId()));
try {
if (client != null && client.isConnected()) {
client.close();
}
} catch (IOException ex) {
}
}
public void ping() throws KaramelException {
if (lastHeartbeat < System.currentTimeMillis() - Settings.SSH_PING_INTERVAL) {
if (client != null && client.isConnected()) {
updateHeartbeat();
} else {
connect();
}
}
}
private void updateHeartbeat() {
lastHeartbeat = System.currentTimeMillis();
}
//ssh machine maintains the list of succeed tasks synced with the remote machine, it downloads it just if the ssh
//connection is lost
private void loadSucceedListFromMachineToMemory() {
logger.debug(String.format("Loading succeeded tasklist from %s", machineEntity.getPublicIp()));
String clusterName = machineEntity.getGroup().getCluster().getName().toLowerCase();
String remoteSucceedPath = Settings.REMOTE_SUCCEEDTASKS_PATH(machineEntity.getSshUser());
String localSucceedPath = Settings.MACHINE_SUCCEEDTASKS_PATH(clusterName, machineEntity.getPublicIp());
File localFile = new File(localSucceedPath);
try {
Files.deleteIfExists(localFile.toPath());
} catch (IOException ex) {
}
try {
downloadRemoteFile(remoteSucceedPath, localSucceedPath, true);
} catch (IOException ex) {
logger.info(String.format("Succeeded tasklist does not exist on %s", machineEntity.getPublicIp()));
//remote file does not exists
} catch (KaramelException ex) {
//shoudn't throw this because I am deleting the local file already here
} finally {
try {
String list = IoUtils.readContentFromPath(localSucceedPath);
String[] items = list.split("\n");
succeedTasksHistory.clear();
succeedTasksHistory.addAll(Arrays.asList(items));
} catch (IOException ex) {
//local file does not exists, list is considered to be empty
succeedTasksHistory.clear();
}
}
}
@Override
public void downloadRemoteFile(String remoteFilePath, String localFilePath, boolean overwrite)
throws KaramelException, IOException {
connect();
SCPFileTransfer scp = client.newSCPFileTransfer();
File f = new File(localFilePath);
f.mkdirs();
// Don't collect logs of values, just overwrite
if (f.exists()) {
if (overwrite) {
f.delete();
} else {
throw new KaramelException(String.format("%s: Local file already exist %s",
machineEntity.getId(), localFilePath));
}
}
// If the file doesn't exist, it should quickly throw an IOException
scp.download(remoteFilePath, localFilePath);
}
}