package com.sequenceiq.cloudbreak.service.credential;
import java.math.BigInteger;
import java.security.KeyFactory;
import java.security.PublicKey;
import java.security.spec.DSAPublicKeySpec;
import java.security.spec.RSAPublicKeySpec;
import java.util.NoSuchElementException;
import java.util.StringTokenizer;
import org.apache.commons.codec.binary.Base64;
public final class PublicKeyReaderUtil {
private static final String BEGIN_PUB_KEY = "---- BEGIN SSH2 PUBLIC KEY ----";
private static final String END_PUB_KEY = "---- END SSH2 PUBLIC KEY ----";
private static final String SSH2_DSA_KEY = "ssh-dsa";
private static final String SSH2_RSA_KEY = "ssh-rsa";
private PublicKeyReaderUtil() {
}
public static PublicKey load(final String key) throws PublicKeyParseException {
final int c = key.charAt(0);
final String base64;
if (c == 's') {
base64 = PublicKeyReaderUtil.extractOpenSSHBase64(key);
} else if (c == '-') {
base64 = PublicKeyReaderUtil.extractSecSHBase64(key);
} else {
throw new PublicKeyParseException(PublicKeyParseException.ErrorCode.UNKNOWN_PUBLIC_KEY_FILE_FORMAT);
}
final SSH2DataBuffer buf = new SSH2DataBuffer(Base64.decodeBase64(base64.getBytes()));
final String type = buf.readString();
final PublicKey ret;
if (PublicKeyReaderUtil.SSH2_DSA_KEY.equals(type)) {
ret = decodeDSAPublicKey(buf);
} else if (PublicKeyReaderUtil.SSH2_RSA_KEY.equals(type)) {
ret = decodePublicKey(buf);
} else {
throw new PublicKeyParseException(PublicKeyParseException.ErrorCode.UNKNOWN_PUBLIC_KEY_CERTIFICATE_FORMAT);
}
return ret;
}
public static PublicKey loadOpenSsh(final String key) throws PublicKeyParseException {
final int c = key.charAt(0);
final String base64;
if (c == 's') {
base64 = PublicKeyReaderUtil.extractOpenSSHBase64(key);
} else {
throw new PublicKeyParseException(PublicKeyParseException.ErrorCode.UNKNOWN_PUBLIC_KEY_FILE_FORMAT);
}
final SSH2DataBuffer buf = new SSH2DataBuffer(Base64.decodeBase64(base64.getBytes()));
final String type = buf.readString();
final PublicKey ret;
if (PublicKeyReaderUtil.SSH2_RSA_KEY.equals(type)) {
ret = decodePublicKey(buf);
} else {
throw new PublicKeyParseException(PublicKeyParseException.ErrorCode.UNKNOWN_PUBLIC_KEY_CERTIFICATE_FORMAT);
}
return ret;
}
public static String extractOpenSSHBase64(final String key)
throws PublicKeyParseException {
final String base64;
try {
final StringTokenizer st = new StringTokenizer(key);
st.nextToken();
base64 = st.nextToken();
} catch (final NoSuchElementException e) {
throw new PublicKeyParseException(PublicKeyParseException.ErrorCode.CORRUPT_OPENSSH_PUBLIC_KEY_STRING);
}
return base64;
}
private static String extractSecSHBase64(final String key) throws PublicKeyParseException {
final StringBuilder base64Data = new StringBuilder();
boolean startKey = false;
boolean startKeyBody = false;
boolean endKey = false;
boolean nextLineIsHeader = false;
for (final String line : key.split("\n")) {
final String trimLine = line.trim();
if (!startKey && trimLine.equals(PublicKeyReaderUtil.BEGIN_PUB_KEY)) {
startKey = true;
} else if (startKey) {
if (trimLine.equals(PublicKeyReaderUtil.END_PUB_KEY)) {
endKey = true;
break;
} else if (nextLineIsHeader) {
if (!trimLine.endsWith("\\")) {
nextLineIsHeader = false;
}
} else if (trimLine.indexOf(':') > 0) {
if (startKeyBody) {
throw new PublicKeyParseException(PublicKeyParseException.ErrorCode.CORRUPT_SECSSH_PUBLIC_KEY_STRING);
} else if (trimLine.endsWith("\\")) {
nextLineIsHeader = true;
}
} else {
startKeyBody = true;
base64Data.append(trimLine);
}
}
}
if (!endKey) {
throw new PublicKeyParseException(
PublicKeyParseException.ErrorCode.CORRUPT_SECSSH_PUBLIC_KEY_STRING);
}
return base64Data.toString();
}
private static PublicKey decodeDSAPublicKey(final SSH2DataBuffer buffer) throws PublicKeyParseException {
final BigInteger p = buffer.readMPint();
final BigInteger q = buffer.readMPint();
final BigInteger g = buffer.readMPint();
final BigInteger y = buffer.readMPint();
try {
final KeyFactory dsaKeyFact = KeyFactory.getInstance("DSA");
final DSAPublicKeySpec dsaPubSpec = new DSAPublicKeySpec(y, p, q, g);
return dsaKeyFact.generatePublic(dsaPubSpec);
} catch (final Exception e) {
throw new PublicKeyParseException(
PublicKeyParseException.ErrorCode.SSH2DSA_ERROR_DECODING_PUBLIC_KEY_BLOB, e);
}
}
private static PublicKey decodePublicKey(final SSH2DataBuffer buffer) throws PublicKeyParseException {
final BigInteger e = buffer.readMPint();
final BigInteger n = buffer.readMPint();
try {
final KeyFactory rsaKeyFact = KeyFactory.getInstance("RSA");
final RSAPublicKeySpec rsaPubSpec = new RSAPublicKeySpec(n, e);
return rsaKeyFact.generatePublic(rsaPubSpec);
} catch (final Exception ex) {
throw new PublicKeyParseException(PublicKeyParseException.ErrorCode.SSH2RSA_ERROR_DECODING_PUBLIC_KEY_BLOB, ex);
}
}
private static class SSH2DataBuffer {
public static final int INT1 = 24;
public static final int INT2 = 16;
public static final int INT3 = 8;
private final byte[] data;
private int pos;
SSH2DataBuffer(final byte[] data) {
this.data = data;
}
public BigInteger readMPint() throws PublicKeyParseException {
final byte[] raw = readByteArray();
return (raw.length > 0) ? new BigInteger(raw) : BigInteger.valueOf(0);
}
public String readString() throws PublicKeyParseException {
return new String(readByteArray());
}
private int readUInt32() {
final int byte1 = this.data[this.pos++];
final int byte2 = this.data[this.pos++];
final int byte3 = this.data[this.pos++];
final int byte4 = this.data[this.pos++];
return (byte1 << INT1) + (byte2 << INT2) + (byte3 << INT3) + (byte4 << 0);
}
private byte[] readByteArray() throws PublicKeyParseException {
final int len = readUInt32();
if ((len < 0) || (len > (this.data.length - this.pos))) {
throw new PublicKeyParseException(
PublicKeyParseException.ErrorCode.CORRUPT_BYTE_ARRAY_ON_READ);
}
final byte[] str = new byte[len];
System.arraycopy(this.data, this.pos, str, 0, len);
this.pos += len;
return str;
}
}
public static final class PublicKeyParseException extends Exception {
private final ErrorCode errorCode;
private PublicKeyParseException(final ErrorCode errorCode) {
super(errorCode.message);
this.errorCode = errorCode;
}
private PublicKeyParseException(final ErrorCode errorCode,
final Throwable cause) {
super(errorCode.message, cause);
this.errorCode = errorCode;
}
public ErrorCode getErrorCode() {
return this.errorCode;
}
public enum ErrorCode {
UNKNOWN_PUBLIC_KEY_FILE_FORMAT("Corrupt or unknown public key file format"),
UNKNOWN_PUBLIC_KEY_CERTIFICATE_FORMAT("Corrupt or unknown public key certificate format"),
CORRUPT_OPENSSH_PUBLIC_KEY_STRING("Corrupt OpenSSH public key string"),
CORRUPT_SECSSH_PUBLIC_KEY_STRING("Corrupt SECSSH public key string"),
SSH2DSA_ERROR_DECODING_PUBLIC_KEY_BLOB("SSH2DSA: error decoding public key blob"),
SSH2RSA_ERROR_DECODING_PUBLIC_KEY_BLOB("SSH2RSA: error decoding public key blob"),
CORRUPT_BYTE_ARRAY_ON_READ("Public key length is shorter than 2048 bits or byte array is corrupt.");
private final String message;
ErrorCode(final String message) {
this.message = message;
}
}
}
}