/*
* Copyright 2006-2011 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*/
package org.springframework.security.jwt.crypto.sign;
import static org.springframework.security.jwt.codec.Codecs.b64Decode;
import static org.springframework.security.jwt.codec.Codecs.utf8Encode;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.math.BigInteger;
import java.security.*;
import java.security.interfaces.RSAPublicKey;
import java.security.spec.*;
import java.util.Arrays;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.bouncycastle.asn1.ASN1Sequence;
/**
* Reads RSA key pairs using BC provider classes but without the
* need to specify a crypto provider or have BC added as one.
*
* @author Luke Taylor
*/
class RsaKeyHelper {
private static String BEGIN = "-----BEGIN";
private static Pattern PEM_DATA = Pattern.compile("-----BEGIN (.*)-----(.*)-----END (.*)-----", Pattern.DOTALL);
static KeyPair parseKeyPair(String pemData) {
Matcher m = PEM_DATA.matcher(pemData.trim());
if (!m.matches()) {
throw new IllegalArgumentException("String is not PEM encoded data");
}
String type = m.group(1);
final byte[] content = b64Decode(utf8Encode(m.group(2)));
PublicKey publicKey;
PrivateKey privateKey = null;
try {
KeyFactory fact = KeyFactory.getInstance("RSA");
if (type.equals("RSA PRIVATE KEY")) {
ASN1Sequence seq = ASN1Sequence.getInstance(content);
if (seq.size() != 9) {
throw new IllegalArgumentException("Invalid RSA Private Key ASN1 sequence.");
}
org.bouncycastle.asn1.pkcs.RSAPrivateKey key = org.bouncycastle.asn1.pkcs.RSAPrivateKey.getInstance(seq);
RSAPublicKeySpec pubSpec = new RSAPublicKeySpec(key.getModulus(), key.getPublicExponent());
RSAPrivateCrtKeySpec privSpec = new RSAPrivateCrtKeySpec(key.getModulus(), key.getPublicExponent(),
key.getPrivateExponent(), key.getPrime1(), key.getPrime2(), key.getExponent1(), key.getExponent2(),
key.getCoefficient());
publicKey = fact.generatePublic(pubSpec);
privateKey = fact.generatePrivate(privSpec);
} else if (type.equals("PUBLIC KEY")) {
KeySpec keySpec = new X509EncodedKeySpec(content);
publicKey = fact.generatePublic(keySpec);
} else if (type.equals("RSA PUBLIC KEY")) {
ASN1Sequence seq = ASN1Sequence.getInstance(content);
org.bouncycastle.asn1.pkcs.RSAPublicKey key = org.bouncycastle.asn1.pkcs.RSAPublicKey.getInstance(seq);
RSAPublicKeySpec pubSpec = new RSAPublicKeySpec(key.getModulus(), key.getPublicExponent());
publicKey = fact.generatePublic(pubSpec);
} else {
throw new IllegalArgumentException(type + " is not a supported format");
}
return new KeyPair(publicKey, privateKey);
}
catch (InvalidKeySpecException e) {
throw new RuntimeException(e);
}
catch (NoSuchAlgorithmException e) {
throw new IllegalStateException(e);
}
}
private static final Pattern SSH_PUB_KEY = Pattern.compile("ssh-(rsa|dsa) ([A-Za-z0-9/+]+=*) (.*)");
static RSAPublicKey parsePublicKey(String key) {
Matcher m = SSH_PUB_KEY.matcher(key);
if (m.matches()) {
String alg = m.group(1);
String encKey = m.group(2);
//String id = m.group(3);
if (!"rsa".equalsIgnoreCase(alg)) {
throw new IllegalArgumentException("Only RSA is currently supported, but algorithm was " + alg);
}
return parseSSHPublicKey(encKey);
} else if (!key.startsWith(BEGIN)) {
// Assume it's the plain Base64 encoded ssh key without the "ssh-rsa" at the start
return parseSSHPublicKey(key);
}
KeyPair kp = parseKeyPair(key);
if (kp.getPublic() == null) {
throw new IllegalArgumentException("Key data does not contain a public key");
}
return (RSAPublicKey) kp.getPublic();
}
private static RSAPublicKey parseSSHPublicKey(String encKey) {
final byte[] PREFIX = new byte[] {0,0,0,7, 's','s','h','-','r','s','a'};
ByteArrayInputStream in = new ByteArrayInputStream(b64Decode(utf8Encode(encKey)));
byte[] prefix = new byte[11];
try {
if (in.read(prefix) != 11 || !Arrays.equals(PREFIX, prefix)) {
throw new IllegalArgumentException("SSH key prefix not found");
}
BigInteger e = new BigInteger(readBigInteger(in));
BigInteger n = new BigInteger(readBigInteger(in));
return createPublicKey(n, e);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
static RSAPublicKey createPublicKey(BigInteger n, BigInteger e) {
try {
return (RSAPublicKey) KeyFactory.getInstance("RSA").generatePublic(new RSAPublicKeySpec(n, e));
}
catch (Exception ex) {
throw new RuntimeException(ex);
}
}
private static byte[] readBigInteger(ByteArrayInputStream in) throws IOException {
byte[] b = new byte[4];
if (in.read(b) != 4) {
throw new IOException("Expected length data as 4 bytes");
}
int l = (b[0] << 24) | (b[1] << 16) | (b[2] << 8) | b[3];
b = new byte[l];
if (in.read(b) != l) {
throw new IOException("Expected " + l + " key bytes");
}
return b;
}
}