package org.threadly.litesockets.utils; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; import java.io.FileInputStream; import java.io.IOException; import java.net.Socket; import java.net.SocketAddress; import java.nio.ByteBuffer; import java.nio.channels.SocketChannel; import java.security.KeyStore; import java.util.Arrays; import javax.net.ssl.KeyManagerFactory; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLEngine; import javax.net.ssl.TrustManager; import org.junit.After; import org.junit.Before; import org.junit.Test; import org.threadly.concurrent.future.FutureUtils; import org.threadly.concurrent.future.ListenableFuture; import org.threadly.litesockets.Client; import org.threadly.litesockets.NoThreadSocketExecuter; import org.threadly.litesockets.SocketExecuter; import org.threadly.litesockets.WireProtocol; import org.threadly.litesockets.buffers.MergedByteBuffers; import org.threadly.litesockets.buffers.ReuseableMergedByteBuffers; public class SSLProcessorTests { static final String STRING = "hello"; static final ByteBuffer STRINGBB = ByteBuffer.wrap(STRING.getBytes()); static final String LARGE_STRING; static final ByteBuffer LARGE_STRINGBB; static final String[] SIMPLE_ENCRYPT = new String[] {"SSL_DH_anon_WITH_DES_CBC_SHA"}; static final String[] OTHER_ENCRYPT = new String[] {"SSL_DHE_RSA_WITH_3DES_EDE_CBC_SHA"}; static { StringBuilder sb = new StringBuilder(); for(int i=0; i<100; i++) { sb.append(STRING); } LARGE_STRING = sb.toString(); LARGE_STRINGBB = ByteBuffer.wrap(LARGE_STRING.getBytes()); } SocketExecuter SE; TrustManager[] myTMs = new TrustManager [] {new SSLUtils.FullTrustManager() }; KeyStore KS; KeyManagerFactory kmf; SSLContext sslCtx; @Before public void start() throws Exception { SE = new NoThreadSocketExecuter(); SE.start(); KS = KeyStore.getInstance(KeyStore.getDefaultType()); System.out.println(ClassLoader.getSystemClassLoader().getResource("keystore.jks")); String filename = ClassLoader.getSystemClassLoader().getResource("keystore.jks").getFile(); FileInputStream ksf = new FileInputStream(filename); KS.load(ksf, "password".toCharArray()); kmf = KeyManagerFactory.getInstance("SunX509"); kmf.init(KS, "password".toCharArray()); sslCtx = SSLContext.getInstance("SSL"); sslCtx.init(kmf.getKeyManagers(), myTMs, null); System.out.println(Arrays.toString(sslCtx.createSSLEngine().getSupportedCipherSuites())); } @After public void stop() { System.gc(); System.out.println("Used Memory:" + (Runtime.getRuntime().totalMemory() - Runtime.getRuntime().freeMemory()) / (1024*1024)); } @Test public void notEncrypted() { FakeClient fc = new FakeClient(SE); SSLProcessor sp = new SSLProcessor(fc, sslCtx.createSSLEngine()); MergedByteBuffers mbb = sp.encrypt(STRINGBB.duplicate()); assertEquals(STRING, mbb.duplicate().getAsString(mbb.remaining())); MergedByteBuffers mbb2 = sp.decrypt(mbb); assertEquals(STRING, mbb2.duplicate().getAsString(mbb2.remaining())); } @Test public void encrypted() throws IOException { FakeClient fc = new FakeClient(SE); SSLEngine see = sslCtx.createSSLEngine(); see.setEnabledCipherSuites(SIMPLE_ENCRYPT); see.setUseClientMode(true); SSLProcessor sp = new SSLProcessor(fc, see); fc.setSSLProcessor(sp); FakeClient fc2 = new FakeClient(SE); SSLEngine see2 = sslCtx.createSSLEngine(); see2.setEnabledCipherSuites(SIMPLE_ENCRYPT); see2.setUseClientMode(false); SSLProcessor sp2 = new SSLProcessor(fc2, see2); fc2.setSSLProcessor(sp2); assertFalse(sp2.isEncrypted()); assertFalse(sp.isEncrypted()); assertFalse(sp.handShakeStarted()); sp.doHandShake(); assertTrue(sp.handShakeStarted()); assertFalse(sp2.handShakeStarted()); sp2.doHandShake(); assertTrue(sp2.handShakeStarted()); assertFalse(sp2.isEncrypted()); assertFalse(sp.isEncrypted()); while(true) { if(fc.canWrite()) { ByteBuffer bb = fc.getWriteBuffer(); fc2.addReadBuffer(bb); sp2.decrypt(fc2.getRead()); } else if(fc2.canWrite()) { ByteBuffer bb = fc2.getWriteBuffer(); fc.addReadBuffer(bb); sp.decrypt(fc.getRead()); } else { break; } } assertTrue(sp2.isEncrypted()); assertTrue(sp.isEncrypted()); MergedByteBuffers mbb = sp.encrypt(STRINGBB.duplicate()); byte[] ba = new byte[mbb.remaining()]; mbb.duplicate().get(ba); assertFalse(Arrays.equals(STRINGBB.array(), ba)); MergedByteBuffers dmbb = sp2.decrypt(mbb); assertEquals(STRING, dmbb.getAsString(dmbb.remaining())); } @Test public void noCommonCipher() throws IOException { FakeClient fc = new FakeClient(SE); SSLEngine see = sslCtx.createSSLEngine(); see.setEnabledCipherSuites(SIMPLE_ENCRYPT); see.setUseClientMode(true); SSLProcessor sp = new SSLProcessor(fc, see); fc.setSSLProcessor(sp); FakeClient fc2 = new FakeClient(SE); SSLEngine see2 = sslCtx.createSSLEngine(); see2.setEnabledCipherSuites(OTHER_ENCRYPT); see2.setUseClientMode(false); SSLProcessor sp2 = new SSLProcessor(fc2, see2); fc2.setSSLProcessor(sp2); assertFalse(sp2.isEncrypted()); assertFalse(sp.isEncrypted()); assertFalse(sp.handShakeStarted()); sp.doHandShake(); assertTrue(sp.handShakeStarted()); assertFalse(sp2.handShakeStarted()); sp2.doHandShake(); assertTrue(sp2.handShakeStarted()); assertFalse(sp2.isEncrypted()); assertFalse(sp.isEncrypted()); while(true) { if(fc.canWrite()) { ByteBuffer bb = fc.getWriteBuffer(); fc2.addReadBuffer(bb); sp2.decrypt(fc2.getRead()); } else if(fc2.canWrite()) { ByteBuffer bb = fc2.getWriteBuffer(); fc.addReadBuffer(bb); sp.decrypt(fc.getRead()); } else { break; } } assertFalse(sp2.isEncrypted()); assertFalse(sp.isEncrypted()); } @Test public void largeEncrypted() throws IOException { FakeClient fc = new FakeClient(SE); SSLEngine see = sslCtx.createSSLEngine(); see.setEnabledCipherSuites(SIMPLE_ENCRYPT); see.setUseClientMode(true); SSLProcessor sp = new SSLProcessor(fc, see); fc.setSSLProcessor(sp); FakeClient fc2 = new FakeClient(SE); SSLEngine see2 = sslCtx.createSSLEngine(); see2.setEnabledCipherSuites(SIMPLE_ENCRYPT); see2.setUseClientMode(false); SSLProcessor sp2 = new SSLProcessor(fc2, see2); fc2.setSSLProcessor(sp2); sp.doHandShake(); sp2.doHandShake(); while(true) { if(fc.canWrite()) { ByteBuffer bb = fc.getWriteBuffer(); fc2.addReadBuffer(bb); sp2.decrypt(fc2.getRead()); } else if(fc2.canWrite()) { ByteBuffer bb = fc2.getWriteBuffer(); fc.addReadBuffer(bb); sp.decrypt(fc.getRead()); } else { break; } } MergedByteBuffers mbb = sp.encrypt(LARGE_STRINGBB.duplicate()); byte[] ba = new byte[mbb.remaining()]; mbb.duplicate().get(ba); assertFalse(Arrays.equals(LARGE_STRINGBB.array(), ba)); MergedByteBuffers dmbb = new ReuseableMergedByteBuffers(); while(mbb.remaining() > 0) { MergedByteBuffers tmpmbb = sp2.decrypt(mbb.pullBuffer(1)); dmbb.add(tmpmbb); } assertEquals(LARGE_STRING, dmbb.getAsString(dmbb.remaining())); } public static class FakeClient extends Client { MergedByteBuffers writeBuffers = new ReuseableMergedByteBuffers(false); SSLProcessor sp; public FakeClient(SocketExecuter se) { super(se); } public void setSSLProcessor(SSLProcessor sp) { this.sp = sp; } @Override public void addReadBuffer(ByteBuffer bb) { super.addReadBuffer(bb); } @Override public boolean canWrite() { return writeBuffers.remaining() > 0; } @Override public boolean hasConnectionTimedOut() { // TODO Auto-generated method stub return false; } @Deprecated @Override public boolean setSocketOption(SocketOption so, int value) { // TODO Auto-generated method stub return false; } @Override public ListenableFuture<Boolean> connect() { // TODO Auto-generated method stub return null; } @Override protected void setConnectionStatus(Throwable t) { // TODO Auto-generated method stub } @Override public void setConnectionTimeout(int timeout) { // TODO Auto-generated method stub } @Override public int getTimeout() { // TODO Auto-generated method stub return 0; } @Override public int getWriteBufferSize() { // TODO Auto-generated method stub return 0; } @Override protected ByteBuffer getWriteBuffer() { return writeBuffers.popBuffer(); } @Override protected void reduceWrite(int size) { // TODO Auto-generated method stub } @Override protected SocketChannel getChannel() { // TODO Auto-generated method stub return null; } @Override public WireProtocol getProtocol() { // TODO Auto-generated method stub return null; } @Override protected Socket getSocket() { // TODO Auto-generated method stub return null; } @Override public void close() { // TODO Auto-generated method stub } @Override public SocketAddress getRemoteSocketAddress() { // TODO Auto-generated method stub return null; } @Override public SocketAddress getLocalSocketAddress() { // TODO Auto-generated method stub return null; } @Override public ListenableFuture<?> write(final ByteBuffer bb) { writeBuffers.add(sp.encrypt(bb)); return FutureUtils.immediateResultFuture(true); } @Override public ClientOptions clientOptions() { // TODO Auto-generated method stub return null; } } }