package org.threadly.litesockets.tcp;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import java.io.File;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.security.KeyStore;
import java.util.concurrent.CancellationException;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicReference;
import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.TrustManager;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.threadly.concurrent.PriorityScheduler;
import org.threadly.litesockets.Client;
import org.threadly.litesockets.Client.Reader;
import org.threadly.litesockets.Server;
import org.threadly.litesockets.Server.ClientAcceptor;
import org.threadly.litesockets.SocketExecuter;
import org.threadly.litesockets.TCPClient;
import org.threadly.litesockets.TCPServer;
import org.threadly.litesockets.ThreadedSocketExecuter;
import org.threadly.litesockets.buffers.MergedByteBuffers;
import org.threadly.litesockets.buffers.ReuseableMergedByteBuffers;
import org.threadly.litesockets.utils.PortUtils;
import org.threadly.litesockets.utils.SSLUtils;
import org.threadly.test.concurrent.TestCondition;
public class SSLTests {
PriorityScheduler PS;
int port;
final String GET = "hello";
SocketExecuter SE;
TrustManager[] myTMs = new TrustManager [] {new SSLUtils.FullTrustManager() };
KeyStore KS;
KeyManagerFactory kmf;
SSLContext sslCtx;
FakeTCPServerClient serverFC;
@Before
public void start() throws Exception {
PS = new PriorityScheduler(5);
SE = new ThreadedSocketExecuter(PS);
SE.start();
port = PortUtils.findTCPPort();
KS = KeyStore.getInstance(KeyStore.getDefaultType());
File filename = new File(ClassLoader.getSystemClassLoader().getResource("test.pem").getFile());
kmf = SSLUtils.generateKeyStoreFromPEM(filename, filename);
sslCtx = SSLContext.getInstance("SSL");
sslCtx.init(kmf.getKeyManagers(), myTMs, null);
serverFC = new FakeTCPServerClient();
}
@After
public void stop() {
for(Server s: serverFC.getAllServers()) {
s.close();
}
for(Client c: serverFC.getAllClients()) {
c.close();
}
SE.stop();
PS.shutdownNow();
serverFC = new FakeTCPServerClient();
System.gc();
System.out.println("Used Memory:"
+ (Runtime.getRuntime().totalMemory() - Runtime.getRuntime().freeMemory()) / (1024*1024));
}
@Test(expected=IllegalStateException.class)
public void badSSLStart() throws Exception {
TCPServer server = SE.createTCPServer("localhost", port);
server.setSSLContext(sslCtx);
server.setDoHandshake(true);
serverFC.addTCPServer(server);
final TCPClient client = SE.createTCPClient("localhost", port);
client.startSSL();
}
@Test
public void simpleWriteTest() throws Exception {
long start = System.currentTimeMillis();
TCPServer server = SE.createTCPServer("localhost", port);
server.setSSLContext(sslCtx);
server.setDoHandshake(true);
serverFC.addTCPServer(server);
final TCPClient client = SE.createTCPClient("localhost", port);
SSLEngine sslec = sslCtx.createSSLEngine("localhost", port);
sslec.setUseClientMode(true);
client.setSSLEngine(sslec);
serverFC.addTCPClient(client);
client.connect().get(5000, TimeUnit.MILLISECONDS);
client.startSSL().get(5000, TimeUnit.MILLISECONDS);;
System.out.println(System.currentTimeMillis()-start);
new TestCondition(){
@Override
public boolean get() {
return serverFC.getNumberOfClients() == 2;
}
}.blockTillTrue(5000);
final TCPClient sclient = serverFC.getClientAt(1);
new TestCondition(){
@Override
public boolean get() {
return client.isEncrypted();
}
}.blockTillTrue(5000);
assertTrue(client.isEncrypted());
assertTrue(sclient.isEncrypted());
System.out.println("Write");
sclient.write(TCPTests.SMALL_TEXT_BUFFER.duplicate());
System.out.println("Write Done");
new TestCondition(){
@Override
public boolean get() {
return serverFC.getClientsBuffer(client).remaining() > 2;
}
}.blockTillTrue(5000);
String st = serverFC.getClientsBuffer(client).getAsString(serverFC.getClientsBuffer(client).remaining());
assertEquals(TCPTests.SMALL_TEXT, st);
}
@Test
public void sslClientTimeout() throws IOException, InterruptedException, ExecutionException, TimeoutException {
TCPServer server = SE.createTCPServer("localhost", port);
serverFC.addTCPServer(server);
long start = System.currentTimeMillis();
try {
final TCPClient client = SE.createTCPClient("localhost", port);
SSLEngine ssle = sslCtx.createSSLEngine("localhost", port);
ssle.setUseClientMode(true);
client.setSSLEngine(ssle);
client.setConnectionTimeout(201);
client.connect();
client.startSSL().get(5000, TimeUnit.MILLISECONDS);
fail();
} catch(CancellationException e) {
assertTrue(System.currentTimeMillis()-start >= 200);
}
server.close();
}
@Test
public void largeWriteTest() throws Exception{
TCPServer server = SE.createTCPServer("localhost", port);
server.setSSLContext(sslCtx);
server.setSSLHostName("localhost");
server.setDoHandshake(true);
serverFC.addTCPServer(server);
final TCPClient client = SE.createTCPClient("localhost", port);
SSLEngine sslec = sslCtx.createSSLEngine("localhost", port);
sslec.setUseClientMode(true);
client.setSSLEngine(sslec);
serverFC.addTCPClient(client);
client.startSSL();
new TestCondition(){
@Override
public boolean get() {
return serverFC.getNumberOfClients() == 2 && client.isEncrypted();
}
}.blockTillTrue(5000);
final TCPClient sclient = serverFC.getClientAt(1);
serverFC.addTCPClient(client);
for(int i=0; i<3; i++) {
sclient.write(TCPTests.LARGE_TEXT_BUFFER.duplicate());
}
new TestCondition(){
@Override
public boolean get() {
/*
System.out.println(serverFC.map.get(client).remaining()+":"+(TCPTests.LARGE_TEXT_BUFFER.remaining()*3));
System.out.println("w:"+sclient.finishedHandshake.get()+":"+sclient.startedHandshake.get());
System.out.println("w:"+sclient.ssle.getHandshakeStatus());
System.out.println("r:"+client.ssle.getHandshakeStatus());
System.out.println("r:"+client.getReadBufferSize());
*/
if(serverFC.getClientsBuffer(client) != null) {
return serverFC.getClientsBuffer(client).remaining() == TCPTests.LARGE_TEXT_BUFFER.remaining()*3;
}
return false;
}
}.blockTillTrue(5000);
String st = serverFC.getClientsBuffer(client).getAsString(TCPTests.LARGE_TEXT_BUFFER.remaining());
assertEquals(TCPTests.LARGE_TEXT, st);
st = serverFC.getClientsBuffer(client).getAsString(TCPTests.LARGE_TEXT_BUFFER.remaining());
assertEquals(TCPTests.LARGE_TEXT, st);
st = serverFC.getClientsBuffer(client).getAsString(TCPTests.LARGE_TEXT_BUFFER.remaining());
assertEquals(TCPTests.LARGE_TEXT, st);
}
// @Test(expected=IllegalStateException.class)
// public void useTCPClientPendingReads() throws IOException {
// TCPServer server = SE.createTCPServer("localhost", port);
// serverFC.addTCPServer(server);
//
// final TCPClient tcp_client = SE.createTCPClient("localhost", port);
// //serverFC.addTCPClient(tcp_client);
// SE.addClient(tcp_client);
// tcp_client.setReader(new Reader() {
// @Override
// public void onRead(Client client) {
// System.out.println("GOT READ");
// //We do nothing here
// }});
//
// new TestCondition(){
// @Override
// public boolean get() {
// return serverFC.clients.size() == 1;
// }
// }.blockTillTrue(5000);
// TCPClient sclient = (TCPClient) serverFC.clients.get(0);
//
// sclient.write(TCPTests.SMALL_TEXT_BUFFER.duplicate());
//
// new TestCondition(){
// @Override
// public boolean get() {
// return tcp_client.getReadBufferSize() > 0;
// }
// }.blockTillTrue(5000);
//
// final SSLClient client = new SSLClient(tcp_client, this.sslCtx.createSSLEngine("localhost", port), true, true);
// client.close();
// }
// @Test
// public void loop() throws Exception {
// for(int i=0; i<100; i++) {
// this.doLateSSLhandshake();
// stop();
// start();
// }
// }
@Test
public void doLateSSLhandshake() throws IOException, InterruptedException, ExecutionException, TimeoutException {
TCPServer server = SE.createTCPServer("localhost", port);
server.setSSLContext(sslCtx);
server.setSSLHostName("localhost");
server.setDoHandshake(false);
final AtomicReference<TCPClient> servers_client = new AtomicReference<TCPClient>();
final AtomicReference<String> serversEncryptedString = new AtomicReference<String>();
final AtomicReference<String> clientsEncryptedString = new AtomicReference<String>();
server.setClientAcceptor(new ClientAcceptor() {
@Override
public void accept(Client c) {
final TCPClient sslc = (TCPClient) c;
servers_client.set(sslc);
sslc.setReader(new Reader() {
MergedByteBuffers mbb = new ReuseableMergedByteBuffers();
boolean didSSL = false;
@Override
public void onRead(Client client) {
mbb.add(client.getRead());
if(!didSSL && mbb.remaining() >= 6) {
String tmp = mbb.getAsString(6);
if(tmp.equals("DO_SSL")) {
sslc.write(ByteBuffer.wrap("DO_SSL".getBytes()));
System.out.println("DOSSL-Server");
sslc.startSSL().addListener(new Runnable() {
@Override
public void run() {
didSSL = true;
System.out.println("DIDSSL-Server");
}});
}
} else {
if(mbb.remaining() >= 19) {
String tmp = mbb.getAsString(19);
serversEncryptedString.set(tmp);
client.write(ByteBuffer.wrap("THIS WAS ENCRYPTED!".getBytes()));
}
}
}});
//SE.addClient(sslc.getTCPClient());
}});
server.start();
final TCPClient sslclient = SE.createTCPClient("localhost", port);
SSLEngine sslec = sslCtx.createSSLEngine("localhost", port);
sslec.setUseClientMode(true);
sslclient.setSSLEngine(sslec);
sslclient.setReader(new Reader() {
MergedByteBuffers mbb = new ReuseableMergedByteBuffers();
boolean didSSL = false;
@Override
public void onRead(Client client) {
mbb.add(client.getRead());
if(!didSSL && mbb.remaining() >= 6) {
String tmp = mbb.getAsString(6);
if(tmp.equals("DO_SSL")) {
System.out.println("DOSSL");
sslclient.startSSL().addListener(new Runnable() {
@Override
public void run() {
didSSL = true;
sslclient.write(ByteBuffer.wrap("THIS WAS ENCRYPTED!".getBytes()));
System.out.println("DIDSSL");
}});
}
} else {
if(mbb.remaining() >= 19) {
String tmp = mbb.getAsString(19);
clientsEncryptedString.set(tmp);
}
}
}});
System.out.println(sslclient);
try {
sslclient.connect().get(5000, TimeUnit.MILLISECONDS);
} catch (ExecutionException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
System.out.println("WRITE!!");
try {
sslclient.write(ByteBuffer.wrap("DO_SSL".getBytes())).get(5000, TimeUnit.MILLISECONDS);
} catch (TimeoutException e) {
System.out.println("WRITE ERROR!! "+sslclient.getWriteBufferSize());
throw e;
}
System.out.println("WRITE DONE!!");
new TestCondition(){
@Override
public boolean get() {
// if(servers_client.get() != null) {
// System.out.println(servers_client.get().getReadBufferSize());
// }
return clientsEncryptedString.get() != null && serversEncryptedString.get() != null;
}
}.blockTillTrue(5000);
assertEquals(clientsEncryptedString.get(), serversEncryptedString.get());
}
}