package org.bouncycastle.crypto.tls.test; import java.io.IOException; import java.io.PipedInputStream; import java.io.PipedOutputStream; import java.io.PrintStream; import java.security.SecureRandom; import junit.framework.TestCase; import org.bouncycastle.asn1.x509.Certificate; import org.bouncycastle.crypto.tls.AlertLevel; import org.bouncycastle.crypto.tls.CertificateRequest; import org.bouncycastle.crypto.tls.ClientCertificateType; import org.bouncycastle.crypto.tls.DefaultTlsClient; import org.bouncycastle.crypto.tls.DefaultTlsServer; import org.bouncycastle.crypto.tls.TlsAuthentication; import org.bouncycastle.crypto.tls.TlsClientProtocol; import org.bouncycastle.crypto.tls.TlsCredentials; import org.bouncycastle.crypto.tls.TlsEncryptionCredentials; import org.bouncycastle.crypto.tls.TlsServerProtocol; import org.bouncycastle.crypto.tls.TlsSignerCredentials; public class TlsProtocolTest extends TestCase { public void testClientServer() throws Exception { SecureRandom secureRandom = new SecureRandom(); PipedInputStream clientRead = new PipedInputStream(); PipedInputStream serverRead = new PipedInputStream(); PipedOutputStream clientWrite = new PipedOutputStream(serverRead); PipedOutputStream serverWrite = new PipedOutputStream(clientRead); TlsClientProtocol clientProtocol = new TlsClientProtocol(clientRead, clientWrite, secureRandom); TlsServerProtocol serverProtocol = new TlsServerProtocol(serverRead, serverWrite, secureRandom); ServerThread serverThread = new ServerThread(serverProtocol); serverThread.start(); MyTlsClient client = new MyTlsClient(); clientProtocol.connect(client); // byte[] data = new byte[64]; // secureRandom.nextBytes(data); // // OutputStream output = clientProtocol.getOutputStream(); // output.write(data); // output.close(); // // byte[] echo = Streams.readAll(clientProtocol.getInputStream()); serverThread.join(); // assertTrue(Arrays.areEqual(data, echo)); } static class ServerThread extends Thread { private final TlsServerProtocol serverProtocol; ServerThread(TlsServerProtocol serverProtocol) { this.serverProtocol = serverProtocol; } public void run() { try { MyTlsServer server = new MyTlsServer(); serverProtocol.accept(server); // Streams.pipeAll(serverProtocol.getInputStream(), // serverProtocol.getOutputStream()); serverProtocol.close(); } catch (Exception e) { throw new RuntimeException(e); } } } static class MyTlsClient extends DefaultTlsClient { public void notifyAlertRaised(short alertLevel, short alertDescription, String message, Exception cause) { PrintStream out = (alertLevel == AlertLevel.fatal) ? System.err : System.out; out.println("TLS client raised alert (AlertLevel." + alertLevel + ", AlertDescription." + alertDescription + ")"); if (message != null) { out.println(message); } if (cause != null) { cause.printStackTrace(out); } } public void notifyAlertReceived(short alertLevel, short alertDescription) { PrintStream out = (alertLevel == AlertLevel.fatal) ? System.err : System.out; out.println("TLS client received alert (AlertLevel." + alertLevel + ", AlertDescription." + alertDescription + ")"); } public TlsAuthentication getAuthentication() throws IOException { return new TlsAuthentication() { public void notifyServerCertificate(org.bouncycastle.crypto.tls.Certificate serverCertificate) throws IOException { Certificate[] chain = serverCertificate.getCertificateList(); System.out.println("Received server certificate chain of length " + chain.length); for (int i = 0; i != chain.length; i++) { Certificate entry = chain[i]; // TODO Create fingerprint based on certificate signature algorithm digest System.out.println(" fingerprint:SHA-256 " + TlsTestUtils.fingerprint(entry) + " (" + entry.getSubject() + ")"); } } public TlsCredentials getClientCredentials(CertificateRequest certificateRequest) throws IOException { short[] certificateTypes = certificateRequest.getCertificateTypes(); if (certificateTypes != null) { for (int i = 0; i < certificateTypes.length; ++i) { if (certificateTypes[i] == ClientCertificateType.rsa_sign) { // TODO Create a distinct client certificate for use here return TlsTestUtils.loadSignerCredentials(context, new String[]{"x509-server.pem", "x509-ca.pem"}, "x509-server-key.pem"); } } } return null; } }; } } static class MyTlsServer extends DefaultTlsServer { public void notifyAlertRaised(short alertLevel, short alertDescription, String message, Exception cause) { PrintStream out = (alertLevel == AlertLevel.fatal) ? System.err : System.out; out.println("TLS server raised alert (AlertLevel." + alertLevel + ", AlertDescription." + alertDescription + ")"); if (message != null) { out.println(message); } if (cause != null) { cause.printStackTrace(out); } } public void notifyAlertReceived(short alertLevel, short alertDescription) { PrintStream out = (alertLevel == AlertLevel.fatal) ? System.err : System.out; out.println("TLS server received alert (AlertLevel." + alertLevel + ", AlertDescription." + alertDescription + ")"); } public CertificateRequest getCertificateRequest() { return new CertificateRequest(new short[]{ ClientCertificateType.rsa_sign }, null, null); } public void notifyClientCertificate(org.bouncycastle.crypto.tls.Certificate clientCertificate) throws IOException { Certificate[] chain = clientCertificate.getCertificateList(); System.out.println("Received client certificate chain of length " + chain.length); for (int i = 0; i != chain.length; i++) { Certificate entry = chain[i]; // TODO Create fingerprint based on certificate signature algorithm digest System.out.println(" fingerprint:SHA-256 " + TlsTestUtils.fingerprint(entry) + " (" + entry.getSubject() + ")"); } } protected TlsEncryptionCredentials getRSAEncryptionCredentials() throws IOException { return TlsTestUtils.loadEncryptionCredentials(context, new String[]{"x509-server.pem", "x509-ca.pem"}, "x509-server-key.pem"); } protected TlsSignerCredentials getRSASignerCredentials() throws IOException { return TlsTestUtils.loadSignerCredentials(context, new String[]{"x509-server.pem", "x509-ca.pem"}, "x509-server-key.pem"); } } }