package core.aws.task.linux;
import com.amazonaws.services.ec2.model.DescribeTagsRequest;
import com.amazonaws.services.ec2.model.Filter;
import com.amazonaws.services.ec2.model.Instance;
import com.amazonaws.services.ec2.model.Tag;
import com.amazonaws.services.ec2.model.TagDescription;
import core.aws.client.AWS;
import core.aws.env.Environment;
import core.aws.env.Param;
import core.aws.resource.ec2.KeyPair;
import core.aws.task.ec2.EC2TagHelper;
import core.aws.util.Asserts;
import core.aws.util.Lists;
import core.aws.util.Randoms;
import core.aws.util.Strings;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.nio.file.Path;
import java.util.List;
import java.util.Locale;
import java.util.concurrent.CountDownLatch;
import java.util.stream.Collectors;
/**
* @author neo
*/
public class SSHRunner {
private final Logger logger = LoggerFactory.getLogger(SSHRunner.class);
private final Environment env;
private final String resourceId;
private final Integer instanceIndex;
private final String tunnelResourceId;
public SSHRunner(Environment env, String resourceId, Integer instanceIndex, String tunnelResourceId) {
this.env = env;
this.resourceId = resourceId;
this.instanceIndex = instanceIndex;
this.tunnelResourceId = tunnelResourceId;
}
public void run() throws IOException, InterruptedException {
List<Instance> instances = runningInstances(resourceId);
Instance instance = locateInstanceToSSH(instances);
Integer tunnelPort = null;
if (tunnelResourceId != null) {
tunnelPort = startTunnelSSH(instance);
}
ssh(instance, tunnelPort);
}
private Integer startTunnelSSH(Instance instance) throws InterruptedException {
Instance tunnelInstance = runningInstances(tunnelResourceId).get(0);
Integer localPort = Double.valueOf(Randoms.number(3000, 10000)).intValue();
CountDownLatch latch = new CountDownLatch(1);
Thread tunnelThread = new Thread(() -> {
Process process = null;
try {
Path keyPath = KeyPair.keyFile(tunnelInstance.getKeyName(), env);
String userAndHost = "ubuntu@" + hostName(tunnelInstance);
String portBinding = Strings.format("{}:{}:22", localPort, instance.getPrivateIpAddress());
List<String> command = tunnelCommand(keyPath, userAndHost, portBinding);
logger.info("tunnel command => {}", String.join(" ", command));
process = new ProcessBuilder().command(command).start();
process.getInputStream().read(); // wait until there is output
latch.countDown();
process.waitFor();
} catch (InterruptedException | IOException e) {
throw new IllegalStateException(e);
} finally {
if (process != null) process.destroy();
}
});
tunnelThread.setDaemon(true);
tunnelThread.start();
latch.await();
return localPort;
}
private List<Instance> runningInstances(String resourceId) {
Tag tag = new EC2TagHelper(env).resourceId(resourceId);
DescribeTagsRequest request = new DescribeTagsRequest()
.withFilters(new Filter("key").withValues(tag.getKey()),
new Filter("value").withValues(tag.getValue()),
new Filter("resource-type").withValues("instance"));
List<TagDescription> remoteTags = AWS.ec2.describeTags(request);
List<String> instanceIds = remoteTags.stream().map(TagDescription::getResourceId).collect(Collectors.toList());
if (instanceIds.isEmpty()) {
com.amazonaws.services.autoscaling.model.AutoScalingGroup asGroup = AWS.as.describeASGroup(env.name + "-" + this.resourceId);
if (asGroup == null) throw new Error("can not find any running instance or asGroup, id=" + this.resourceId);
instanceIds = asGroup.getInstances().stream()
.map(com.amazonaws.services.autoscaling.model.Instance::getInstanceId)
.collect(Collectors.toList());
}
logger.info("find instanceId, {} => {}", resourceId, instanceIds);
List<Instance> instances = AWS.ec2.describeInstances(instanceIds)
.stream().filter(instance -> "running".equals(instance.getState().getName())).collect(Collectors.toList());
if (instances.isEmpty()) throw new Error("can not find any running instance, id=" + resourceId);
return instances;
}
private Instance locateInstanceToSSH(List<Instance> instances) {
for (int i = 0; i < instances.size(); i++) {
com.amazonaws.services.ec2.model.Instance remoteInstance = instances.get(i);
logger.info("index={}, instanceId={}, state={}, publicDNS={}, privateDNS={}",
i,
remoteInstance.getInstanceId(),
remoteInstance.getState().getName(),
remoteInstance.getPublicDnsName(),
remoteInstance.getPrivateDnsName());
}
Asserts.isTrue(instances.size() == 1 || instanceIndex != null, "more than one remoteInstance, use --{} to specify index", Param.INSTANCE_INDEX.key);
return instances.size() == 1 ? instances.get(0) : instances.get(instanceIndex);
}
private void ssh(Instance instance, Integer tunnelPort) throws IOException, InterruptedException {
Path keyPath = KeyPair.keyFile(instance.getKeyName(), env);
String userAndHost;
if (tunnelPort != null) userAndHost = "ubuntu@localhost";
else userAndHost = "ubuntu@" + hostName(instance);
List<String> command = command(keyPath, userAndHost, tunnelPort);
logger.info("command => {}", String.join(" ", command));
Process process = new ProcessBuilder().inheritIO().command(command).start();
process.waitFor();
logger.info("session disconnected");
}
private List<String> command(Path keyPath, String userAndHost, Integer tunnelPort) {
List<String> command = Lists.newArrayList();
if (System.getProperty("os.name").toLowerCase(Locale.getDefault()).contains("win")) {
command.add("cmd");
command.add("/C");
command.add("start");
}
// send server stay live signal every 30 seconds, and accept host
List<String> params = Lists.newArrayList("ssh", "-o", "ServerAliveInterval=30", "-o", "StrictHostKeyChecking=no", "-i", keyPath.toString());
if (tunnelPort != null) {
params.add("-p");
params.add(String.valueOf(tunnelPort));
}
params.add(userAndHost);
command.addAll(params);
return command;
}
private List<String> tunnelCommand(Path keyPath, String userAndHost, String portBinding) {
List<String> command = Lists.newArrayList();
if (System.getProperty("os.name").toLowerCase(Locale.getDefault()).contains("win")) {
command.add("cmd");
command.add("/C");
command.add("start");
}
command.addAll(Lists.newArrayList("ssh", "-o", "ServerAliveInterval=30", "-o", "StrictHostKeyChecking=no", "-i", keyPath.toString(), "-L", portBinding, userAndHost));
return command;
}
private String hostName(Instance remoteInstance) {
String publicDNS = remoteInstance.getPublicDnsName();
return publicDNS != null ? publicDNS : remoteInstance.getPrivateDnsName();
}
}