package io.craft.atom.protocol.ssl; import io.craft.atom.io.AbstractIoHandler; import io.craft.atom.io.Channel; import io.craft.atom.io.IoAcceptor; import io.craft.atom.io.IoHandler; import io.craft.atom.nio.api.NioFactory; import io.craft.atom.protocol.ssl.api.SslCodecFactory; import io.craft.atom.test.AvailablePortFinder; import io.craft.atom.test.CaseCounter; import java.io.BufferedReader; import java.io.InputStreamReader; import java.net.InetAddress; import java.net.Socket; import java.security.KeyStore; import java.security.Security; import javax.net.ssl.KeyManagerFactory; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLSocketFactory; import javax.net.ssl.TrustManagerFactory; import javax.xml.ws.ProtocolException; import junit.framework.Assert; import org.junit.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * @author mindwind * @version 1.0, Oct 18, 2013 */ public class TestSslCodec { private static final Logger LOG = LoggerFactory.getLogger(TestSslCodec.class); private static final String SSL_CODEC = "ssl.codec" ; private static final String ALGORITHM ; private static final int PORT ; private static final int MSG_NUM = 100 ; private static volatile int count = 0 ; private static Exception clientError = null ; static { String algorithm = Security.getProperty("ssl.KeyManagerFactory.algorithm"); if (algorithm == null) { algorithm = KeyManagerFactory.getDefaultAlgorithm(); } PORT = AvailablePortFinder.getNextAvailable(5555); ALGORITHM = algorithm; } // ~ --------------------------------------------------------------------------------------------------------- private static SSLContext createSSLContext() { try { return createSSLContext0(); } catch (Exception e) { e.printStackTrace(); throw new ProtocolException(e); } } private static SSLContext createSSLContext0() throws Exception { char[] passphrase = "password".toCharArray(); SSLContext ctx = SSLContext.getInstance("TLS"); KeyManagerFactory kmf = KeyManagerFactory.getInstance(ALGORITHM); TrustManagerFactory tmf = TrustManagerFactory.getInstance(ALGORITHM); KeyStore ks = KeyStore.getInstance("JKS"); KeyStore ts = KeyStore.getInstance("JKS"); ks.load(TestSslCodec.class.getResourceAsStream("/ssl.keystore"), passphrase); ts.load(TestSslCodec.class.getResourceAsStream("/ssl.truststore"), passphrase); kmf.init(ks, passphrase); tmf.init(ts); ctx.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null); return ctx; } private static void startServer() throws Exception { IoHandler handler = new TestIoHandler(); IoAcceptor acceptor = NioFactory.newTcpAcceptor(handler); acceptor.bind(PORT); } private static void startClient() throws Exception { InetAddress address = InetAddress.getByName("localhost"); SSLContext context = createSSLContext(); SSLSocketFactory factory = context.getSocketFactory(); connectAndSend(address, factory); } private static void connectAndSend(InetAddress address, SSLSocketFactory factory) throws Exception { Socket parent = new Socket(address, PORT); Socket socket = factory.createSocket(parent, address.getCanonicalHostName(), PORT, false); for (int i = 0; i < MSG_NUM; i++) { LOG.debug("[CRAFT-ATOM-PROTOCOL-SSL] Client send: hello {}", i); socket.getOutputStream().write("hello\n".getBytes()); socket.getOutputStream().flush(); socket.setSoTimeout(10000); BufferedReader in = new BufferedReader(new InputStreamReader(socket.getInputStream())); String line = in.readLine(); LOG.debug("[CRAFT-ATOM-PROTOCOL-SSL] Client got {}", line); ++count; } socket.close(); } @Test public void testSslCodec() throws Exception { startServer(); Thread t = new Thread() { public void run() { try { startClient(); } catch (Exception e) { clientError = e; } } }; t.start(); if (clientError != null) { throw clientError; } t.join(3000); Assert.assertEquals(MSG_NUM, count); System.out.println(String.format("[CRAFT-ATOM-PROTOCOL-SSL] (^_^) <%s> Case -> test ssl codec. ", CaseCounter.incr(1))); } // ~ --------------------------------------------------------------------------------------------------------- private static class TestIoHandler extends AbstractIoHandler { @Override public void channelOpened(Channel<byte[]> channel) { SSLContext ctx = createSSLContext(); io.craft.atom.protocol.ssl.api.SslCodec codec = SslCodecFactory.newSslCodec(ctx, new NioSslHandshakeHandler(channel)); channel.setAttribute(SSL_CODEC, codec); } @Override public void channelRead(Channel<byte[]> channel, byte[] bytes) { io.craft.atom.protocol.ssl.api.SslCodec codec = (io.craft.atom.protocol.ssl.api.SslCodec) channel.getAttribute(SSL_CODEC); byte[] ddata = codec.decode(bytes); if (ddata != null) { LOG.debug("[CRAFT-ATOM-PROTOCOL-SSL] Receive data={}", new String(ddata)); } if (ddata != null) { byte[] edata = codec.encode("hi, how are you?\n".getBytes()); channel.write(edata); if (edata != null) { LOG.debug("[CRAFT-ATOM-PROTOCOL-SSL] Sent data={}", new String(edata)); } } } } private static class NioSslHandshakeHandler implements io.craft.atom.protocol.ssl.spi.SslHandshakeHandler { private Channel<byte[]> channel; public NioSslHandshakeHandler(Channel<byte[]> channel) { this.channel = channel; } @Override public void needWrite(byte[] bytes) { channel.write(bytes); } } }