package org.limewire.nio.ssl; import java.net.InetSocketAddress; import java.net.ServerSocket; import java.net.Socket; import java.nio.ByteBuffer; import javax.net.ssl.SSLException; import junit.framework.Test; import org.limewire.nio.NIOServerSocket; import org.limewire.nio.NIOSocket; import org.limewire.nio.ProtocolBandwidthTracker; import org.limewire.util.BaseTestCase; import org.limewire.util.BufferUtils; import org.limewire.util.StringUtils; public class SSLUtilsTest extends BaseTestCase { public SSLUtilsTest(String name) { super(name); } public static Test suite() { return buildTestSuite(SSLUtilsTest.class); } public void testIsTLSEnabled() throws Exception { assertFalse(SSLUtils.isTLSEnabled(new Socket())); assertFalse(SSLUtils.isTLSEnabled(new NIOSocket())); assertTrue(SSLUtils.isTLSEnabled(new TLSNIOSocket())); } public void testIsStartTLSCapable() throws Exception { assertFalse(SSLUtils.isStartTLSCapable(new Socket())); assertTrue(SSLUtils.isStartTLSCapable(new NIOSocket())); assertTrue(SSLUtils.isStartTLSCapable(new TLSNIOSocket())); } public void testStartTLS() throws Exception { try { SSLUtils.startTLS(new Socket(), BufferUtils.getEmptyBuffer()); fail("expected exception"); } catch(IllegalArgumentException expected) {} Socket s = new NIOSocket(); assertTrue(SSLUtils.isStartTLSCapable(s)); assertFalse(SSLUtils.isTLSEnabled(s)); s = SSLUtils.startTLS(s, BufferUtils.getEmptyBuffer()); assertTrue(SSLUtils.isTLSEnabled(s)); try { SSLUtils.startTLS(new NIOSocket(), ByteBuffer.wrap(new byte[] { 'N', 'O', 'T', 'T', 'L', 'S' } )); fail("expected exception"); } catch(SSLException expected) {} ServerSocket ss = new NIOServerSocket(); ss.setSoTimeout(1000); ss.bind(new InetSocketAddress("localhost", 0)); Socket tls = new TLSSocketFactory().createSocket("localhost", ss.getLocalPort()); tls.getOutputStream().write(StringUtils.toAsciiBytes("OUTPUT")); Socket accepted = ss.accept(); assertFalse(SSLUtils.isTLSEnabled(accepted)); assertTrue(SSLUtils.isStartTLSCapable(accepted)); byte[] read = new byte[100]; int amt = accepted.getInputStream().read(read); assertGreaterThan(0, amt); assertNotEquals("OUTPUT", StringUtils.getASCIIString(read, 0, amt)); Socket converted = SSLUtils.startTLS(accepted, ByteBuffer.wrap(read, 0, amt)); amt = converted.getInputStream().read(read); // length of string works, since ascii encoding ensures 1-1 mapping between chars and bytes assertEquals("OUTPUT".length(), amt); assertEquals("OUTPUT", StringUtils.getASCIIString(read, 0, amt)); converted.close(); accepted.close(); ss.close(); s.close(); } public void testGetSSLBandwidthTracker() throws Exception { ProtocolBandwidthTracker t = SSLUtils.getSSLBandwidthTracker(new Socket()); assertSame(t, SSLUtils.EmptyTracker.instance());// a little stricter check than necessary t = SSLUtils.getSSLBandwidthTracker(new NIOSocket()); assertSame(t, SSLUtils.EmptyTracker.instance());// a little stricter check than necessary ServerSocket listening = new TLSServerSocketFactory().createServerSocket(); listening.setSoTimeout(1000); listening.bind(new InetSocketAddress("localhost", 0)); Socket outgoing = new TLSSocketFactory().createSocket("localhost", listening.getLocalPort()); Socket incoming = listening.accept(); ProtocolBandwidthTracker outTracker = SSLUtils.getSSLBandwidthTracker(outgoing); ProtocolBandwidthTracker inTracker = SSLUtils.getSSLBandwidthTracker(incoming); assertNotSame(outTracker, inTracker); assertNotSame(outTracker, SSLUtils.EmptyTracker.instance()); assertNotSame(inTracker, SSLUtils.EmptyTracker.instance()); outgoing.getOutputStream().write(StringUtils.toAsciiBytes("THIS IS OUTPUT")); incoming.getOutputStream().write(StringUtils.toAsciiBytes("INCOMING OUTPUT")); byte[] outRead = new byte[100]; byte[] inRead = new byte[100]; int outAmt = outgoing.getInputStream().read(outRead); int inAmt = incoming.getInputStream().read(inRead); assertEquals("THIS IS OUTPUT".length(), inAmt); assertEquals("INCOMING OUTPUT".length(), outAmt); assertEquals(inAmt, inTracker.getReadBytesProduced()); assertEquals(outAmt, inTracker.getWrittenBytesConsumed()); assertEquals(outAmt, outTracker.getReadBytesProduced()); assertEquals(inAmt, outTracker.getWrittenBytesConsumed()); assertGreaterThan(inAmt, inTracker.getReadBytesConsumed()); assertGreaterThan(outAmt, inTracker.getWrittenBytesProduced()); assertGreaterThan(outAmt, outTracker.getReadBytesConsumed()); assertGreaterThan(inAmt, outTracker.getWrittenBytesProduced()); outgoing.close(); incoming.close(); listening.close(); } }