/** * * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.apache.airavata.gfac.impl; import com.jcraft.jsch.ChannelExec; import com.jcraft.jsch.JSch; import com.jcraft.jsch.JSchException; import com.jcraft.jsch.Session; import com.jcraft.jsch.UserInfo; import org.apache.airavata.common.exception.AiravataException; import org.apache.airavata.gfac.core.GFacException; import org.apache.airavata.gfac.core.JobManagerConfiguration; import org.apache.airavata.gfac.core.authentication.AuthenticationInfo; import org.apache.airavata.gfac.core.authentication.SSHKeyAuthentication; import org.apache.airavata.gfac.core.cluster.AbstractRemoteCluster; import org.apache.airavata.gfac.core.cluster.CommandInfo; import org.apache.airavata.gfac.core.cluster.CommandOutput; import org.apache.airavata.gfac.core.cluster.JobSubmissionOutput; import org.apache.airavata.gfac.core.cluster.RawCommandInfo; import org.apache.airavata.gfac.core.cluster.ServerInfo; import org.apache.airavata.model.status.JobStatus; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.File; import java.io.IOException; import java.util.List; import java.util.Map; import java.util.UUID; /** * One Remote cluster instance for each compute resource. */ public class HPCRemoteCluster extends AbstractRemoteCluster{ private static final Logger log = LoggerFactory.getLogger(HPCRemoteCluster.class); private static final int MAX_RETRY_COUNT = 3; private final SSHKeyAuthentication authentication; private final JSch jSch; public HPCRemoteCluster(ServerInfo serverInfo, JobManagerConfiguration jobManagerConfiguration, AuthenticationInfo authenticationInfo) throws AiravataException, GFacException { super(serverInfo, jobManagerConfiguration, authenticationInfo); try { if (authenticationInfo instanceof SSHKeyAuthentication) { authentication = (SSHKeyAuthentication) authenticationInfo; } else { throw new AiravataException("Support ssh key authentication only"); } jSch = new JSch(); jSch.addIdentity(UUID.randomUUID().toString(), authentication.getPrivateKey(), authentication.getPublicKey(), authentication.getPassphrase().getBytes()); } catch (JSchException e) { throw new AiravataException("JSch initialization error ", e); } } private Session getOpenSession() throws JSchException { Session newSession = jSch.getSession(serverInfo.getUserName(), serverInfo.getHost(), serverInfo.getPort()); newSession.setUserInfo(new DefaultUserInfo(serverInfo.getUserName(), null, authentication.getPassphrase())); if (authentication.getStrictHostKeyChecking().equals("yes")) { jSch.setKnownHosts(authentication.getKnownHostsFilePath()); } else { newSession.setConfig("StrictHostKeyChecking", "no"); } newSession.connect(); // 0 connection timeout return newSession; } @Override public JobSubmissionOutput submitBatchJob(String jobScriptFilePath, String workingDirectory) throws GFacException { JobSubmissionOutput jsoutput = new JobSubmissionOutput(); copyTo(jobScriptFilePath, workingDirectory); // scp script file to working directory RawCommandInfo submitCommand = jobManagerConfiguration.getSubmitCommand(workingDirectory, jobScriptFilePath); submitCommand.setRawCommand("cd " + workingDirectory + "; " + submitCommand.getRawCommand()); StandardOutReader reader = new StandardOutReader(); executeCommand(submitCommand, reader); // throwExceptionOnError(reader, submitCommand); jsoutput.setJobId(outputParser.parseJobSubmission(reader.getStdOutputString())); if (jsoutput.getJobId() == null) { if (outputParser.isJobSubmissionFailed(reader.getStdOutputString())) { jsoutput.setJobSubmissionFailed(true); jsoutput.setFailureReason("stdout : " + reader.getStdOutputString() + "\n stderr : " + reader.getStdErrorString()); } } jsoutput.setExitCode(reader.getExitCode()); if (jsoutput.getExitCode() != 0) { jsoutput.setJobSubmissionFailed(true); jsoutput.setFailureReason("stdout : " + reader.getStdOutputString() + "\n stderr : " + reader.getStdErrorString()); } jsoutput.setStdOut(reader.getStdOutputString()); jsoutput.setStdErr(reader.getStdErrorString()); return jsoutput; } @Override public void copyTo(String localFile, String remoteFile) throws GFacException { int retry = 3; while (retry > 0) { try { log.info("Transferring localhost:" + localFile + " to " + serverInfo.getHost() + ":" + remoteFile); SSHUtils.scpTo(localFile, remoteFile, getSshSession()); retry = 0; } catch (Exception e) { retry--; if (retry == 0) { throw new GFacException("Failed to scp localhost:" + localFile + " to " + serverInfo.getHost() + ":" + remoteFile, e); } else { log.info("Retry transfer localhost:" + localFile + " to " + serverInfo.getHost() + ":" + remoteFile); } } } } private Session getSshSession() throws GFacException { return Factory.getSSHSession(authenticationInfo, serverInfo); } @Override public void copyFrom(String remoteFile, String localFile) throws GFacException { int retry = 3; while(retry>0) { try { log.info("Transferring " + serverInfo.getHost() + ":" + remoteFile + " To localhost:" + localFile); SSHUtils.scpFrom(remoteFile, localFile, getSession()); retry=0; } catch (Exception e) { retry--; if (retry == 0) { throw new GFacException("Failed to scp " + serverInfo.getHost() + ":" + remoteFile + " to " + "localhost:" + localFile, e); } else { log.info("Retry transfer " + serverInfo.getHost() + ":" + remoteFile + " to localhost:" + localFile); } } } } @Override public void scpThirdParty(String sourceFile, String destinationFile, Session clientSession, DIRECTION direction, boolean ignoreEmptyFile) throws GFacException { int retryCount= 0; try { while (retryCount < MAX_RETRY_COUNT) { retryCount++; log.info("Transferring from:" + sourceFile + " To: " + destinationFile); try { if (direction == DIRECTION.FROM) { SSHUtils.scpThirdParty(sourceFile, getSession(), destinationFile, clientSession, ignoreEmptyFile); } else { SSHUtils.scpThirdParty(sourceFile, clientSession, destinationFile, getSession(), ignoreEmptyFile); } break; // exit while loop } catch (JSchException e) { if (retryCount == MAX_RETRY_COUNT) { log.error("Retry count " + MAX_RETRY_COUNT + " exceeded for transferring from:" + sourceFile + " To: " + destinationFile, e); throw e; } log.error("Issue with jsch, Retry transferring from:" + sourceFile + " To: " + destinationFile, e); } } } catch (IOException | JSchException e) { throw new GFacException("Failed scp file:" + sourceFile + " to remote file " +destinationFile , e); } } /** * This method can be used to get the file name of a file giving the extension. It assumes that there will be only * one file with that extension. In case if there are more than one file one random file name from the matching ones * will be returned. * * @param fileExtension * @param parentPath * @param session * @return */ @Override public String getFileNameFromExtension(String fileExtension, String parentPath, Session session) throws GFacException { try { List<String> fileNames = SSHUtils.listDirectory(parentPath, session); for(String fileName : fileNames){ if(fileName.endsWith(fileExtension)){ return fileName; } } log.warn("No matching file found for extension: " + fileExtension + " in the " + parentPath + " directory"); return null; } catch (Exception e) { e.printStackTrace(); throw new GFacException("Failed to list directory " + parentPath); } } @Override public void makeDirectory(String directoryPath) throws GFacException { int retryCount = 0; try { while (retryCount < MAX_RETRY_COUNT) { retryCount++; log.info("Creating directory: " + serverInfo.getHost() + ":" + directoryPath); try { SSHUtils.makeDirectory(directoryPath, getSession()); break; // Exit while loop } catch (JSchException e) { if (retryCount == MAX_RETRY_COUNT) { log.error("Retry count " + MAX_RETRY_COUNT + " exceeded for creating directory: " + serverInfo.getHost() + ":" + directoryPath, e); throw e; } log.error("Issue with jsch, Retry creating directory: " + serverInfo.getHost() + ":" + directoryPath); } } } catch (JSchException | IOException e) { throw new GFacException("Failed to create directory " + serverInfo.getHost() + ":" + directoryPath, e); } } @Override public JobStatus cancelJob(String jobId) throws GFacException { JobStatus oldStatus = getJobStatus(jobId); RawCommandInfo cancelCommand = jobManagerConfiguration.getCancelCommand(jobId); StandardOutReader reader = new StandardOutReader(); executeCommand(cancelCommand, reader); throwExceptionOnError(reader, cancelCommand); return oldStatus; } @Override public JobStatus getJobStatus(String jobId) throws GFacException { RawCommandInfo monitorCommand = jobManagerConfiguration.getMonitorCommand(jobId); StandardOutReader reader = new StandardOutReader(); executeCommand(monitorCommand, reader); throwExceptionOnError(reader, monitorCommand); return outputParser.parseJobStatus(jobId, reader.getStdOutputString()); } @Override public String getJobIdByJobName(String jobName, String userName) throws GFacException { RawCommandInfo jobIdMonitorCommand = jobManagerConfiguration.getJobIdMonitorCommand(jobName, userName); StandardOutReader reader = new StandardOutReader(); executeCommand(jobIdMonitorCommand, reader); throwExceptionOnError(reader, jobIdMonitorCommand); return outputParser.parseJobId(jobName, reader.getStdOutputString()); } @Override public void getJobStatuses(String userName, Map<String, JobStatus> jobStatusMap) throws GFacException { RawCommandInfo userBasedMonitorCommand = jobManagerConfiguration.getUserBasedMonitorCommand(userName); StandardOutReader reader = new StandardOutReader(); executeCommand(userBasedMonitorCommand, reader); throwExceptionOnError(reader, userBasedMonitorCommand); outputParser.parseJobStatuses(userName, jobStatusMap, reader.getStdOutputString()); } @Override public List<String> listDirectory(String directoryPath) throws GFacException { try { log.info("Creating directory: " + serverInfo.getHost() + ":" + directoryPath); return SSHUtils.listDirectory(directoryPath, getSession()); } catch (JSchException | IOException e) { throw new GFacException("Failed to list directory " + serverInfo.getHost() + ":" + directoryPath, e); } } @Override public boolean execute(CommandInfo commandInfo) throws GFacException { StandardOutReader reader = new StandardOutReader(); executeCommand(commandInfo, reader); return true; } @Override public Session getSession() throws GFacException { return getSshSession(); } @Override public void disconnect() throws GFacException { Factory.disconnectSSHSession(serverInfo); } /** * This method return <code>true</code> if there is an error in standard output. If not return <code>false</code> * * @param reader - command output reader * @param submitCommand - command which executed in remote machine. * @return command has return error or not. */ private void throwExceptionOnError(StandardOutReader reader, RawCommandInfo submitCommand) throws GFacException { String stdErrorString = reader.getStdErrorString(); String command = submitCommand.getCommand().substring(submitCommand.getCommand().lastIndexOf(File.separator) + 1); if (stdErrorString == null) { // noting to do } else if ((stdErrorString.contains(command.trim()) && !stdErrorString.contains("Warning")) || stdErrorString .contains("error")) { log.error("Command {} , Standard Error output {}", command, stdErrorString); throw new GFacException("Error running command " + command + " on remote cluster. StandardError: " + stdErrorString); } } private void executeCommand(CommandInfo commandInfo, CommandOutput commandOutput) throws GFacException { String command = commandInfo.getCommand(); int retryCount = 0; ChannelExec channelExec = null; try { while (retryCount < MAX_RETRY_COUNT) { retryCount++; try { Session session = getSshSession(); channelExec = ((ChannelExec) session.openChannel("exec")); channelExec.setCommand(command); channelExec.setInputStream(null); channelExec.setErrStream(commandOutput.getStandardError()); channelExec.connect(); log.info("Executing command {}", commandInfo.getCommand()); commandOutput.onOutput(channelExec); break; // exit from while loop } catch (JSchException e) { if (retryCount == MAX_RETRY_COUNT) { log.error("Retry count " + MAX_RETRY_COUNT + " exceeded for executing command : " + command, e); throw e; } log.error("Issue with jsch, Retry executing command : " + command, e); } } } catch (JSchException e) { throw new GFacException("Unable to execute command - " + command, e); } finally { //Only disconnecting the channel, session can be reused if (channelExec != null) { commandOutput.exitCode(channelExec.getExitStatus()); channelExec.disconnect(); } } } @Override public ServerInfo getServerInfo() { return this.serverInfo; } @Override public AuthenticationInfo getAuthentication() { return this.authentication; } private class DefaultUserInfo implements UserInfo { private String userName; private String password; private String passphrase; public DefaultUserInfo(String userName, String password, String passphrase) { this.userName = userName; this.password = password; this.passphrase = passphrase; } @Override public String getPassphrase() { return null; } @Override public String getPassword() { return null; } @Override public boolean promptPassword(String s) { return false; } @Override public boolean promptPassphrase(String s) { return false; } @Override public boolean promptYesNo(String s) { return false; } @Override public void showMessage(String s) { } } }