package water.network;
import org.junit.Test;
import water.util.FileUtils;
import water.util.StringUtils;
import javax.net.ssl.SSLException;
import java.io.IOException;
import java.net.BindException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;
import java.util.concurrent.BrokenBarrierException;
import java.util.concurrent.CyclicBarrier;
import static org.junit.Assert.*;
import static water.util.FileUtils.*;
public class SSLSocketChannelFactoryTest {
private int port = 9999;
@Test
public void shouldHandshake() throws IOException, SSLContextException, BrokenBarrierException, InterruptedException {
SSLProperties props = new SSLProperties();
props.put("h2o_ssl_protocol", SecurityUtils.defaultTLSVersion());
props.put("h2o_ssl_jks_internal", getFile("src/test/resources/keystore.jks").getPath());
props.put("h2o_ssl_jks_password", "password");
props.put("h2o_ssl_jts", getFile("src/test/resources/cacerts.jks").getPath());
props.put("h2o_ssl_jts_password", "password");
final SSLSocketChannelFactory factory = new SSLSocketChannelFactory(props);
final CyclicBarrier barrier = new CyclicBarrier(2);
final CyclicBarrier testOne = new CyclicBarrier(2);
final CyclicBarrier testTwo = new CyclicBarrier(2);
final CyclicBarrier testThree = new CyclicBarrier(2);
final boolean[] hs = new boolean[]{true};
Thread client = new ClientThread(factory, testOne, testTwo, testThree, barrier);
client.setDaemon(false);
client.start();
try {
ServerSocketChannel serverSocketChannel = ServerSocketChannel.open();
serverSocketChannel.socket().setReceiveBufferSize(64 * 1024);
while(true) {
try {
serverSocketChannel.socket().bind(new InetSocketAddress(port));
break;
} catch (BindException e) {
port++;
}
}
barrier.await();
SocketChannel sock = serverSocketChannel.accept();
barrier.reset();
SSLSocketChannel wrappedChannel = (SSLSocketChannel) factory.wrapServerChannel(sock);
assertTrue(wrappedChannel.isHandshakeComplete());
// FIRST TEST: SSL -> SSL SMALL COMMUNICATION
ByteBuffer readBuffer = ByteBuffer.allocate(12);
while (readBuffer.hasRemaining()) {
wrappedChannel.read(readBuffer);
}
readBuffer.flip();
byte[] dst = new byte[12];
readBuffer.get(dst, 0, 12);
readBuffer.clear();
assertEquals("hello, world", new String(dst, "UTF-8"));
testOne.await();
// SECOND TEST: SSL -> SSL BIG COMMUNICATION
int read = 0;
byte[] dstBig = new byte[16];
ByteBuffer readBufferBig = ByteBuffer.allocate(1024);
while (read < 5 * 64 * 1024) {
while (readBufferBig.position() < 16) {
wrappedChannel.read(readBufferBig);
}
readBufferBig.flip();
readBufferBig.get(dstBig, 0, 16);
if (!readBufferBig.hasRemaining()) {
readBufferBig.clear();
} else {
readBufferBig.compact();
}
assertEquals("hello, world" + (read % 9) + "!!!", new String(dstBig, "UTF-8"));
read += 16;
}
testTwo.await();
// THIRD TEST: NON-SSL -> SSL COMMUNICATION
try {
while (readBuffer.hasRemaining()) {
wrappedChannel.read(readBuffer);
}
fail();
} catch (SSLException e) {
// PASSED
}
assertTrue(wrappedChannel.getEngine().isInboundDone());
testThree.await();
// FOURTH TEST: SSL -> NON-SSL COMMUNICATION
readBuffer.clear();
while (readBuffer.hasRemaining()) {
sock.read(readBuffer);
}
readBuffer.flip();
readBuffer.get(dst, 0, 12);
readBuffer.clear();
assertNotEquals("hello, world", new String(dst, "UTF-8"));
} catch (IOException | InterruptedException | BrokenBarrierException e) {
e.printStackTrace();
}
barrier.await();
assertTrue("One of the handshakes failed!", hs[0]);
}
private class ClientThread extends Thread {
private final SSLSocketChannelFactory factory;
private final CyclicBarrier testOne;
private final CyclicBarrier testTwo;
private final CyclicBarrier testThree;
private final CyclicBarrier barrier;
public ClientThread(SSLSocketChannelFactory factory,
CyclicBarrier testOne,
CyclicBarrier testTwo,
CyclicBarrier testThree,
CyclicBarrier barrier) {
this.factory = factory;
this.testOne = testOne;
this.testTwo = testTwo;
this.testThree = testThree;
this.barrier = barrier;
}
@Override
public void run() {
try {
barrier.await();
SocketChannel sock = SocketChannel.open();
sock.socket().setReuseAddress(true);
sock.socket().setSendBufferSize(64 * 1024);
InetSocketAddress isa = new InetSocketAddress("127.0.0.1", port);
sock.connect(isa);
sock.configureBlocking(true);
sock.socket().setTcpNoDelay(true);
SSLSocketChannel wrappedChannel = (SSLSocketChannel) factory.wrapClientChannel(sock, "127.0.0.1", port);
// FIRST TEST: SSL -> SSL SMALL COMMUNICATION
ByteBuffer write = ByteBuffer.allocate(1024);
write.put(StringUtils.bytesOf("hello, world"));
write.flip();
wrappedChannel.write(write);
testOne.await();
// SECOND TEST: SSL -> SSL BIG COMMUNICATION
ByteBuffer toWriteBig = ByteBuffer.allocate(64 * 1024);
for (int i = 0; i < 5; i++) {
toWriteBig.clear();
while (toWriteBig.hasRemaining()) {
toWriteBig.put(
StringUtils.bytesOf("hello, world" + ((i * 64 * 1024 + toWriteBig.position()) % 9) + "!!!")
);
}
toWriteBig.flip();
wrappedChannel.write(toWriteBig);
}
testTwo.await();
// THIRD TEST: NON-SSL -> SSL COMMUNICATION
write.clear();
write.put(StringUtils.bytesOf("hello, world"));
write.flip();
sock.write(write);
testThree.await();
// FOURTH TEST: SSL -> NON-SSL COMMUNICATION
write.clear();
write.put(StringUtils.bytesOf("hello, world"));
wrappedChannel.write(write);
} catch (IOException | InterruptedException | BrokenBarrierException e) {
e.printStackTrace();
} finally {
try {
barrier.await();
} catch (InterruptedException | BrokenBarrierException e) {
e.printStackTrace();
}
}
}
}
}