/**
*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.airavata.gfac.impl.task;
import com.jcraft.jsch.JSchException;
import com.jcraft.jsch.Session;
import org.apache.airavata.common.exception.AiravataException;
import org.apache.airavata.common.exception.ApplicationSettingsException;
import org.apache.airavata.credential.store.store.CredentialStoreException;
import org.apache.airavata.gfac.core.GFacException;
import org.apache.airavata.gfac.core.GFacUtils;
import org.apache.airavata.gfac.core.authentication.AuthenticationInfo;
import org.apache.airavata.gfac.core.cluster.CommandInfo;
import org.apache.airavata.gfac.core.cluster.RawCommandInfo;
import org.apache.airavata.gfac.core.cluster.RemoteCluster;
import org.apache.airavata.gfac.core.cluster.ServerInfo;
import org.apache.airavata.gfac.core.context.ProcessContext;
import org.apache.airavata.gfac.core.context.TaskContext;
import org.apache.airavata.gfac.core.task.Task;
import org.apache.airavata.gfac.core.task.TaskException;
import org.apache.airavata.gfac.impl.Factory;
import org.apache.airavata.model.appcatalog.storageresource.StorageResourceDescription;
import org.apache.airavata.model.application.io.InputDataObjectType;
import org.apache.airavata.model.application.io.OutputDataObjectType;
import org.apache.airavata.model.commons.ErrorModel;
import org.apache.airavata.model.status.ProcessState;
import org.apache.airavata.model.status.TaskState;
import org.apache.airavata.model.status.TaskStatus;
import org.apache.airavata.model.task.DataStagingTaskModel;
import org.apache.airavata.model.task.TaskTypes;
import org.apache.thrift.TException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.Arrays;
import java.util.Map;
/**
* This will be used for both Input file staging and output file staging, hence if you do any changes to a part of logic
* in this class please consider that will works with both input and output cases.
*/
public class SCPDataStageTask implements Task {
private static final Logger log = LoggerFactory.getLogger(SCPDataStageTask.class);
private static final int DEFAULT_SSH_PORT = 22;
private String hostName;
private String inputPath;
@Override
public void init(Map<String, String> propertyMap) throws TaskException {
}
@Override
public TaskStatus execute(TaskContext taskContext) {
TaskStatus status = new TaskStatus(TaskState.EXECUTING);
AuthenticationInfo authenticationInfo = null;
DataStagingTaskModel subTaskModel = null;
String localDataDir = null;
ProcessContext processContext = taskContext.getParentProcessContext();
ProcessState processState = processContext.getProcessState();
try {
subTaskModel = ((DataStagingTaskModel) taskContext.getSubTaskModel());
if (processState == ProcessState.OUTPUT_DATA_STAGING) {
OutputDataObjectType processOutput = taskContext.getProcessOutput();
if (processOutput != null && processOutput.getValue() == null) {
log.error("expId: {}, processId:{}, taskId: {}:- Couldn't stage file {} , file name shouldn't be null",
taskContext.getExperimentId(), taskContext.getProcessId(), taskContext.getTaskId(),
processOutput.getName());
status = new TaskStatus(TaskState.FAILED);
if (processOutput.isIsRequired()) {
status.setReason("File name is null, but this output's isRequired bit is not set");
} else {
status.setReason("File name is null");
}
return status;
}
} else if (processState == ProcessState.INPUT_DATA_STAGING) {
InputDataObjectType processInput = taskContext.getProcessInput();
if (processInput != null && processInput.getValue() == null) {
log.error("expId: {}, processId:{}, taskId: {}:- Couldn't stage file {} , file name shouldn't be null",
taskContext.getExperimentId(), taskContext.getProcessId(), taskContext.getTaskId(),
processInput.getName());
status = new TaskStatus(TaskState.FAILED);
if (processInput.isIsRequired()) {
status.setReason("File name is null, but this input's isRequired bit is not set");
} else {
status.setReason("File name is null");
}
return status;
}
} else {
status.setState(TaskState.FAILED);
status.setReason("Invalid task invocation, Support " + ProcessState.INPUT_DATA_STAGING.name() + " and " +
"" + ProcessState.OUTPUT_DATA_STAGING.name() + " process phases. found " + processState.name());
return status;
}
StorageResourceDescription storageResource = processContext.getStorageResource();
// StoragePreference storagePreference = taskContext.getParentProcessContext().getStoragePreference();
if (storageResource != null) {
hostName = storageResource.getHostName();
} else {
throw new GFacException("Storage Resource is null");
}
inputPath = processContext.getStorageFileSystemRootLocation();
inputPath = (inputPath.endsWith(File.separator) ? inputPath : inputPath + File.separator);
// use rsync instead of scp if source and destination host and user name is same.
URI sourceURI = new URI(subTaskModel.getSource());
String fileName = sourceURI.getPath().substring(sourceURI.getPath().lastIndexOf(File.separator) + 1,
sourceURI.getPath().length());
authenticationInfo = Factory.getComputerResourceSSHKeyAuthentication(processContext);
ServerInfo serverInfo = processContext.getComputeResourceServerInfo();
Session sshSession = Factory.getSSHSession(authenticationInfo, serverInfo);
URI destinationURI = null;
if (subTaskModel.getDestination().startsWith("dummy")) {
destinationURI = TaskUtils.getDestinationURI(taskContext, hostName, inputPath, fileName);
subTaskModel.setDestination(destinationURI.toString());
} else {
destinationURI = new URI(subTaskModel.getDestination());
}
if (sourceURI.getHost().equalsIgnoreCase(destinationURI.getHost())
&& sourceURI.getUserInfo().equalsIgnoreCase(destinationURI.getUserInfo())) {
localDataCopy(taskContext, sourceURI, destinationURI);
status.setState(TaskState.COMPLETED);
status.setReason("Locally copied file using 'cp' command ");
return status;
}
status = new TaskStatus(TaskState.COMPLETED);
//Wildcard for file name. Has to find the correct name.
if(fileName.startsWith("*.")){
String destParentPath = (new File(destinationURI.getPath())).getParentFile().getPath();
String sourceParentPath = (new File(sourceURI.getPath())).getParentFile().getPath();
String temp = taskContext.getParentProcessContext().getDataMovementRemoteCluster()
.getFileNameFromExtension(fileName.substring(2), sourceParentPath, sshSession);
if(temp != null && temp != ""){
fileName = temp;
}
if(destParentPath.endsWith(File.separator)){
destinationURI = new URI(destParentPath + fileName);
}else{
destinationURI = new URI(destParentPath + File.separator + fileName);
}
}
if (processState == ProcessState.INPUT_DATA_STAGING) {
inputDataStaging(taskContext, sshSession, sourceURI, destinationURI);
status.setReason("Successfully staged input data");
} else if (processState == ProcessState.OUTPUT_DATA_STAGING) {
makeDir(taskContext, destinationURI);
// TODO - save updated subtask model with new destination
outputDataStaging(taskContext, sshSession, sourceURI, destinationURI);
status.setReason("Successfully staged output data");
}
} catch (TException e) {
String msg = "Couldn't create subTask model thrift model";
log.error(msg, e);
status.setState(TaskState.FAILED);
status.setReason(msg);
ErrorModel errorModel = new ErrorModel();
errorModel.setActualErrorMessage(e.getMessage());
errorModel.setUserFriendlyMessage(msg);
taskContext.getTaskModel().setTaskErrors(Arrays.asList(errorModel));
return status;
} catch (ApplicationSettingsException | FileNotFoundException e) {
String msg = "Failed while reading credentials";
log.error(msg, e);
status.setState(TaskState.FAILED);
status.setReason(msg);
ErrorModel errorModel = new ErrorModel();
errorModel.setActualErrorMessage(e.getMessage());
errorModel.setUserFriendlyMessage(msg);
taskContext.getTaskModel().setTaskErrors(Arrays.asList(errorModel));
} catch (URISyntaxException e) {
String msg = "Source or destination uri is not correct source : " + subTaskModel.getSource() + ", " +
"destination : " + subTaskModel.getDestination();
log.error(msg, e);
status.setState(TaskState.FAILED);
status.setReason(msg);
ErrorModel errorModel = new ErrorModel();
errorModel.setActualErrorMessage(e.getMessage());
errorModel.setUserFriendlyMessage(msg);
taskContext.getTaskModel().setTaskErrors(Arrays.asList(errorModel));
} catch (CredentialStoreException e) {
String msg = "Storage authentication issue, could be invalid credential token";
log.error(msg, e);
status.setState(TaskState.FAILED);
status.setReason(msg);
ErrorModel errorModel = new ErrorModel();
errorModel.setActualErrorMessage(e.getMessage());
errorModel.setUserFriendlyMessage(msg);
taskContext.getTaskModel().setTaskErrors(Arrays.asList(errorModel));
} catch (AiravataException e) {
String msg = "Error while creating ssh session with client";
log.error(msg, e);
status.setState(TaskState.FAILED);
status.setReason(msg);
ErrorModel errorModel = new ErrorModel();
errorModel.setActualErrorMessage(e.getMessage());
errorModel.setUserFriendlyMessage(msg);
taskContext.getTaskModel().setTaskErrors(Arrays.asList(errorModel));
} catch (JSchException | IOException e) {
String msg = "Failed to do scp with client";
log.error(msg, e);
status.setState(TaskState.FAILED);
status.setReason(msg);
ErrorModel errorModel = new ErrorModel();
errorModel.setActualErrorMessage(e.getMessage());
errorModel.setUserFriendlyMessage(msg);
taskContext.getTaskModel().setTaskErrors(Arrays.asList(errorModel));
} catch (GFacException e) {
String msg = "Data staging failed";
log.error(msg, e);
status.setState(TaskState.FAILED);
status.setReason(msg);
ErrorModel errorModel = new ErrorModel();
errorModel.setActualErrorMessage(e.getMessage());
errorModel.setUserFriendlyMessage(msg);
taskContext.getTaskModel().setTaskErrors(Arrays.asList(errorModel));
}
return status;
}
private void localDataCopy(TaskContext taskContext, URI sourceURI, URI destinationURI) throws GFacException {
StringBuilder sb = new StringBuilder("rsync -cr ");
sb.append(sourceURI.getPath()).append(" ").append(destinationURI.getPath());
CommandInfo commandInfo = new RawCommandInfo(sb.toString());
taskContext.getParentProcessContext().getDataMovementRemoteCluster().execute(commandInfo);
}
private void inputDataStaging(TaskContext taskContext, Session sshSession, URI sourceURI, URI
destinationURI) throws GFacException, IOException, JSchException {
/**
* scp third party file transfer 'to' compute resource.
*/
taskContext.getParentProcessContext().getDataMovementRemoteCluster().scpThirdParty(sourceURI.getPath(),
destinationURI.getPath(), sshSession, RemoteCluster.DIRECTION.FROM, false);
}
private void outputDataStaging(TaskContext taskContext, Session sshSession, URI sourceURI, URI destinationURI)
throws AiravataException, IOException, JSchException, GFacException {
/**
* scp third party file transfer 'from' comute resource.
*/
taskContext.getParentProcessContext().getDataMovementRemoteCluster().scpThirdParty(sourceURI.getPath(),
destinationURI.getPath(), sshSession, RemoteCluster.DIRECTION.TO, true);
// update output locations
GFacUtils.saveExperimentOutput(taskContext.getParentProcessContext(), taskContext.getProcessOutput().getName(), destinationURI.toString());
GFacUtils.saveProcessOutput(taskContext.getParentProcessContext(), taskContext.getProcessOutput().getName(), destinationURI.toString());
}
private void makeDir(TaskContext taskContext, URI pathURI) throws GFacException {
int endIndex = pathURI.getPath().lastIndexOf('/');
if (endIndex < 1) {
return;
}
String targetPath = pathURI.getPath().substring(0, endIndex);
taskContext.getParentProcessContext().getDataMovementRemoteCluster().makeDirectory(targetPath);
}
@Override
public TaskStatus recover(TaskContext taskContext) {
TaskState state = taskContext.getTaskStatus().getState();
if (state == TaskState.EXECUTING || state == TaskState.CREATED) {
return execute(taskContext);
} else {
// files already transferred or failed
return taskContext.getTaskStatus();
}
}
@Override
public TaskTypes getType() {
return TaskTypes.DATA_STAGING;
}
}