/*
* Copyright (C) 2015 SoftIndex LLC.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.datakernel.eventloop;
import io.datakernel.bytebuf.ByteBuf;
import io.datakernel.eventloop.AsyncTcpSocket.EventHandler;
import org.hamcrest.BaseMatcher;
import org.hamcrest.Description;
import org.hamcrest.Matcher;
import org.jmock.Expectations;
import org.jmock.integration.junit4.JUnitRuleMockery;
import org.junit.Before;
import org.junit.Ignore;
import org.junit.Rule;
import org.junit.Test;
import javax.net.ssl.*;
import java.io.File;
import java.io.FileInputStream;
import java.io.InputStream;
import java.net.InetSocketAddress;
import java.security.KeyManagementException;
import java.security.KeyStore;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.util.Random;
import java.util.concurrent.Executor;
import static io.datakernel.bytebuf.ByteBufPool.*;
import static io.datakernel.eventloop.FatalErrorHandlers.rethrowOnAnyError;
import static org.hamcrest.Matchers.greaterThan;
import static org.junit.Assert.*;
@Ignore
public class AsyncSslSocketTest {
// <editor-fold desc="fields">
private static final String KEYSTORE_PATH = "./src/test/resources/keystore.jks";
private static final String KEYSTORE_PASS = "testtest";
private static final String KEY_PASS = "testtest";
private static final String TRUSTSTORE_PATH = "./src/test/resources/truststore.jks";
private static final String TRUSTSTORE_PASS = "testtest";
private Eventloop eventloop = Eventloop.create().withFatalErrorHandler(rethrowOnAnyError());
private AsyncSslSocket serverSslSocket;
private AsyncSslSocket clientSslSocket;
@Rule
public JUnitRuleMockery context = new JUnitRuleMockery();
private EventHandler clientEventHandler = context.mock(EventHandler.class, "clientEventHandler");
private EventHandler serverEventHandler = context.mock(EventHandler.class, "serverEventHandler");
private AsyncTcpSocketStub clientSocketStub;
private AsyncTcpSocketStub serverSocketStub;
// </editor-fold>
// <editor-fold desc="initialization">
@Before
public void init() throws Exception {
Executor executor = new ExecutorStub();
KeyManager[] keyManagers = createKeyManagers(new File(KEYSTORE_PATH), KEYSTORE_PASS, KEY_PASS);
TrustManager[] trustManagers = createTrustManagers(new File(TRUSTSTORE_PATH), TRUSTSTORE_PASS);
SSLContext sslContext = createSslContext("TLSv1.2", keyManagers, trustManagers, new SecureRandom());
SSLEngine serverSSLEngine = sslContext.createSSLEngine();
serverSSLEngine.setUseClientMode(false);
serverSocketStub = new AsyncTcpSocketStub("server", eventloop);
serverSslSocket = AsyncSslSocket.create(eventloop, serverSocketStub, serverSSLEngine, executor);
serverSocketStub.setEventHandler(serverSslSocket);
serverSslSocket.setEventHandler(serverEventHandler);
SSLEngine clientSSLEngine = sslContext.createSSLEngine();
clientSSLEngine.setUseClientMode(true);
clientSocketStub = new AsyncTcpSocketStub("client", eventloop);
clientSslSocket = AsyncSslSocket.create(eventloop, clientSocketStub, clientSSLEngine, executor);
clientSocketStub.setEventHandler(clientSslSocket);
clientSslSocket.setEventHandler(clientEventHandler);
// connect client and server stub sockets directly
clientSocketStub.connect(serverSocketStub);
}
// </editor-fold>
// <editor-fold desc="tests">
@Test
public void performsSimpleMessageExchange() throws NoSuchAlgorithmException {
context.checking(new Expectations() {{
oneOf(serverEventHandler).onRead(with(bytebufOfMessage("Hello")));
oneOf(clientEventHandler).onRead(with(bytebufOfMessage("World")));
allowing(serverEventHandler).onReadEndOfStream();
allowing(serverEventHandler).onWrite();
allowing(clientEventHandler).onReadEndOfStream();
allowing(clientEventHandler).onWrite();
allowing(serverEventHandler).onRegistered();
allowing(clientEventHandler).onRegistered();
}});
eventloop.post(new Runnable() {
@Override
public void run() {
serverSslSocket.onRegistered();
clientSslSocket.onRegistered();
serverSslSocket.read();
clientSslSocket.read();
clientSslSocket.write(createByteBufFromString("Hello"));
serverSslSocket.write(createByteBufFromString("World"));
}
});
eventloop.run();
System.out.println("created: " + getCreatedItems());
System.out.println("in pool: " + getPoolItems());
assertEquals(getPoolItemsString(), getCreatedItems(), getPoolItems());
}
@Test
public void sendsLargeAmountOfDataFromClientToServer() {
final StringBuilder sentData = new StringBuilder();
EventHandlerDataAccumulator serverDataAccumulator = new EventHandlerDataAccumulator(serverSslSocket);
serverSslSocket.setEventHandler(serverDataAccumulator);
context.checking(new Expectations() {{
ignoring(clientEventHandler);
}});
eventloop.post(new Runnable() {
@Override
public void run() {
serverSslSocket.onRegistered();
clientSslSocket.onRegistered();
serverSslSocket.read();
// send large message
String largeMessage = generateLargeString(100_000);
sentData.append(largeMessage);
clientSslSocket.write(createByteBufFromString(largeMessage));
// send lots of small messages
String smallMsg = "data_012345";
for (int i = 0; i < 25_000; i++) {
sentData.append(smallMsg);
clientSslSocket.write(createByteBufFromString(smallMsg));
}
}
});
eventloop.run();
assertThat("received bytes amount", serverDataAccumulator.getAccumulatedData().length(), greaterThan(0));
assertEquals(sentData.toString(), serverDataAccumulator.getAccumulatedData());
assertEquals(getPoolItemsString(), getCreatedItems(), getPoolItems());
}
@Test
public void sendsLargeAmountOfDataFromServerToClient() {
final StringBuilder sentData = new StringBuilder();
EventHandlerDataAccumulator clientDataAccumulator = new EventHandlerDataAccumulator(clientSslSocket);
clientSslSocket.setEventHandler(clientDataAccumulator);
context.checking(new Expectations() {{
ignoring(serverEventHandler);
}});
eventloop.post(new Runnable() {
@Override
public void run() {
serverSslSocket.onRegistered();
clientSslSocket.onRegistered();
clientSslSocket.read();
// send large message
String largeMessage = generateLargeString(100_000);
sentData.append(largeMessage);
serverSslSocket.write(createByteBufFromString(largeMessage));
// send lots of small messages
String smallMsg = "data_012345";
for (int i = 0; i < 25_000; i++) {
sentData.append(smallMsg);
serverSslSocket.write(createByteBufFromString(smallMsg));
}
}
});
eventloop.run();
assertTrue(clientDataAccumulator.getAccumulatedData().length() > 0);
assertEquals(sentData.toString(), clientDataAccumulator.getAccumulatedData());
assertEquals(getPoolItemsString(), getCreatedItems(), getPoolItems());
}
@Test
public void getsSSLExceptionWhenOtherSideWasClosedWithoutSpecifiedHandshakeMessage() {
context.checking(new Expectations() {{
// check first messages
oneOf(clientEventHandler).onRead(with(bytebufOfMessage("World")));
oneOf(serverEventHandler).onRead(with(bytebufOfMessage("Hello")));
// check error
oneOf(clientEventHandler).onClosedWithError(with(any(SSLException.class)));
allowing(serverEventHandler).onWrite();
allowing(clientEventHandler).onWrite();
allowing(clientEventHandler).onRegistered();
allowing(serverEventHandler).onRegistered();
allowing(serverEventHandler).onClosedWithError(with(any(Exception.class)));
}});
eventloop.post(new Runnable() {
@Override
public void run() {
serverSslSocket.onRegistered();
clientSslSocket.onRegistered();
serverSslSocket.read();
clientSslSocket.read();
clientSslSocket.write(createByteBufFromString("Hello"));
serverSslSocket.write(createByteBufFromString("World"));
eventloop.schedule(eventloop.currentTimeMillis() + 100, new Runnable() {
@Override
public void run() {
// write endOfStream directly to client stub socket
clientSocketStub.onReadEndOfStream();
}
});
}
});
eventloop.run();
assertEquals(getPoolItemsString(), getCreatedItems(), getPoolItems());
}
@Test
public void otherSideEventHandler_ReceivesEndOfStream_InCaseOfProperClosing() {
context.checking(new Expectations() {{
// check first messages
oneOf(serverEventHandler).onRead(with(bytebufOfMessage("Hello")));
oneOf(clientEventHandler).onRead(with(bytebufOfMessage("World")));
// check error
oneOf(serverEventHandler).onReadEndOfStream();
allowing(clientEventHandler).onRegistered();
allowing(serverEventHandler).onRegistered();
allowing(clientEventHandler).onWrite();
allowing(serverEventHandler).onWrite();
}});
eventloop.post(new Runnable() {
@Override
public void run() {
serverSslSocket.onRegistered();
clientSslSocket.onRegistered();
serverSslSocket.read();
clientSslSocket.read();
clientSslSocket.write(createByteBufFromString("Hello"));
serverSslSocket.write(createByteBufFromString("World"));
eventloop.schedule(eventloop.currentTimeMillis() + 100, new Runnable() {
@Override
public void run() {
clientSslSocket.close();
}
});
}
});
eventloop.run();
assertEquals(getPoolItemsString(), getCreatedItems(), getPoolItems());
}
// </editor-fold>
// <editor-fold desc="stub classes">
public static final class AsyncTcpSocketStub implements AsyncTcpSocket {
private String desc;
private Eventloop eventloop;
private AsyncTcpSocketStub otherSide;
private EventHandler downstreamEventHandler;
private boolean writeEndOfStream = false;
public void connect(AsyncTcpSocketStub otherSide) {
this.otherSide = otherSide;
otherSide.otherSide = this;
}
public AsyncTcpSocketStub(String desc, Eventloop eventloop) {
this.desc = desc;
this.eventloop = eventloop;
}
public void onRead(ByteBuf buf) {
downstreamEventHandler.onRead(buf);
}
public void onReadEndOfStream() {
downstreamEventHandler.onReadEndOfStream();
}
@Override
public void setEventHandler(EventHandler eventHandler) {
this.downstreamEventHandler = eventHandler;
}
@Override
public void read() {
}
@Override
public void write(final ByteBuf buf) {
assert !writeEndOfStream;
if (otherSide == null) {
buf.recycle();
return;
}
final AsyncTcpSocketStub cached = otherSide;
eventloop.postLater(new Runnable() {
@Override
public void run() {
cached.onRead(buf);
}
});
downstreamEventHandler.onWrite();
}
@Override
public void writeEndOfStream() {
assert !writeEndOfStream;
final AsyncTcpSocketStub cached = otherSide;
writeEndOfStream = true;
eventloop.postLater(new Runnable() {
@Override
public void run() {
cached.onReadEndOfStream();
}
});
}
@Override
public void close() {
if (otherSide != null) {
otherSide.otherSide = null;
otherSide = null;
}
}
@Override
public InetSocketAddress getRemoteSocketAddress() {
return null;
}
@Override
public String toString() {
return "desc: " + desc;
}
}
public static final class ExecutorStub implements Executor {
@Override
public void execute(Runnable command) {
command.run();
}
}
public static final class EventHandlerDataAccumulator implements EventHandler {
StringBuilder data = new StringBuilder();
AsyncTcpSocket upstream;
public EventHandlerDataAccumulator(AsyncTcpSocket upstream) {
this.upstream = upstream;
}
@Override
public void onRegistered() {
}
@Override
public void onRead(ByteBuf buf) {
data.append(extractMessageFromByteBuf(buf));
upstream.read();
}
@Override
public void onReadEndOfStream() {
}
@Override
public void onWrite() {
}
@Override
public void onClosedWithError(Exception e) {
}
public String getAccumulatedData() {
return data.toString();
}
}
// </editor-fold>
// <editor-fold desc="helper methods">
public static TrustManager[] createTrustManagers(File path, String pass) throws Exception {
KeyStore trustStore = KeyStore.getInstance("JKS");
try (InputStream trustStoreIS = new FileInputStream(path)) {
trustStore.load(trustStoreIS, pass.toCharArray());
}
TrustManagerFactory trustFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
trustFactory.init(trustStore);
return trustFactory.getTrustManagers();
}
public static KeyManager[] createKeyManagers(File path, String storePass, String keyPass) throws Exception {
KeyStore store = KeyStore.getInstance("JKS");
try (InputStream is = new FileInputStream(path)) {
store.load(is, storePass.toCharArray());
}
KeyManagerFactory kmf = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
kmf.init(store, keyPass.toCharArray());
return kmf.getKeyManagers();
}
public static SSLContext createSslContext(String algorithm, KeyManager[] keyManagers, TrustManager[] trustManagers,
SecureRandom secureRandom) throws NoSuchAlgorithmException, KeyManagementException {
SSLContext instance = SSLContext.getInstance(algorithm);
instance.init(keyManagers, trustManagers, secureRandom);
return instance;
}
public static ByteBuf createByteBufFromString(String message) {
return ByteBuf.wrapForReading(message.getBytes());
}
public static String extractMessageFromByteBuf(ByteBuf buf) {
String result = new String(buf.array(), buf.readPosition(), buf.readRemaining());
buf.recycle();
return result;
}
public static String generateLargeString(int size) {
StringBuilder builder = new StringBuilder();
Random random = new Random();
for (int i = 0; i < size; i++) {
int randNumber = random.nextInt(3);
if (randNumber == 0) {
builder.append('a');
} else if (randNumber == 1) {
builder.append('b');
} else if (randNumber == 2) {
builder.append('c');
}
}
return builder.toString();
}
// </editor-fold>
// <editor-fold desc="custom matchers">
public static Matcher<ByteBuf> bytebufOfMessage(final String message) {
return new BaseMatcher<ByteBuf>() {
@Override
public boolean matches(Object item) {
String extractedMessage = extractMessageFromByteBuf((ByteBuf) item);
return extractedMessage.equals(message);
}
@Override
public void describeTo(Description description) {
description.appendText("Message: " + message);
}
};
}
// </editor-fold>
}