package org.bouncycastle.tls.test; import java.io.OutputStream; import java.io.PipedInputStream; import java.io.PipedOutputStream; import java.security.SecureRandom; import junit.framework.TestCase; import org.bouncycastle.tls.TlsClientProtocol; import org.bouncycastle.tls.TlsServerProtocol; import org.bouncycastle.util.Arrays; import org.bouncycastle.util.io.Streams; public class TlsPSKProtocolTest extends TestCase { public void testClientServer() throws Exception { SecureRandom secureRandom = new SecureRandom(); PipedInputStream clientRead = new PipedInputStream(); PipedInputStream serverRead = new PipedInputStream(); PipedOutputStream clientWrite = new PipedOutputStream(serverRead); PipedOutputStream serverWrite = new PipedOutputStream(clientRead); TlsClientProtocol clientProtocol = new TlsClientProtocol(clientRead, clientWrite); TlsServerProtocol serverProtocol = new TlsServerProtocol(serverRead, serverWrite); ServerThread serverThread = new ServerThread(serverProtocol); serverThread.start(); MockPSKTlsClient client = new MockPSKTlsClient(null); clientProtocol.connect(client); // NOTE: Because we write-all before we read-any, this length can't be more than the pipe capacity int length = 1000; byte[] data = new byte[length]; secureRandom.nextBytes(data); OutputStream output = clientProtocol.getOutputStream(); output.write(data); byte[] echo = new byte[data.length]; int count = Streams.readFully(clientProtocol.getInputStream(), echo); assertEquals(count, data.length); assertTrue(Arrays.areEqual(data, echo)); output.close(); serverThread.join(); } static class ServerThread extends Thread { private final TlsServerProtocol serverProtocol; ServerThread(TlsServerProtocol serverProtocol) { this.serverProtocol = serverProtocol; } public void run() { try { MockPSKTlsServer server = new MockPSKTlsServer(); serverProtocol.accept(server); Streams.pipeAll(serverProtocol.getInputStream(), serverProtocol.getOutputStream()); serverProtocol.close(); } catch (Exception e) { // throw new RuntimeException(e); } } } }