package org.bouncycastle.crypto.tls; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.util.Vector; import org.bouncycastle.asn1.x509.KeyUsage; import org.bouncycastle.asn1.x509.SubjectPublicKeyInfo; import org.bouncycastle.crypto.params.AsymmetricKeyParameter; import org.bouncycastle.crypto.params.DHParameters; import org.bouncycastle.crypto.params.DHPrivateKeyParameters; import org.bouncycastle.crypto.params.DHPublicKeyParameters; import org.bouncycastle.crypto.params.RSAKeyParameters; import org.bouncycastle.crypto.util.PublicKeyFactory; /** * TLS 1.0 PSK key exchange (RFC 4279). */ public class TlsPSKKeyExchange extends AbstractTlsKeyExchange { protected TlsPSKIdentity pskIdentity; protected DHParameters dhParameters; protected int[] namedCurves; protected short[] clientECPointFormats, serverECPointFormats; protected byte[] psk_identity_hint = null; protected DHPrivateKeyParameters dhAgreePrivateKey = null; protected DHPublicKeyParameters dhAgreePublicKey = null; protected AsymmetricKeyParameter serverPublicKey = null; protected RSAKeyParameters rsaServerPublicKey = null; protected TlsEncryptionCredentials serverCredentials = null; protected byte[] premasterSecret; public TlsPSKKeyExchange(int keyExchange, Vector supportedSignatureAlgorithms, TlsPSKIdentity pskIdentity, DHParameters dhParameters, int[] namedCurves, short[] clientECPointFormats, short[] serverECPointFormats) { super(keyExchange, supportedSignatureAlgorithms); switch (keyExchange) { case KeyExchangeAlgorithm.DHE_PSK: case KeyExchangeAlgorithm.ECDHE_PSK: case KeyExchangeAlgorithm.PSK: case KeyExchangeAlgorithm.RSA_PSK: break; default: throw new IllegalArgumentException("unsupported key exchange algorithm"); } this.pskIdentity = pskIdentity; this.dhParameters = dhParameters; this.namedCurves = namedCurves; this.clientECPointFormats = clientECPointFormats; this.serverECPointFormats = serverECPointFormats; } public void skipServerCredentials() throws IOException { if (keyExchange == KeyExchangeAlgorithm.RSA_PSK) { throw new TlsFatalAlert(AlertDescription.unexpected_message); } } public void processServerCredentials(TlsCredentials serverCredentials) throws IOException { if (!(serverCredentials instanceof TlsEncryptionCredentials)) { throw new TlsFatalAlert(AlertDescription.internal_error); } processServerCertificate(serverCredentials.getCertificate()); this.serverCredentials = (TlsEncryptionCredentials)serverCredentials; } public byte[] generateServerKeyExchange() throws IOException { // TODO[RFC 4279] Need a server-side PSK API to determine hint and resolve identities to keys this.psk_identity_hint = null; if (this.psk_identity_hint == null && !requiresServerKeyExchange()) { return null; } ByteArrayOutputStream buf = new ByteArrayOutputStream(); if (this.psk_identity_hint == null) { TlsUtils.writeOpaque16(TlsUtils.EMPTY_BYTES, buf); } else { TlsUtils.writeOpaque16(this.psk_identity_hint, buf); } if (this.keyExchange == KeyExchangeAlgorithm.DHE_PSK) { if (this.dhParameters == null) { throw new TlsFatalAlert(AlertDescription.internal_error); } this.dhAgreePrivateKey = TlsDHUtils.generateEphemeralServerKeyExchange(context.getSecureRandom(), this.dhParameters, buf); } else if (this.keyExchange == KeyExchangeAlgorithm.ECDHE_PSK) { // TODO[RFC 5489] } return buf.toByteArray(); } public void processServerCertificate(Certificate serverCertificate) throws IOException { if (keyExchange != KeyExchangeAlgorithm.RSA_PSK) { throw new TlsFatalAlert(AlertDescription.unexpected_message); } if (serverCertificate.isEmpty()) { throw new TlsFatalAlert(AlertDescription.bad_certificate); } org.bouncycastle.asn1.x509.Certificate x509Cert = serverCertificate.getCertificateAt(0); SubjectPublicKeyInfo keyInfo = x509Cert.getSubjectPublicKeyInfo(); try { this.serverPublicKey = PublicKeyFactory.createKey(keyInfo); } catch (RuntimeException e) { throw new TlsFatalAlert(AlertDescription.unsupported_certificate); } // Sanity check the PublicKeyFactory if (this.serverPublicKey.isPrivate()) { throw new TlsFatalAlert(AlertDescription.internal_error); } this.rsaServerPublicKey = validateRSAPublicKey((RSAKeyParameters)this.serverPublicKey); TlsUtils.validateKeyUsage(x509Cert, KeyUsage.keyEncipherment); super.processServerCertificate(serverCertificate); } public boolean requiresServerKeyExchange() { switch (keyExchange) { case KeyExchangeAlgorithm.DHE_PSK: case KeyExchangeAlgorithm.ECDHE_PSK: return true; default: return false; } } public void processServerKeyExchange(InputStream input) throws IOException { this.psk_identity_hint = TlsUtils.readOpaque16(input); if (this.keyExchange == KeyExchangeAlgorithm.DHE_PSK) { ServerDHParams serverDHParams = ServerDHParams.parse(input); this.dhAgreePublicKey = TlsDHUtils.validateDHPublicKey(serverDHParams.getPublicKey()); } else if (this.keyExchange == KeyExchangeAlgorithm.ECDHE_PSK) { // TODO[RFC 5489] } } public void validateCertificateRequest(CertificateRequest certificateRequest) throws IOException { throw new TlsFatalAlert(AlertDescription.unexpected_message); } public void processClientCredentials(TlsCredentials clientCredentials) throws IOException { throw new TlsFatalAlert(AlertDescription.internal_error); } public void generateClientKeyExchange(OutputStream output) throws IOException { if (psk_identity_hint == null) { pskIdentity.skipIdentityHint(); } else { pskIdentity.notifyIdentityHint(psk_identity_hint); } byte[] psk_identity = pskIdentity.getPSKIdentity(); TlsUtils.writeOpaque16(psk_identity, output); if (this.keyExchange == KeyExchangeAlgorithm.DHE_PSK) { this.dhAgreePrivateKey = TlsDHUtils.generateEphemeralClientKeyExchange(context.getSecureRandom(), dhAgreePublicKey.getParameters(), output); } else if (this.keyExchange == KeyExchangeAlgorithm.ECDHE_PSK) { // TODO[RFC 5489] throw new TlsFatalAlert(AlertDescription.internal_error); } else if (this.keyExchange == KeyExchangeAlgorithm.RSA_PSK) { this.premasterSecret = TlsRSAUtils.generateEncryptedPreMasterSecret(context, this.rsaServerPublicKey, output); } } public byte[] generatePremasterSecret() throws IOException { byte[] psk = pskIdentity.getPSK(); byte[] other_secret = generateOtherSecret(psk.length); ByteArrayOutputStream buf = new ByteArrayOutputStream(4 + other_secret.length + psk.length); TlsUtils.writeOpaque16(other_secret, buf); TlsUtils.writeOpaque16(psk, buf); return buf.toByteArray(); } protected byte[] generateOtherSecret(int pskLength) throws IOException { if (this.keyExchange == KeyExchangeAlgorithm.DHE_PSK) { if (dhAgreePrivateKey != null) { return TlsDHUtils.calculateDHBasicAgreement(dhAgreePublicKey, dhAgreePrivateKey); } throw new TlsFatalAlert(AlertDescription.internal_error); } if (this.keyExchange == KeyExchangeAlgorithm.ECDHE_PSK) { // TODO[RFC 5489] throw new TlsFatalAlert(AlertDescription.internal_error); } if (this.keyExchange == KeyExchangeAlgorithm.RSA_PSK) { return this.premasterSecret; } return new byte[pskLength]; } protected RSAKeyParameters validateRSAPublicKey(RSAKeyParameters key) throws IOException { // TODO What is the minimum bit length required? // key.getModulus().bitLength(); if (!key.getExponent().isProbablePrime(2)) { throw new TlsFatalAlert(AlertDescription.illegal_parameter); } return key; } }