package io.airlift.airship.coordinator.auth.ssh;
import com.google.common.base.Splitter;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ListMultimap;
import com.google.common.collect.PeekingIterator;
import com.google.common.io.CharSource;
import java.io.IOException;
import java.math.BigInteger;
import java.security.PrivateKey;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import static com.google.common.collect.ImmutableList.copyOf;
import static com.google.common.collect.Iterators.peekingIterator;
import static org.apache.commons.codec.binary.Base64.decodeBase64;
public class PemDecoder
{
public static PrivateKey decodeSshPrivateKey(String pemData)
throws IOException
{
Pem pem = parsePem(pemData);
if (pem.getHeaders().containsKey("Proc-Type")) {
throw new IllegalArgumentException("Encrypted keys are not supported");
}
DerReader reader = new DerReader(pem.getData());
byte[] sequence = reader.readEntry(DerType.SEQUENCE);
if (!reader.isComplete()) {
throw new IllegalArgumentException("Invalid ssh key");
}
reader = new DerReader(sequence);
BigInteger version = reader.readBigInteger();
if (!version.equals(BigInteger.ZERO)) {
throw new IllegalArgumentException("Unknown ssh key version " + version);
}
if (pem.getType().equals("DSA PRIVATE KEY")) {
BigInteger p = reader.readBigInteger();
BigInteger q = reader.readBigInteger();
BigInteger g = reader.readBigInteger();
BigInteger y = reader.readBigInteger();
BigInteger x = reader.readBigInteger();
if (!reader.isComplete()) {
throw new IllegalArgumentException("Invalid ssh key");
}
return new PemDsaPrivateKey(pemData, p, q, g, x);
}
else if (pem.getType().equals("RSA PRIVATE KEY")) {
BigInteger n = reader.readBigInteger();
BigInteger e = reader.readBigInteger();
BigInteger d = reader.readBigInteger();
// rsa key contains several more numbers which we don't need
return new PemRsaPrivateKey(pemData, d, n);
}
throw new IllegalArgumentException("Unknown key type " + pem.getType());
}
public static Pem parsePem(String pemData)
throws IOException
{
List<String> lines = CharSource.wrap(pemData).readLines();
for (PeekingIterator<String> iterator = peekingIterator(lines.iterator()); iterator.hasNext(); ) {
String line = iterator.next().trim();
if (line.isEmpty()) {
continue;
}
String type = parseBegin(line);
return parsePem(type, iterator);
}
throw new IllegalArgumentException("Invalid pem data: missing BEGIN");
}
public static String parseBegin(String line)
{
Pattern beginPattern = Pattern.compile("-----BEGIN (.*)-----");
Matcher matcher = beginPattern.matcher(line);
if (matcher.matches()) {
return matcher.group(1).trim();
}
return null;
}
private static Pem parsePem(String type, PeekingIterator<String> iterator)
{
if (!iterator.hasNext()) {
throw new IllegalAccessError("Invalid pem data: missing END");
}
String end = "-----END " + type + "-----";
// read headers
ListMultimap<String, String> headers = readHeaders(iterator, end);
// read body
byte[] body = readBody(iterator, end);
return new Pem(type, headers, body);
}
private static ListMultimap<String, String> readHeaders(PeekingIterator<String> iterator, String keyEnd)
{
ListMultimap<String, String> headers = ArrayListMultimap.create();
for (String line = iterator.peek();
line != null && !line.startsWith(keyEnd) && line.contains(":");
line = iterator.peek()) {
// consume line from iterator
iterator.next();
// header name and value are separated by a colon
List<String> header = copyOf(Splitter.on(':').trimResults().limit(2).split(line));
String name = header.get(0);
// values are comma separated
List<String> values = ImmutableList.of();
if (header.size() == 2) {
values = copyOf(Splitter.on(',').trimResults().split(header.get(1)));
}
headers.putAll(name, values);
}
return headers;
}
private static byte[] readBody(PeekingIterator<String> iterator, String keyEnd)
{
StringBuilder body = new StringBuilder();
for (String line = iterator.peek();
line != null && !line.startsWith(keyEnd);
line = iterator.peek()) {
// consume line from iterator
iterator.next();
body.append(line.trim());
}
return decodeBase64(body.toString());
}
private static class Pem
{
private final String type;
private final ListMultimap<String, String> headers;
private final byte[] data;
private Pem(String type, ListMultimap<String, String> headers, byte[] data)
{
this.type = type;
this.headers = headers;
this.data = data;
}
public String getType()
{
return type;
}
public ListMultimap<String, String> getHeaders()
{
return headers;
}
public byte[] getData()
{
return data;
}
@Override
public String toString()
{
final StringBuilder sb = new StringBuilder();
sb.append("Pem");
sb.append("{type='").append(type).append('\'');
sb.append(", headers=").append(headers);
sb.append(", data.length=").append(data.length);
sb.append('}');
return sb.toString();
}
}
}