package com.sequenceiq.cloudbreak.service.stack.flow; import static org.springframework.ui.freemarker.FreeMarkerTemplateUtils.processTemplateIntoString; import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; import java.nio.charset.StandardCharsets; import java.util.HashMap; import java.util.Map; import java.util.Set; import java.util.concurrent.TimeUnit; import javax.inject.Inject; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Component; import com.google.common.io.BaseEncoding; import com.sequenceiq.cloudbreak.core.CloudbreakException; import com.sequenceiq.cloudbreak.core.CloudbreakSecuritySetupException; import com.sequenceiq.cloudbreak.core.bootstrap.service.OrchestratorType; import com.sequenceiq.cloudbreak.core.bootstrap.service.OrchestratorTypeResolver; import com.sequenceiq.cloudbreak.domain.Credential; import com.sequenceiq.cloudbreak.domain.InstanceMetaData; import com.sequenceiq.cloudbreak.domain.Orchestrator; import com.sequenceiq.cloudbreak.domain.Stack; import com.sequenceiq.cloudbreak.repository.InstanceMetaDataRepository; import com.sequenceiq.cloudbreak.repository.SecurityConfigRepository; import com.sequenceiq.cloudbreak.repository.StackRepository; import com.sequenceiq.cloudbreak.service.GatewayConfigService; import com.sequenceiq.cloudbreak.service.PollingService; import com.sequenceiq.cloudbreak.service.TlsSecurityService; import com.sequenceiq.cloudbreak.service.stack.connector.adapter.ServiceProviderConnectorAdapter; import com.sequenceiq.cloudbreak.util.FileReaderUtils; import freemarker.template.Configuration; import freemarker.template.TemplateException; import net.schmizz.sshj.SSHClient; import net.schmizz.sshj.common.IOUtils; import net.schmizz.sshj.connection.channel.direct.Session; import net.schmizz.sshj.transport.verification.HostKeyVerifier; import net.schmizz.sshj.xfer.InMemorySourceFile; @Component public class TlsSetupService { private static final Logger LOGGER = LoggerFactory.getLogger(TlsSetupService.class); private static final int SETUP_TIMEOUT = 180; private static final int SSH_POLLING_INTERVAL = 5000; private static final int SSH_MAX_ATTEMPTS_FOR_HOSTS = 100; @Inject private ServiceProviderConnectorAdapter connector; @Inject private SecurityConfigRepository securityConfigRepository; @Inject private TlsSecurityService tlsSecurityService; @Inject private StackRepository stackRepository; @Inject private PollingService<SshCheckerTaskContext> sshCheckerTaskContextPollingService; @Inject private SshCheckerTask sshCheckerTask; @Inject private Configuration freemarkerConfiguration; @Inject private OrchestratorTypeResolver orchestratorTypeResolver; @Inject private GatewayConfigService gatewayConfigService; @Inject private InstanceMetaDataRepository instanceMetaDataRepository; @Value("#{'${cb.cert.dir:}/${cb.tls.cert.file:}'}") private String tlsCertificatePath; public void setupTls(Stack stack, InstanceMetaData gwInstance, String user, Set<String> sshFingerprints) throws CloudbreakException { String publicIp = gatewayConfigService.getGatewayIp(stack, gwInstance); int sshPort = gwInstance.getSshPort(); LOGGER.info("SSHClient parameters: stackId: {}, publicIp: {}, user: {}", stack.getId(), publicIp, user); if (publicIp == null) { throw new CloudbreakException("Failed to connect to host, IP address not defined."); } SSHClient ssh = new SSHClient(); Orchestrator orchestrator = stack.getOrchestrator(); HostKeyVerifier hostKeyVerifier = new VerboseHostKeyVerifier(sshFingerprints); try { waitForSsh(stack, publicIp, sshPort, hostKeyVerifier, user); String privateKeyLocation = tlsSecurityService.getSshPrivateFileLocation(stack.getId()); setupTemporarySsh(ssh, publicIp, sshPort, hostKeyVerifier, user, privateKeyLocation, stack.getCredential()); uploadTlsSetupScript(orchestrator, ssh, publicIp, stack.getGatewayPort(), stack.getCredential()); executeTlsSetupScript(ssh); downloadAndSavePrivateKey(stack, ssh, gwInstance); } catch (IOException e) { throw new CloudbreakException("Failed to setup TLS through temporary SSH.", e); } catch (TemplateException e) { throw new CloudbreakException("Failed to generate TLS setup script.", e); } finally { try { ssh.disconnect(); } catch (IOException e) { throw new CloudbreakException("Couldn't disconnect temp SSH session", e); } } } public void removeTemporarySShKey(Stack stack, String publicIp, int sshPort, String user, Set<String> sshFingerprints) throws CloudbreakException { SSHClient ssh = new SSHClient(); try { String privateKeyLocation = tlsSecurityService.getSshPrivateFileLocation(stack.getId()); HostKeyVerifier hostKeyVerifier = new VerboseHostKeyVerifier(sshFingerprints); prepareSshConnection(ssh, publicIp, sshPort, hostKeyVerifier, user, privateKeyLocation, stack.getCredential()); removeTemporarySShKey(ssh, user, stack.getCredential()); } catch (IOException e) { LOGGER.info("Unable to delete temporary SSH key for stack {}", stack.getId()); } finally { try { ssh.disconnect(); } catch (IOException e) { throw new CloudbreakException("Couldn't disconnect temp SSH session", e); } } } private void waitForSsh(Stack stack, String publicIp, int sshPort, HostKeyVerifier hostKeyVerifier, String user) throws CloudbreakSecuritySetupException { sshCheckerTaskContextPollingService.pollWithTimeoutSingleFailure( sshCheckerTask, new SshCheckerTaskContext(stack, hostKeyVerifier, publicIp, sshPort, user, tlsSecurityService.getSshPrivateFileLocation(stack.getId())), SSH_POLLING_INTERVAL, SSH_MAX_ATTEMPTS_FOR_HOSTS); } private void setupTemporarySsh(SSHClient ssh, String ip, int port, HostKeyVerifier hostKeyVerifier, String user, String privateKeyLocation, Credential credential) throws IOException { LOGGER.info("Setting up temporary ssh..."); prepareSshConnection(ssh, ip, port, hostKeyVerifier, user, privateKeyLocation, credential); String remoteTlsCertificatePath = "/tmp/cb-client.pem"; ssh.newSCPFileTransfer().upload(tlsCertificatePath, remoteTlsCertificatePath); LOGGER.info("Temporary ssh setup finished succesfully, public key is uploaded to {}", remoteTlsCertificatePath); } private void prepareSshConnection(SSHClient ssh, String ip, int port, HostKeyVerifier hostKeyVerifier, String user, String privateKeyLocation, Credential credential) throws IOException { ssh.addHostKeyVerifier(hostKeyVerifier); ssh.connect(ip, port); if (credential.passwordAuthenticationRequired()) { ssh.authPassword(user, credential.getLoginPassword()); } else { ssh.authPublickey(user, privateKeyLocation); } } private void uploadTlsSetupScript(Orchestrator orchestrator, SSHClient ssh, String publicIp, Integer sslPort, Credential credential) throws IOException, TemplateException, CloudbreakException { LOGGER.info("Uploading tls-setup.sh to the gateway..."); Map<String, Object> model = new HashMap<>(); model.put("publicIp", publicIp); model.put("username", credential.getLoginUserName()); model.put("sudopre", credential.passwordAuthenticationRequired() ? String.format("echo '%s'|", credential.getLoginPassword()) : ""); model.put("sudocheck", credential.passwordAuthenticationRequired() ? "-S" : ""); model.put("sslPort", sslPort.toString()); OrchestratorType type = orchestratorTypeResolver.resolveType(orchestrator.getType()); String tls = processTemplateIntoString( freemarkerConfiguration.getTemplate(String.format("init/%s/tls-setup.sh", type.name().toLowerCase()), "UTF-8"), model); InMemorySourceFile tlsFile = uploadParameterFile(tls, "tls-setup.sh"); ssh.newSCPFileTransfer().upload(tlsFile, "/tmp/tls-setup.sh"); LOGGER.info("tls-setup.sh uploaded to /tmp/tls-setup.sh. Content: {}", tls); if (type.hostOrchestrator()) { String nginxConf = FileReaderUtils.readFileFromClasspath("init/host/ssl.conf"); InMemorySourceFile nginxConfFile = uploadParameterFile(nginxConf, "ssl.conf"); ssh.newSCPFileTransfer().upload(nginxConfFile, "/tmp/ssl.conf"); LOGGER.info("nginx conf uploaded to /tmp/ssl.conf. Content: {}", nginxConf); } } private InMemorySourceFile uploadParameterFile(String generatedTemplate, final String name) { final byte[] tlsScriptBytes = generatedTemplate.getBytes(StandardCharsets.UTF_8); return new InMemorySourceFile() { @Override public String getName() { return name; } @Override public long getLength() { return tlsScriptBytes.length; } @Override public InputStream getInputStream() throws IOException { return new ByteArrayInputStream(tlsScriptBytes); } }; } private void executeTlsSetupScript(SSHClient ssh) throws IOException, CloudbreakException { LOGGER.info("Executing tls-setup.sh on the gateway..."); int exitStatus = executeSshCommand(ssh, "bash /tmp/tls-setup.sh", true, "tls-setup"); LOGGER.info("tls-setup.sh finished with {} exitcode.", exitStatus); if (exitStatus != 0) { throw new CloudbreakException(String.format("TLS setup script exited with error code: %s", exitStatus)); } } private void removeTemporarySShKey(SSHClient ssh, String user, Credential credential) throws IOException, CloudbreakException { if (!credential.passwordAuthenticationRequired()) { LOGGER.info("Removing temporary sshkey from the gateway..."); String removeCommand = String.format("sudo sed -i '/#tmpssh_start/,/#tmpssh_end/{s/./ /g}' /home/%s/.ssh/authorized_keys", user); int exitStatus = executeSshCommand(ssh, removeCommand, false, ""); LOGGER.info("Temporary sshkey removed from the gateway, exitcode: {}", exitStatus); if (exitStatus != 0) { throw new CloudbreakException(String.format("Failed to remove temp SSH key. Error code: %s", exitStatus)); } } } private void downloadAndSavePrivateKey(Stack stack, SSHClient ssh, InstanceMetaData gwInstance) throws IOException, CloudbreakSecuritySetupException { long stackId = stack.getId(); String serverCertDir = tlsSecurityService.createServerCertDir(stackId, gwInstance); LOGGER.info("Server cert directory is created at: " + serverCertDir); ssh.newSCPFileTransfer().download("/tmp/server.pem", serverCertDir + "/ca.pem"); InstanceMetaData metaData = instanceMetaDataRepository.findOne(gwInstance.getId()); metaData.setServerCert(BaseEncoding.base64().encode(tlsSecurityService.readServerCert(stackId, gwInstance).getBytes())); instanceMetaDataRepository.save(metaData); } private Session startSshSession(SSHClient ssh) throws IOException { Session sshSession = ssh.startSession(); sshSession.allocateDefaultPTY(); return sshSession; } private int executeSshCommand(SSHClient ssh, String command, boolean logOutput, String logPrefix) throws IOException { Session session = startSshSession(ssh); Session.Command cmd = session.exec(command); if (logOutput) { logStdOutAndStdErr(cmd, logPrefix); } cmd.join(SETUP_TIMEOUT, TimeUnit.SECONDS); session.close(); return cmd.getExitStatus(); } private void logStdOutAndStdErr(Session.Command command, String commandDesc) throws IOException { LOGGER.info("Standard output of {} command", commandDesc); LOGGER.info(IOUtils.readFully(command.getInputStream()).toString()); LOGGER.info("Standard error of {} command", commandDesc); LOGGER.info(IOUtils.readFully(command.getErrorStream()).toString()); } }