package org.webpieces.nio.test;
import java.io.FileInputStream;
import java.nio.ByteBuffer;
import java.security.KeyStore;
import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLEngineResult.HandshakeStatus;
import javax.net.ssl.SSLEngineResult.Status;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLSession;
import javax.net.ssl.TrustManagerFactory;
import junit.framework.TestCase;
public class TestSSLEngine extends TestCase {
//private static final Logger log = LoggerFactory.getLogger(TestSSLEngine.class);
/**
* Sunny day scenaior of normal two normal SSLEngines...no split packets,
* no failures, etc.
* @throws Exception
*/
public void testRawSSLEngine() throws Exception {
SSLEngine server = getServerEngine();
SSLEngine client = getClientEngine();
SSLSession s = client.getSession();
ByteBuffer unencrPacket = ByteBuffer.allocate(s.getApplicationBufferSize());
ByteBuffer encPacket = ByteBuffer.allocate(s.getPacketBufferSize());
encPacket.clear();
unencrPacket.clear();
doHandshakeAndVerify(client, server, unencrPacket, encPacket);
doRehandshake();
//connection is now established....
encPacket.clear();
client.closeOutbound();
SSLEngineResult result = client.wrap(unencrPacket, encPacket);
assertEquals(HandshakeStatus.NEED_UNWRAP, result.getHandshakeStatus());
assertEquals(Status.CLOSED, result.getStatus());
encPacket.flip();
result = server.unwrap(encPacket, unencrPacket);
assertEquals(HandshakeStatus.NEED_WRAP, result.getHandshakeStatus());
assertEquals(Status.CLOSED, result.getStatus());
encPacket.clear();
result = server.wrap(unencrPacket, encPacket);
assertEquals(HandshakeStatus.NOT_HANDSHAKING, result.getHandshakeStatus());
assertEquals(Status.CLOSED, result.getStatus());
encPacket.flip();
result = client.unwrap(encPacket, unencrPacket);
assertEquals(HandshakeStatus.NOT_HANDSHAKING, result.getHandshakeStatus());
assertEquals(Status.CLOSED, result.getStatus());
}
private void doRehandshake() {
// /*****************************************************
// * REHANDSHAKE BEGINS HERE.............
// *****************************************************/
//
// client.beginHandshake();
// assertEquals(HandshakeStatus.NEED_WRAP, client.getHandshakeStatus());
//
// encPacket.clear();
// SSLEngineResult result = client.wrap(unencrPacket, encPacket); //CLIENT HANDSHAKE MSG
// assertEquals(HandshakeStatus.NEED_UNWRAP, result.getHandshakeStatus());
// assertEquals(Status.OK, result.getStatus());
//
// String expected = "abc";
// ByteBuffer encData = ByteBuffer.allocate(s.getPacketBufferSize());
// ByteBuffer data = ByteBuffer.allocate(10);
// putString(data, expected);
// data.flip();
// log.trace("data1="+data+" encData="+encData);
// result = client.wrap(data, encData); //CLIENT WRAP DATA
// log.trace("data2="+data+" encData="+encData);
// assertEquals(HandshakeStatus.NEED_UNWRAP, result.getHandshakeStatus());
// assertEquals(Status.OK, result.getStatus());
//
// unencrPacket.clear();
// encPacket.flip();
// result = server.unwrap(encPacket, unencrPacket); //SERVER UNWRAP HANDSHAKE MSG
// assertEquals(HandshakeStatus.NEED_TASK, result.getHandshakeStatus());
// assertEquals(Status.OK, result.getStatus());
// server.getDelegatedTask(); //get task but don't run it yet...wait until after decrypt of real data
//
// /********************************************************
// * Found out this is expected behavior....until runnable is run SSLEngine can't be used.
// * BIG NOTE: Notice, I did not run the Runnable yet. If I put
// * r.run() right here, and change the assert statements below from NEED_TASK
// * to NEED_WRAP, the test will then pass!!!!!
// *
// *******************************************************/
//
// ByteBuffer dataOut = ByteBuffer.allocate(server.getSession().getApplicationBufferSize());
// dataOut.clear();
// encData.flip();
// log.trace("datain1="+encData+" out="+dataOut);
// result = server.unwrap(encData, dataOut); //SERVER UNWRAP DATA
// log.trace("datain2="+encData+" out="+dataOut);
// assertEquals(HandshakeStatus.NEED_TASK, result.getHandshakeStatus());
// assertEquals(Status.OK, result.getStatus());
//
// dataOut.flip();
// String actual = readString(dataOut, dataOut.remaining());
// assertEquals(expected, actual);
}
private void doHandshakeAndVerify(SSLEngine client, SSLEngine server, ByteBuffer unencrPacket, ByteBuffer encPacket) throws Exception {
startOfHandshake(client, server, unencrPacket, encPacket);
continueHandshake(client, server, unencrPacket, encPacket);
}
private void startOfHandshake(SSLEngine client, SSLEngine server, ByteBuffer unencrPacket, ByteBuffer encPacket) throws SSLException {
client.beginHandshake();
SSLEngineResult result = client.wrap(unencrPacket, encPacket);
assertEquals(result.getHandshakeStatus(), HandshakeStatus.NEED_UNWRAP);
assertEquals(result.getStatus(), Status.OK);
encPacket.flip();
result = server.unwrap(encPacket, unencrPacket);
assertEquals(HandshakeStatus.NEED_TASK, result.getHandshakeStatus());
assertEquals(Status.OK, result.getStatus());
Runnable r = server.getDelegatedTask();
r.run();
assertEquals(HandshakeStatus.NEED_WRAP, server.getHandshakeStatus());
encPacket.clear();
result = server.wrap(unencrPacket, encPacket);
assertEquals(HandshakeStatus.NEED_UNWRAP, result.getHandshakeStatus());
assertEquals(Status.OK, result.getStatus());
encPacket.flip();
result = client.unwrap(encPacket, unencrPacket);
assertEquals(HandshakeStatus.NEED_TASK, result.getHandshakeStatus());
assertEquals(Status.OK, result.getStatus());
r = client.getDelegatedTask();
r.run();
assertEquals(HandshakeStatus.NEED_WRAP, client.getHandshakeStatus());
encPacket.clear();
result = client.wrap(unencrPacket, encPacket);
assertEquals(HandshakeStatus.NEED_WRAP, result.getHandshakeStatus());
assertEquals(Status.OK, result.getStatus());
encPacket.flip();
result = server.unwrap(encPacket, unencrPacket);
assertEquals(HandshakeStatus.NEED_TASK, result.getHandshakeStatus());
assertEquals(Status.OK, result.getStatus());
r = server.getDelegatedTask();
r.run();
assertEquals(HandshakeStatus.NEED_UNWRAP, server.getHandshakeStatus());
}
private void continueHandshake(SSLEngine client, SSLEngine server, ByteBuffer unencrPacket, ByteBuffer encPacket) throws SSLException {
SSLEngineResult result;
encPacket.clear();
result = client.wrap(unencrPacket, encPacket);
assertEquals(HandshakeStatus.NEED_WRAP, result.getHandshakeStatus());
assertEquals(Status.OK, result.getStatus());
encPacket.flip();
result = server.unwrap(encPacket, unencrPacket);
assertEquals(HandshakeStatus.NEED_UNWRAP, result.getHandshakeStatus());
assertEquals(Status.OK, result.getStatus());
encPacket.clear();
result = client.wrap(unencrPacket, encPacket);
assertEquals(HandshakeStatus.NEED_UNWRAP, result.getHandshakeStatus());
assertEquals(Status.OK, result.getStatus());
encPacket.flip();
result = server.unwrap(encPacket, unencrPacket);
assertEquals(HandshakeStatus.NEED_WRAP, result.getHandshakeStatus());
assertEquals(Status.OK, result.getStatus());
encPacket.clear();
result = server.wrap(unencrPacket, encPacket);
assertEquals(HandshakeStatus.NEED_WRAP, result.getHandshakeStatus());
assertEquals(Status.OK, result.getStatus());
encPacket.flip();
result = client.unwrap(encPacket, unencrPacket);
assertEquals(HandshakeStatus.NEED_UNWRAP, result.getHandshakeStatus());
assertEquals(Status.OK, result.getStatus());
encPacket.clear();
result = server.wrap(unencrPacket, encPacket);
assertEquals(HandshakeStatus.FINISHED, result.getHandshakeStatus());
assertEquals(Status.OK, result.getStatus());
encPacket.flip();
result = client.unwrap(encPacket, unencrPacket);
assertEquals(HandshakeStatus.FINISHED, result.getHandshakeStatus());
assertEquals(Status.OK, result.getStatus());
}
private String password = "root01";
private String clientKeystore = "src/test/resources/client.keystore";
private String serverKeystore = "src/test/resources/server.keystore";
private SSLEngine getServerEngine() throws Exception {
char[] passphrase = password.toCharArray();
// First initialize the key and trust material.
KeyStore ks = KeyStore.getInstance("JKS");
ks.load(new FileInputStream(serverKeystore), passphrase);
SSLContext sslContext = SSLContext.getInstance("TLS");
//****************Server side specific*********************
// KeyManager's decide which key material to use.
KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509");
kmf.init(ks, passphrase);
sslContext.init(kmf.getKeyManagers(), null, null);
//****************Server side specific*********************
SSLEngine engine = sslContext.createSSLEngine();
engine.setUseClientMode(false);
SSLEngine server = engine;
return server;
}
private SSLEngine getClientEngine() throws Exception {
// Create/initialize the SSLContext with key material
char[] passphrase = password.toCharArray();
// First initialize the key and trust material.
KeyStore ks = KeyStore.getInstance("JKS");
ks.load(new FileInputStream(clientKeystore), passphrase);
SSLContext sslContext = SSLContext.getInstance("TLS");
//****************Client side specific*********************
// TrustManager's decide whether to allow connections.
TrustManagerFactory tmf = TrustManagerFactory.getInstance("SunX509");
tmf.init(ks);
sslContext.init(null, tmf.getTrustManagers(), null);
//****************Client side specific*********************
SSLEngine engine = sslContext.createSSLEngine();
engine.setUseClientMode(true);
SSLEngine client = engine;
return client;
}
// private void putString(ByteBuffer b, String fullString) {
// if(b == null)
// throw new IllegalArgumentException("Cannot pass in a null buffer");
// else if(fullString == null)
// throw new IllegalArgumentException("Cannot pass in a null string");
// byte[] encodedString;
// try {
// ByteArrayOutputStream out = new ByteArrayOutputStream();
// OutputStreamWriter writer = new OutputStreamWriter(out);
// writer.write(fullString);
// writer.flush();
// encodedString = out.toByteArray();
// } catch(IOException e) {
// throw new RuntimeException("Should never happen", e);
// }
//
// b.put(encodedString);
// }
public String readString(ByteBuffer b, int numBytesToRead) {
byte[] buffer = new byte[numBytesToRead];
b.get(buffer);
String s = new String(buffer);
return s;
}
}