/*
* Copyright (c) 2012 S.C. Axemblr Software Solutions S.R.L
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.axemblr.provisionr.core;
import com.axemblr.provisionr.api.access.AdminAccess;
import com.axemblr.provisionr.api.pool.Machine;
import com.axemblr.provisionr.core.logging.ErrorStreamLogger;
import com.axemblr.provisionr.core.logging.InfoStreamLogger;
import com.google.common.base.Charsets;
import static com.google.common.base.Preconditions.checkArgument;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.security.PublicKey;
import net.schmizz.sshj.SSHClient;
import net.schmizz.sshj.common.SecurityUtils;
import net.schmizz.sshj.connection.channel.direct.Session;
import net.schmizz.sshj.transport.verification.HostKeyVerifier;
import net.schmizz.sshj.userauth.keyprovider.OpenSSHKeyFile;
import net.schmizz.sshj.xfer.InMemorySourceFile;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.slf4j.Marker;
import org.slf4j.MarkerFactory;
public class Ssh {
private static final Logger LOG = LoggerFactory.getLogger(Ssh.class);
/**
* Fail if a connection can't be established in this amount of time
*/
public static final int DEFAULT_CONNECT_TIMEOUT = 30 * 1000; /* milliseconds */
public static final int DEFAULT_READ_TIMEOUT = 10 * 60 * 1000; /* milliseconds */
private Ssh() {
}
/**
* Accept any host key from the remote machines
*/
private enum AcceptAnyHostKeyVerifier implements HostKeyVerifier {
INSTANCE;
@Override
public boolean verify(String hostname, int port, PublicKey key) {
String fingerprint = SecurityUtils.getFingerprint(key);
LOG.info("Automatically accepting host key for {}:{} with fingerprint {}",
new Object[]{hostname, port, fingerprint});
return true;
}
}
public static SSHClient newClient(Machine machine, AdminAccess adminAccess) throws IOException {
return newClient(machine, adminAccess, DEFAULT_READ_TIMEOUT);
}
/**
* Create a new {@code SSHClient} connected to the remote machine using the
* AdminAccess credentials as provided
*/
public static SSHClient newClient(
Machine machine, AdminAccess adminAccess, int timeoutInMillis
) throws IOException {
checkArgument(timeoutInMillis >= 0, "timeoutInMillis should be positive or 0");
final SSHClient client = new SSHClient();
client.addHostKeyVerifier(AcceptAnyHostKeyVerifier.INSTANCE);
if (timeoutInMillis != 0) {
client.setConnectTimeout(DEFAULT_CONNECT_TIMEOUT);
client.setTimeout(timeoutInMillis);
}
client.connect(machine.getPublicDnsName(), machine.getSshPort());
OpenSSHKeyFile key = new OpenSSHKeyFile();
key.init(adminAccess.getPrivateKey(), adminAccess.getPublicKey());
client.authPublickey(adminAccess.getUsername(), key);
return client;
}
/**
* Stream command output as log message for easy debugging
*/
public static void logCommandOutput(Logger logger, String instanceId, Session.Command command) {
final Marker marker = MarkerFactory.getMarker("ssh-" + instanceId);
new InfoStreamLogger(command.getInputStream(), logger, marker)
.start();
new ErrorStreamLogger(command.getErrorStream(), logger, marker)
.start();
}
/**
* Create a remote file on SSH from a string
*
* @param client ssh client instance
* @param content content for the new file
* @param permissions unix permissions
* @param destination destination path
* @throws IOException
*/
public static void createFile(
SSHClient client, String content, final int permissions, String destination
) throws IOException {
final byte[] bytes = content.getBytes(Charsets.UTF_8);
client.newSCPFileTransfer().upload(new InMemorySourceFile() {
@Override
public String getName() {
return "in-memory";
}
@Override
public long getLength() {
return bytes.length;
}
@Override
public int getPermissions() throws IOException {
return permissions;
}
@Override
public InputStream getInputStream() throws IOException {
return new ByteArrayInputStream(bytes);
}
}, destination);
}
}