package org.bouncycastle.tls.test; import java.io.OutputStream; import java.io.PipedInputStream; import java.io.PipedOutputStream; import java.security.SecureRandom; import junit.framework.TestCase; import org.bouncycastle.tls.ProtocolVersion; import org.bouncycastle.util.Arrays; import org.bouncycastle.util.io.Streams; public class TlsTestCase extends TestCase { private static void checkTLSVersion(ProtocolVersion version) { if (version != null && !version.isTLS()) { throw new IllegalStateException("Non-TLS version"); } } protected final TlsTestConfig config; public TlsTestCase(String name) { super(name); this.config = null; } public TlsTestCase(TlsTestConfig config, String name) { super(name); checkTLSVersion(config.clientMinimumVersion); checkTLSVersion(config.clientOfferVersion); checkTLSVersion(config.serverMaximumVersion); checkTLSVersion(config.serverMinimumVersion); this.config = config; } public void testDummy() { // Avoid "No tests found" warning from junit } protected void runTest() throws Throwable { // Disable the test if it is not being run via TlsTestSuite if (config == null) { return; } SecureRandom secureRandom = new SecureRandom(); PipedInputStream clientRead = new PipedInputStream(); PipedInputStream serverRead = new PipedInputStream(); PipedOutputStream clientWrite = new PipedOutputStream(serverRead); PipedOutputStream serverWrite = new PipedOutputStream(clientRead); NetworkInputStream clientNetIn = new NetworkInputStream(clientRead); NetworkInputStream serverNetIn = new NetworkInputStream(serverRead); NetworkOutputStream clientNetOut = new NetworkOutputStream(clientWrite); NetworkOutputStream serverNetOut = new NetworkOutputStream(serverWrite); TlsTestClientProtocol clientProtocol = new TlsTestClientProtocol(clientNetIn, clientNetOut, config); TlsTestServerProtocol serverProtocol = new TlsTestServerProtocol(serverNetIn, serverNetOut, config); TlsTestClientImpl clientImpl = new TlsTestClientImpl(config); TlsTestServerImpl serverImpl = new TlsTestServerImpl(config); ServerThread serverThread = new ServerThread(serverProtocol, serverImpl); serverThread.start(); Exception caught = null; try { clientProtocol.connect(clientImpl); // NOTE: Because we write-all before we read-any, this length can't be more than the pipe capacity int length = 1000; byte[] data = new byte[length]; secureRandom.nextBytes(data); OutputStream output = clientProtocol.getOutputStream(); output.write(data); byte[] echo = new byte[data.length]; int count = Streams.readFully(clientProtocol.getInputStream(), echo); assertEquals(count, data.length); assertTrue(Arrays.areEqual(data, echo)); assertNotNull(clientImpl.tlsUnique); assertNotNull(serverImpl.tlsUnique); assertTrue(Arrays.areEqual(clientImpl.tlsUnique, serverImpl.tlsUnique)); output.close(); } catch (Exception e) { caught = e; logException(caught); } serverThread.allowExit(); serverThread.join(); assertTrue("Client InputStream not closed", clientNetIn.isClosed()); assertTrue("Client OutputStream not closed", clientNetOut.isClosed()); assertTrue("Server InputStream not closed", serverNetIn.isClosed()); assertTrue("Server OutputStream not closed", serverNetOut.isClosed()); assertEquals("Client fatal alert connection end", config.expectFatalAlertConnectionEnd, clientImpl.firstFatalAlertConnectionEnd); assertEquals("Server fatal alert connection end", config.expectFatalAlertConnectionEnd, serverImpl.firstFatalAlertConnectionEnd); assertEquals("Client fatal alert description", config.expectFatalAlertDescription, clientImpl.firstFatalAlertDescription); assertEquals("Server fatal alert description", config.expectFatalAlertDescription, serverImpl.firstFatalAlertDescription); if (config.expectFatalAlertConnectionEnd == -1) { assertNull("Unexpected client exception", caught); assertNull("Unexpected server exception", serverThread.caught); } } protected void logException(Exception e) { if (TlsTestConfig.DEBUG) { e.printStackTrace(); } } class ServerThread extends Thread { protected final TlsTestServerProtocol serverProtocol; protected final TlsTestServerImpl serverImpl; boolean canExit = false; Exception caught = null; ServerThread(TlsTestServerProtocol serverProtocol, TlsTestServerImpl serverImpl) { this.serverProtocol = serverProtocol; this.serverImpl = serverImpl; } synchronized void allowExit() { canExit = true; this.notifyAll(); } public void run() { try { serverProtocol.accept(serverImpl); Streams.pipeAll(serverProtocol.getInputStream(), serverProtocol.getOutputStream()); serverProtocol.close(); } catch (Exception e) { caught = e; logException(caught); } waitExit(); } protected synchronized void waitExit() { while (!canExit) { try { this.wait(); } catch (InterruptedException e) { } } } } }