/**
* Copyright 2010 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
package datameer.awstasks.aws.ec2.ssh;
import java.io.ByteArrayOutputStream;
import java.io.Closeable;
import java.io.File;
import java.io.IOException;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import org.apache.log4j.Logger;
import datameer.awstasks.ssh.JschRunner;
import datameer.awstasks.ssh.ScpDownloadCommand;
import datameer.awstasks.ssh.ScpUploadCommand;
import datameer.awstasks.ssh.SshExecCommand;
import datameer.awstasks.util.ExceptionUtil;
import datameer.awstasks.util.IoUtil;
import datameer.com.google.common.base.Throwables;
import datameer.com.google.common.collect.Lists;
public class SshClientImpl implements SshClient {
protected static final Logger LOG = Logger.getLogger(SshClientImpl.class);
protected File _privateKey;
protected String _password;
protected final String _username;
protected final List<String> _hostnames;
private boolean _enableConnectRetries;
public SshClientImpl(String username, File privateKey, List<String> hostnames) {
_username = username;
_privateKey = privateKey;
_hostnames = hostnames;
}
public SshClientImpl(String username, String password, List<String> hostnames) {
_username = username;
_password = password;
_hostnames = hostnames;
}
@Override
public void setEnableConnectRetries(boolean enable) {
_enableConnectRetries = enable;
}
@Override
public void executeCommand(String command, OutputStream outputStream) throws IOException {
executeCommand(_hostnames, command, outputStream);
}
@Override
public void executeCommand(String command, OutputStream outputStream, int[] targetedInstances) throws IOException {
executeCommand(getHosts(targetedInstances), command, outputStream);
}
private void executeCommand(List<String> hostnames, final String command, final OutputStream outputStream) throws IOException {
executeSshCommand(hostnames, command, null, outputStream);
}
@Override
public void executeCommandFile(File commandFile, OutputStream outputStream) throws IOException {
executeSshCommand(_hostnames, null, commandFile, outputStream);
}
@Override
public void executeCommandFile(File commandFile, OutputStream outputStream, int[] targetedInstances) throws IOException {
executeSshCommand(getHosts(targetedInstances), null, commandFile, outputStream);
}
private void executeSshCommand(final List<String> hostnames, final String command, final File commandFile, final OutputStream outputStream) throws IOException {
List<SshCallable> sshCallables = Lists.newArrayList();
if (hostnames.size() == 1) {
// don't cache the outputstream
sshCallables.add(new SshCallable() {
@Override
protected void execute() throws IOException {
executeCommandOrCommandFile(hostnames.get(0), command, commandFile, outputStream);
}
});
} else {
// cache the outputstream for ordering the results
for (final String host : hostnames) {
sshCallables.add(new SshCallable() {
ByteArrayOutputStream _byteArrayOutputStream;
@Override
protected void execute() throws IOException {
_byteArrayOutputStream = new ByteArrayOutputStream();
executeCommandOrCommandFile(host, command, commandFile, _byteArrayOutputStream);
}
@Override
public void close() {
try {
outputStream.write(_byteArrayOutputStream.toByteArray());
} catch (IOException e) {
throw ExceptionUtil.convertToRuntimeException(e);
}
}
});
}
}
executeCallables(sshCallables);
}
private void executeCallables(List<SshCallable> sshCallables) throws IOException {
ExecutorService e = Executors.newCachedThreadPool();
List<Future<SshCallable>> futureList = Lists.newArrayListWithCapacity(sshCallables.size());
for (SshCallable sshCallable : sshCallables) {
futureList.add(e.submit(sshCallable));
}
waitForSshCommandCompletion(futureList);
}
private void executeCommandOrCommandFile(final String host, final String command, final File commandFile, OutputStream outputStream) throws IOException {
JschRunner jschRunner = createJschRunner(host);
if (command != null) {
LOG.info(String.format("executing command '%s' on '%s'", command, host));
jschRunner.run(new SshExecCommand(command, outputStream));
} else {
LOG.info(String.format("executing command-file '%s' on '%s'", commandFile.getAbsolutePath(), host));
jschRunner.run(new SshExecCommand(commandFile, outputStream));
}
}
private static void waitForSshCommandCompletion(List<Future<SshCallable>> futureList) throws IOException {
boolean interrupted = false;
for (Future<SshCallable> future : futureList) {
try {
if (interrupted) {
future.cancel(true);
} else {
SshCallable sshTask = null;
try {
sshTask = future.get();
} finally {
IoUtil.closeQuietly(sshTask);
}
}
} catch (InterruptedException ex) {
interrupted = true;
} catch (ExecutionException ex) {
Throwables.propagateIfInstanceOf(ex.getCause(), IOException.class);
Throwables.propagate(ex.getCause());
}
}
}
@Override
public void uploadFile(File localFile, String targetPath) throws IOException {
uploadFile(_hostnames, localFile, targetPath);
}
@Override
public void uploadFile(File localFile, String targetPath, int[] instanceIndex) throws IOException {
List<String> hostnames = getHosts(instanceIndex);
uploadFile(hostnames, localFile, targetPath);
}
private void uploadFile(List<String> hostnames, final File localFile, final String targetPath) throws IOException {
List<SshCallable> callables = Lists.newArrayList();
for (final String host : hostnames) {
callables.add(new SshCallable() {
@Override
protected void execute() throws IOException {
LOG.info(String.format("uploading file '%s' to '%s'", localFile.getAbsolutePath(), constructRemotePath(host, targetPath)));
JschRunner jschRunner = createJschRunner(host);
jschRunner.run(new ScpUploadCommand(localFile, targetPath));
}
});
}
executeCallables(callables);
}
@Override
public void downloadFile(String remoteFile, File localPath, boolean recursiv) throws IOException {
downloadFiles(_hostnames, remoteFile, localPath, recursiv);
}
@Override
public void downloadFile(String remoteFile, File localPath, boolean recursiv, int[] instanceIndex) throws IOException {
List<String> hosts = getHosts(instanceIndex);
downloadFiles(hosts, remoteFile, localPath, recursiv);
}
private void downloadFiles(List<String> hostnames, String remoteFile, File localPath, boolean recursiv) throws IOException {
for (String host : hostnames) {
LOG.info(String.format("downloading file '%s' to '%s'", constructRemotePath(host, remoteFile), localPath.getAbsolutePath()));
JschRunner jschRunner = createJschRunner(host);
jschRunner.run(new ScpDownloadCommand(remoteFile, localPath, recursiv));
}
}
private String constructRemotePath(String host, String filePath) {
return _username + ":" + "@" + host + ":" + filePath;
}
protected JschRunner createJschRunner(String host) {
JschRunner runner = new JschRunner(_username, host);
if (_privateKey != null) {
runner.setKeyfile(_privateKey);
} else {
runner.setPassword(_password);
}
runner.setTrust(true);
runner.setEnableConnectionRetries(_enableConnectRetries);
return runner;
}
protected List<String> getHosts(int[] instanceIndex) {
List<String> hostnames = new ArrayList<String>(_hostnames.size());
for (int i = 0; i < instanceIndex.length; i++) {
hostnames.add(_hostnames.get(instanceIndex[i]));
}
return hostnames;
}
private static abstract class SshCallable implements Callable<SshCallable>, Closeable {
@Override
public final SshCallable call() throws Exception {
execute();
return this;
}
protected abstract void execute() throws IOException;
@Override
public void close() {
// subclasses may override
};
}
}