package org.ovirt.engine.core.uutils.ssh;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.math.BigInteger;
import java.nio.charset.StandardCharsets;
import java.security.GeneralSecurityException;
import java.security.KeyFactory;
import java.security.MessageDigest;
import java.security.PublicKey;
import java.security.interfaces.RSAPublicKey;
import java.security.spec.RSAPublicKeySpec;
import java.util.Arrays;
import org.apache.commons.codec.binary.Base64;
import org.apache.commons.codec.binary.Hex;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class OpenSSHUtils {
// The log:
private static final Logger log = LoggerFactory.getLogger(OpenSSHUtils.class);
// Names of supported algorithms:
private static final String SSH_RSA = "ssh-rsa";
private static final String MD5 = "MD5";
private OpenSSHUtils() {
// No instances allowed.
}
private static byte[] getByteArrayOfData(DataInputStream dataInputStream) throws IOException {
byte[] contents = new byte[dataInputStream.readInt()];
if (dataInputStream.read(contents, 0, contents.length) != contents.length) {
throw new IOException("Invalid ASN1 array");
}
return contents;
}
/**
* Convert a public key string to real public key.
*/
public static PublicKey decodeKeyString(final String key) throws IOException, GeneralSecurityException {
String[] words = key.split("\\s+", 3);
if (words.length < 2 || !SSH_RSA.equals(words[0])) {
throw new GeneralSecurityException("Unsupported SSH public key");
}
try (
ByteArrayInputStream inputStream = new ByteArrayInputStream(Base64.decodeBase64(words[1]));
DataInputStream dataInputStream = new DataInputStream(inputStream)) {
if (!Arrays.equals(getByteArrayOfData(dataInputStream), SSH_RSA.getBytes(StandardCharsets.UTF_8))) {
throw new GeneralSecurityException("Unsupported SSH public key");
}
byte[] exponentBytes = getByteArrayOfData(dataInputStream);
byte[] modulusBytes = getByteArrayOfData(dataInputStream);
return KeyFactory.getInstance("RSA").generatePublic(
new RSAPublicKeySpec(
new BigInteger(modulusBytes),
new BigInteger(exponentBytes)));
}
}
/**
* Convert a public key to the SSH format.
*
* Note that only RSA keys are supported at the moment.
*
* @param key
* the public key to convert
* @return an array of bytes that can with the representation of the public key
*/
public static byte[] getKeyBytes(final PublicKey key) {
// We only support RSA at the moment:
if (!(key instanceof RSAPublicKey)) {
log.error("The key algorithm '{}' is not supported, will return null.", key.getAlgorithm());
return null;
}
// Extract the bytes of the exponent and the modulus
// of the key:
final RSAPublicKey rsaKey = (RSAPublicKey) key;
final byte[] exponentBytes = rsaKey.getPublicExponent().toByteArray();
final byte[] modulusBytes = rsaKey.getModulus().toByteArray();
if (log.isDebugEnabled()) {
log.debug("Exponent is {} ({}).", rsaKey.getPublicExponent(), Hex.encodeHexString(exponentBytes));
log.debug("Modulus is {} ({}).", rsaKey.getModulus(), Hex.encodeHexString(exponentBytes));
}
try {
// Prepare the stream to write the binary SSH key:
final ByteArrayOutputStream binaryOut = new ByteArrayOutputStream();
final DataOutputStream dataOut = new DataOutputStream(binaryOut);
// Write the SSH header (4 bytes for the length of the algorithm
// name and then the algorithm name):
dataOut.writeInt(SSH_RSA.length());
dataOut.writeBytes(SSH_RSA);
// Write the exponent and modulus bytes (note that it is not
// necessary to check if the most significative bit is one, as
// that will never happen with byte arrays created from big
// integers, unless they are negative, which is not the case
// for RSA modulus or exponents):
dataOut.writeInt(exponentBytes.length);
dataOut.write(exponentBytes);
dataOut.writeInt(modulusBytes.length);
dataOut.write(modulusBytes);
// Done, extract the bytes:
binaryOut.close();
final byte[] keyBytes = binaryOut.toByteArray();
if (log.isDebugEnabled()) {
log.debug("Key bytes are {}.", Hex.encodeHexString(keyBytes));
}
return keyBytes;
} catch (IOException exception) {
log.error("Error while serializing public key, will return null.", exception);
return null;
}
}
/**
* Convert a public key to the SSH format used in the <code>authorized_keys</code> files.
*
* Note that only RSA keys are supported at the moment.
*
* @param key
* the public key to convert
* @param alias
* the alias to be appended at the end of the line, if it is <code>null</code> nothing will be appended
* @return an string that can be directly written to the <code>authorized_keys</code> file or <code>null</code> if
* the conversion can't be performed for whatever the reason
*/
public static String getKeyString(final PublicKey key, String alias) {
// Get the serialized version of the key:
final byte[] keyBytes = getKeyBytes(key);
if (keyBytes == null) {
log.error("Can't get key bytes, will return null.");
return null;
}
// Encode it using BASE64:
final Base64 encoder = new Base64(0);
final String encoding = encoder.encodeToString(keyBytes);
if (log.isDebugEnabled()) {
log.debug("Key encoding is '{}'.", encoding);
}
// Return the generated SSH public key:
final StringBuilder buffer = new StringBuilder(
SSH_RSA.length() + 1 + encoding.length() + (alias != null ? 1 + alias.length() : 0) + 1);
buffer.append(SSH_RSA);
buffer.append(" ");
buffer.append(encoding);
if (alias != null) {
buffer.append(" ");
buffer.append(alias);
}
buffer.append('\n');
final String keyString = buffer.toString();
if (log.isDebugEnabled()) {
log.debug("Key string is '{}'.", keyString);
}
return keyString;
}
public static final boolean checkKeyFingerprint(String expected, final PublicKey key, StringBuilder actual)
throws Exception {
String digest = expected.split(":", 2)[0];
try {
if (digest.length() == 2) {
Integer.parseInt(digest, 16);
digest = MD5;
expected = digest + ":" + expected;
}
} catch (NumberFormatException e) {
// ignore
}
if (!digest.startsWith("MD")) {
digest = digest.replaceFirst("([0-9])", "-$1");
}
String fingerprint = getKeyFingerprint(key, digest);
boolean result;
if (MD5.equals(digest)) {
result = expected.equalsIgnoreCase(fingerprint);
} else {
result = expected.equals(fingerprint);
}
if (actual != null) {
actual.setLength(0);
actual.append(fingerprint);
}
return result;
}
public static final String getKeyFingerprint(final PublicKey key, String digest) {
if (digest == null) {
digest = "SHA-256";
}
try {
MessageDigest md = MessageDigest.getInstance(digest);
md.update(getKeyBytes(key));
String fingerprint;
if (MD5.equals(digest)) {
StringBuilder s = new StringBuilder();
s.append(MD5);
for (byte b : md.digest()) {
s.append(':');
s.append(String.format("%02x", b));
}
fingerprint = s.toString();
} else {
fingerprint = String.format(
"%s:%s",
digest.toUpperCase().replace("-", ""),
new Base64(0).encodeToString(md.digest()).replaceAll("=", ""));
}
if (log.isDebugEnabled()) {
log.debug("Fingerprint: {}", fingerprint);
}
return fingerprint;
} catch (GeneralSecurityException e) {
throw new RuntimeException(e);
}
}
public static final String getKeyFingerprint(final PublicKey key) throws Exception {
return getKeyFingerprint(key, null);
}
/*
* commons-codec <= 1.4 has Base64.isArrayByteBase64, but it is deprecated; commons-codec >= 1.5 has Base64.isBase64
* which works on byte[], but it treats whitespace as valid. So, we roll out our own version.
*/
private static boolean isBase64(byte[] octects) {
for (byte octect : octects) {
if (!Base64.isBase64(octect)) {
return false;
}
}
return true;
}
public static boolean isPublicKeyValid(String publicKey) {
int i = publicKey.indexOf("\n");
if (i != -1 && i != publicKey.length() - 1) {
return false;
}
/*
* An OpenSSH public key consists of: [mandatory] The key type [mandatory] A chunk of PEM-encoded data (PEM is a
* specific type of Base64 encoding) [optional] A comment
*/
String[] words = publicKey.split("\\s+", 3);
if (words.length < 2) {
return false;
}
/*
* As per http://docs.oracle.com/javase/8/docs/api/java/util/regex/Pattern.html these character class are
* US-ASCII only.
*/
if (!words[0].matches("^[\\p{Alpha}\\p{Digit}-]*$")) {
return false;
}
if (!isBase64(words[1].getBytes(StandardCharsets.UTF_8))) {
return false;
}
return true;
}
public static boolean arePublicKeysValid(String publicKeys) {
for (String publicKey : publicKeys.split("\n")) {
if (!isPublicKeyValid(publicKey)) {
return false;
}
}
return true;
}
}