package org.webpieces.ssl.api;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.security.GeneralSecurityException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import javax.net.ssl.SSLEngine;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.webpieces.data.api.BufferCreationPool;
import org.webpieces.data.api.BufferPool;
public class TestSSLEngine2 {
private AsyncSSLEngine clientEngine;
private AsyncSSLEngine svrEngine;
private MockSslListener clientListener = new MockSslListener();
private MockSslListener svrListener = new MockSslListener();
@Before
public void setup() throws GeneralSecurityException, IOException {
MockSSLEngineFactory sslEngineFactory = new MockSSLEngineFactory();
BufferPool pool = new BufferCreationPool(false, 17000, 1000);
SSLEngine client = sslEngineFactory.createEngineForSocket();
SSLEngine svr = sslEngineFactory.createEngineForServerSocket();
clientEngine = AsyncSSLFactory.create("client", client, pool, clientListener);
svrEngine = AsyncSSLFactory.create("svr", svr, pool, svrListener);
Assert.assertEquals(ConnectionState.NOT_STARTED, clientEngine.getConnectionState());
Assert.assertEquals(ConnectionState.NOT_STARTED, svrEngine.getConnectionState());
clientEngine.beginHandshake();
Assert.assertEquals(ConnectionState.CONNECTING, clientEngine.getConnectionState());
ByteBuffer buffer = clientListener.getToSendToSocket().get(0);
svrEngine.feedEncryptedPacket(buffer);
Assert.assertEquals(ConnectionState.CONNECTING, svrEngine.getConnectionState());
Runnable r = svrListener.getRunnable();
r.run();
Assert.assertEquals(ConnectionState.CONNECTING, svrEngine.getConnectionState());
ByteBuffer buf = svrListener.getToSendToSocket().get(0);
clientEngine.feedEncryptedPacket(buf);
Assert.assertEquals(ConnectionState.CONNECTING, clientEngine.getConnectionState());
clientListener.getRunnable().run();
}
@Test
public void testBasic() throws GeneralSecurityException, IOException {
List<ByteBuffer> buffers = clientListener.getToSendToSocket();
svrEngine.feedEncryptedPacket(buffers.get(0));
svrListener.getRunnable().run();
svrEngine.feedEncryptedPacket(buffers.get(1));
Assert.assertEquals(ConnectionState.CONNECTING, clientEngine.getConnectionState());
svrEngine.feedEncryptedPacket(buffers.get(2));
Assert.assertEquals(ConnectionState.CONNECTED, svrEngine.getConnectionState());
Assert.assertTrue(svrListener.connected);
List<ByteBuffer> toClientBuffers = svrListener.getToSendToSocket();
clientEngine.feedEncryptedPacket(toClientBuffers.get(0));
Assert.assertEquals(ConnectionState.CONNECTING, clientEngine.getConnectionState());
clientEngine.feedEncryptedPacket(toClientBuffers.get(1));
Assert.assertEquals(ConnectionState.CONNECTED, clientEngine.getConnectionState());
Assert.assertTrue(clientListener.connected);
transferBigData();
}
private void transferBigData() {
ByteBuffer b = ByteBuffer.allocate(17000);
b.put((byte) 1);
b.put((byte) 2);
b.position(b.limit()-2); //simulate buffer full of 0's except first 2 and last 2
b.put((byte) 3);
b.put((byte) 4);
b.flip();
CompletableFuture<Void> future = clientEngine.feedPlainPacket(b);
List<ByteBuffer> encrypted = clientListener.getToSendToSocket();
//results in two ssl packets instead of the one that was fed in..
Assert.assertEquals(2, encrypted.size());
svrEngine.feedEncryptedPacket(encrypted.get(0));
ByteBuffer buffer = svrListener.getToSendToClient().get(0);
svrEngine.feedEncryptedPacket(encrypted.get(1));
ByteBuffer buffer2 = svrListener.getToSendToClient().get(0);
Assert.assertEquals(17000, buffer.remaining()+buffer2.remaining());
Assert.assertFalse(future.isDone());
List<CompletableFuture<Void>> futures = clientListener.getFutures();
futures.get(0).complete(null);
Assert.assertFalse(future.isDone());
futures.get(1).complete(null);
Assert.assertTrue(future.isDone());
}
@Test
public void testCombineBuffers() {
List<ByteBuffer> buffers = clientListener.getToSendToSocket();
ByteBuffer combine = combine(buffers);
svrEngine.feedEncryptedPacket(combine);
svrListener.getRunnable().run();
Assert.assertEquals(ConnectionState.CONNECTED, svrEngine.getConnectionState());
Assert.assertTrue(svrListener.connected);
List<ByteBuffer> toClientBuffers = svrListener.getToSendToSocket();
Assert.assertEquals(2, toClientBuffers.size());
}
private ByteBuffer combine(List<ByteBuffer> buffersToSend) {
int size = 0;
for(ByteBuffer b : buffersToSend) {
size += b.remaining();
}
ByteBuffer buf = ByteBuffer.allocate(size);
for(ByteBuffer b : buffersToSend) {
buf.put(b);
}
buf.flip();
return buf;
}
@Test
public void testRunnableRunAfterNextPacket() {
List<ByteBuffer> buffers = clientListener.getToSendToSocket();
svrEngine.feedEncryptedPacket(buffers.get(0));
Runnable run = svrListener.getRunnable();
svrEngine.feedEncryptedPacket(buffers.get(1));
svrEngine.feedEncryptedPacket(buffers.get(2));
run.run();
Assert.assertEquals(ConnectionState.CONNECTED, svrEngine.getConnectionState());
Assert.assertTrue(svrListener.connected);
List<ByteBuffer> toClientBuffers = svrListener.getToSendToSocket();
Assert.assertEquals(2, toClientBuffers.size());
}
@Test
public void testHalfThenTooMuchFedInPacket() {
List<ByteBuffer> buffers = clientListener.getToSendToSocket();
List<ByteBuffer> first = split(buffers.get(0));
List<ByteBuffer> second = split(buffers.get(1));
ByteBuffer halfAndHalf = combine(first.get(1), second.get(0));
svrEngine.feedEncryptedPacket(first.get(0));
svrEngine.feedEncryptedPacket(halfAndHalf);
Runnable run = svrListener.getRunnable();
run.run();
svrEngine.feedEncryptedPacket(second.get(1));
svrEngine.feedEncryptedPacket(buffers.get(2));
Assert.assertEquals(ConnectionState.CONNECTED, svrEngine.getConnectionState());
Assert.assertTrue(svrListener.connected);
}
@Test
public void testHalfThenTooMuchFedInPacketAndRunnableDelayed() {
List<ByteBuffer> buffers = clientListener.getToSendToSocket();
List<ByteBuffer> first = split(buffers.get(0));
List<ByteBuffer> second = split(buffers.get(1));
ByteBuffer halfAndHalf = combine(first.get(1), second.get(0));
svrEngine.feedEncryptedPacket(first.get(0));
svrEngine.feedEncryptedPacket(halfAndHalf);
Runnable run = svrListener.getRunnable();
svrEngine.feedEncryptedPacket(second.get(1));
run.run();
svrEngine.feedEncryptedPacket(buffers.get(2));
Assert.assertEquals(ConnectionState.CONNECTED, svrEngine.getConnectionState());
Assert.assertTrue(svrListener.connected);
}
private List<ByteBuffer> split(ByteBuffer byteBuffer) {
int splitPoint = byteBuffer.remaining() / 2;
byte[] one = new byte[splitPoint];
byte[] two = new byte[byteBuffer.remaining() - splitPoint];
byteBuffer.get(one);
byteBuffer.get(two);
if(byteBuffer.hasRemaining())
throw new RuntimeException("bug, shoudl have consumed it all");
ByteBuffer buf1 = ByteBuffer.wrap(one);
ByteBuffer buf2 = ByteBuffer.wrap(two);
List<ByteBuffer> list = new ArrayList<>();
list.add(buf1);
list.add(buf2);
return list;
}
private ByteBuffer combine(ByteBuffer byteBuffer, ByteBuffer byteBuffer2) {
ByteBuffer newBuf = ByteBuffer.allocate(byteBuffer.remaining()+byteBuffer2.remaining());
newBuf.put(byteBuffer);
newBuf.put(byteBuffer2);
newBuf.flip();
return newBuf;
}
}