/*
* ****************************************************************************
* Cloud Foundry
* Copyright (c) [2009-2016] Pivotal Software, Inc. All Rights Reserved.
*
* This product is licensed to you under the Apache License, Version 2.0 (the "License").
* You may not use this product except in compliance with the License.
*
* This product includes a number of subcomponents with
* separate copyright notices and license terms. Your use of these
* subcomponents is subject to the terms and conditions of the
* subcomponent's license, as noted in the LICENSE file.
* ****************************************************************************
*/
package org.cloudfoundry.identity.uaa.oauth;
import org.bouncycastle.asn1.ASN1Sequence;
import org.cloudfoundry.identity.uaa.impl.config.LegacyTokenKey;
import org.cloudfoundry.identity.uaa.oauth.jwt.CommonSignatureVerifier;
import org.cloudfoundry.identity.uaa.oauth.jwt.CommonSigner;
import org.cloudfoundry.identity.uaa.oauth.jwt.Signer;
import org.cloudfoundry.identity.uaa.zone.IdentityZoneConfiguration;
import org.cloudfoundry.identity.uaa.zone.IdentityZoneHolder;
import org.springframework.security.jwt.crypto.sign.MacSigner;
import org.springframework.security.jwt.crypto.sign.SignatureVerifier;
import org.springframework.security.oauth2.common.util.RandomValueStringGenerator;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
import java.security.KeyFactory;
import java.security.KeyPair;
import java.security.NoSuchAlgorithmException;
import java.security.PrivateKey;
import java.security.PublicKey;
import java.security.interfaces.RSAPublicKey;
import java.security.spec.InvalidKeySpecException;
import java.security.spec.KeySpec;
import java.security.spec.RSAPrivateCrtKeySpec;
import java.security.spec.RSAPublicKeySpec;
import java.security.spec.X509EncodedKeySpec;
import java.util.Base64;
import java.util.HashMap;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import static org.springframework.security.jwt.codec.Codecs.b64Decode;
import static org.springframework.security.jwt.codec.Codecs.utf8Encode;
public class KeyInfo {
private static Pattern PEM_DATA = Pattern.compile("-----BEGIN (.*)-----(.*)-----END (.*)-----", Pattern.DOTALL);
private static final Base64.Encoder base64encoder = Base64.getMimeEncoder(64, "\n".getBytes());
private String keyId;
private String verifierKey = new RandomValueStringGenerator().generate();
private String signingKey = verifierKey;
private Signer signer = new CommonSigner(null, verifierKey);
private SignatureVerifier verifier = new MacSigner(signingKey);
private String type = "MAC";
private RSAPublicKey rsaPublicKey;
public static KeyInfo getKey(String keyId) {
return getKeys().get(keyId);
}
public static Map<String, KeyInfo> getKeys() {
IdentityZoneConfiguration config = IdentityZoneHolder.get().getConfig();
if (config == null || config.getTokenPolicy().getKeys() == null || config.getTokenPolicy().getKeys().isEmpty()) {
config = IdentityZoneHolder.getUaaZone().getConfig();
}
Map<String, KeyInfo> keys = new HashMap<>();
for (Map.Entry<String, String> entry : config.getTokenPolicy().getKeys().entrySet()) {
KeyInfo keyInfo = new KeyInfo();
keyInfo.setKeyId(entry.getKey());
keyInfo.setSigningKey(entry.getValue());
keys.put(entry.getKey(), keyInfo);
}
if(keys.isEmpty()) {
keys.put(LegacyTokenKey.LEGACY_TOKEN_KEY_ID, LegacyTokenKey.getLegacyTokenKeyInfo());
}
return keys;
}
public static KeyInfo getActiveKey() {
return getKeys().get(getActiveKeyId());
}
private static String getActiveKeyId() {
IdentityZoneConfiguration config = IdentityZoneHolder.get().getConfig();
if(config == null) return IdentityZoneHolder.getUaaZone().getConfig().getTokenPolicy().getActiveKeyId();
String activeKeyId = config.getTokenPolicy().getActiveKeyId();
Map<String, KeyInfo> keys;
if(!StringUtils.hasText(activeKeyId) && (keys = getKeys()).size() == 1) {
activeKeyId = keys.keySet().stream().findAny().get();
}
if(!StringUtils.hasText(activeKeyId)) {
activeKeyId = IdentityZoneHolder.getUaaZone().getConfig().getTokenPolicy().getActiveKeyId();
}
if(!StringUtils.hasText(activeKeyId)) {
activeKeyId = LegacyTokenKey.LEGACY_TOKEN_KEY_ID;
}
return activeKeyId;
}
public Signer getSigner() {
return signer;
}
/**
* @return the verifierKey
*/
public String getVerifierKey() {
return verifierKey;
}
public String getSigningKey() {
return signingKey;
}
public String getType() {
return type;
}
public RSAPublicKey getRsaPublicKey() {
return rsaPublicKey;
}
/**
* @return true if the KeyInfo represents an asymmetric (RSA) key pair
*/
public boolean isAssymetricKey() {
return isAssymetricKey(verifierKey);
}
public SignatureVerifier getVerifier() {
return verifier;
}
/**
* @return true if the string represents an asymmetric (RSA) key
*/
public static boolean isAssymetricKey(String key) {
return key.startsWith("-----BEGIN");
}
/**
* Sets the JWT signing key and corresponding key for verifying siugnatures produced by this class.
* <p>
* The signing key can be either a simple MAC key or an RSA
* key. RSA keys should be in OpenSSH format,
* as produced by <tt>ssh-keygen</tt>.
*
* @param signingKey the key to be used for signing JWTs.
*/
public void setSigningKey(String signingKey) {
if (StringUtils.isEmpty(signingKey)) {
throw new IllegalArgumentException("Signing key cannot be empty");
}
Assert.hasText(signingKey);
signingKey = signingKey.trim();
this.signingKey = signingKey;
this.signer = new CommonSigner(keyId, signingKey);
if (isAssymetricKey(signingKey)) {
KeyPair keyPair = KeyInfo.parseKeyPair(signingKey);
rsaPublicKey = (RSAPublicKey) keyPair.getPublic();
verifierKey = pemEncodePublicKey(rsaPublicKey);
type = "RSA";
} else {
// Assume it's an HMAC key
this.verifierKey = signingKey;
type = "MAC";
}
verifier = new CommonSignatureVerifier(verifierKey);
}
public String getKeyId() {
return keyId;
}
public void setKeyId(String keyId) {
if(!StringUtils.hasText(keyId)){
throw new IllegalArgumentException("KeyId should not be null or empty");
}
this.keyId = keyId;
this.signer = new CommonSigner(keyId, signingKey);
}
public static KeyPair parseKeyPair(String pemData) {
Matcher m = PEM_DATA.matcher(pemData.trim());
if (!m.matches()) {
throw new IllegalArgumentException("String is not PEM encoded data");
}
String type = m.group(1);
final byte[] content = b64Decode(utf8Encode(m.group(2)));
PublicKey publicKey;
PrivateKey privateKey = null;
try {
KeyFactory fact = KeyFactory.getInstance("RSA");
if (type.equals("RSA PRIVATE KEY")) {
ASN1Sequence seq = ASN1Sequence.getInstance(content);
if (seq.size() != 9) {
throw new IllegalArgumentException("Invalid RSA Private Key ASN1 sequence.");
}
org.bouncycastle.asn1.pkcs.RSAPrivateKey key = org.bouncycastle.asn1.pkcs.RSAPrivateKey.getInstance(seq);
RSAPublicKeySpec pubSpec = new RSAPublicKeySpec(key.getModulus(), key.getPublicExponent());
RSAPrivateCrtKeySpec privSpec = new RSAPrivateCrtKeySpec(
key.getModulus(),
key.getPublicExponent(),
key.getPrivateExponent(),
key.getPrime1(),
key.getPrime2(),
key.getExponent1(),
key.getExponent2(),
key.getCoefficient()
);
publicKey = fact.generatePublic(pubSpec);
privateKey = fact.generatePrivate(privSpec);
} else if (type.equals("PUBLIC KEY")) {
KeySpec keySpec = new X509EncodedKeySpec(content);
publicKey = fact.generatePublic(keySpec);
} else if (type.equals("RSA PUBLIC KEY")) {
ASN1Sequence seq = ASN1Sequence.getInstance(content);
org.bouncycastle.asn1.pkcs.RSAPublicKey key = org.bouncycastle.asn1.pkcs.RSAPublicKey.getInstance(seq);
RSAPublicKeySpec pubSpec = new RSAPublicKeySpec(key.getModulus(), key.getPublicExponent());
publicKey = fact.generatePublic(pubSpec);
} else {
throw new IllegalArgumentException(type + " is not a supported format");
}
return new KeyPair(publicKey, privateKey);
}
catch (InvalidKeySpecException e) {
throw new RuntimeException(e);
}
catch (NoSuchAlgorithmException e) {
throw new IllegalStateException(e);
}
}
public static String pemEncodePublicKey(PublicKey publicKey) {
String begin = "-----BEGIN PUBLIC KEY-----\n";
String end = "\n-----END PUBLIC KEY-----";
byte[] data = publicKey.getEncoded();
String base64encoded = new String(base64encoder.encode(data));
return begin + base64encoded + end;
}
}