package org.limewire.nio.ssl;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.Socket;
import java.nio.ByteBuffer;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLServerSocket;
import javax.net.ssl.SSLServerSocketFactory;
import junit.framework.Test;
import org.limewire.nio.NIODispatcher;
import org.limewire.nio.NIOTestUtils;
import org.limewire.nio.channel.ChannelReadObserver;
import org.limewire.nio.channel.InterestReadableByteChannel;
import org.limewire.nio.channel.WriteBufferChannel;
import org.limewire.util.BaseTestCase;
import org.limewire.util.StringUtils;
public class TLSNIOSocketTest extends BaseTestCase {
public TLSNIOSocketTest(String name) {
super(name);
}
public static Test suite() {
return buildTestSuite(TLSNIOSocketTest.class);
}
public void testConnectAndReadWriteBlocking() throws Exception {
SSLContext context = SSLContext.getInstance("TLS");
context.init(null, null, null);
SSLServerSocketFactory factory = context.getServerSocketFactory();
SSLServerSocket server = (SSLServerSocket)factory.createServerSocket(9999);
server.setNeedClientAuth(false);
server.setWantClientAuth(false);
server.setEnabledCipherSuites(new String[] {"TLS_DH_anon_WITH_AES_128_CBC_SHA"});
TLSNIOSocket socket = new TLSNIOSocket("127.0.0.1", 9999);
Socket accepted = server.accept();
OutputStream clientOut = socket.getOutputStream();
clientOut.write(StringUtils.toAsciiBytes("TEST TEST\r\n"));
clientOut.write(StringUtils.toAsciiBytes("\r\n"));
byte[] serverB = new byte[1000];
int serverRead = accepted.getInputStream().read(serverB);
assertEquals(13, serverRead);
assertEquals("TEST TEST\r\n\r\n", StringUtils.getASCIIString(serverB, 0, 13));
accepted.getOutputStream().write(StringUtils.toAsciiBytes("HELLO THIS IS A TEST!"));
InputStream clientIn = socket.getInputStream();
byte[] clientB = new byte[2048];
int clientRead = clientIn.read(clientB);
assertEquals(21, clientRead);
assertEquals("HELLO THIS IS A TEST!", StringUtils.getASCIIString(clientB, 0, 21));
socket.close();
accepted.close();
server.close();
}
public void testConnectAndReadWriteNonBlocking() throws Exception {
SSLContext context = SSLContext.getInstance("TLS");
context.init(null, null, null);
SSLServerSocketFactory factory = context.getServerSocketFactory();
SSLServerSocket server = (SSLServerSocket)factory.createServerSocket(9999);
server.setNeedClientAuth(false);
server.setWantClientAuth(false);
server.setEnabledCipherSuites(new String[] {"TLS_DH_anon_WITH_AES_128_CBC_SHA"});
TLSNIOSocket socket = new TLSNIOSocket("127.0.0.1", 9999);
Socket accepted = server.accept();
WriteBufferChannel clientOut = new WriteBufferChannel();
socket.setWriteObserver(clientOut);
NIODispatcher.instance().getScheduledExecutorService().submit(new Runnable() {public void run() {}}).get(); //wait for write to set
clientOut.setBuffer(ByteBuffer.wrap(StringUtils.toAsciiBytes("TEST TEST\r\n\r\n")));
byte[] serverB = new byte[1000];
int serverRead = accepted.getInputStream().read(serverB);
assertEquals(13, serverRead);
assertEquals("TEST TEST\r\n\r\n", StringUtils.getASCIIString(serverB, 0, 13));
accepted.getOutputStream().write(StringUtils.toAsciiBytes("HELLO THIS IS A TEST!"));
ReadTester reader = new ReadTester();
socket.setReadObserver(reader);
Thread.sleep(500);
ByteBuffer read = reader.getRead();
assertEquals(21, read.limit());
assertEquals("HELLO THIS IS A TEST!", StringUtils.getASCIIString(read.array(), 0, 21));
socket.close();
accepted.close();
server.close();
}
public void testConnectAndReadWriteSwitchBlockingMode() throws Exception {
SSLContext context = SSLContext.getInstance("TLS");
context.init(null, null, null);
SSLServerSocketFactory factory = context.getServerSocketFactory();
SSLServerSocket server = (SSLServerSocket)factory.createServerSocket(9999);
server.setNeedClientAuth(false);
server.setWantClientAuth(false);
server.setEnabledCipherSuites(new String[] {"TLS_DH_anon_WITH_AES_128_CBC_SHA"});
TLSNIOSocket socket = new TLSNIOSocket("127.0.0.1", 9999);
Socket accepted = server.accept();
OutputStream clientOutB = socket.getOutputStream();
clientOutB.write(StringUtils.toAsciiBytes("TEST TEST\r\n"));
clientOutB.write("\r\n".getBytes());
byte[] serverB = new byte[16];
int serverRead = accepted.getInputStream().read(serverB);
assertEquals(13, serverRead);
assertEquals("TEST TEST\r\n\r\n", StringUtils.getASCIIString(serverB, 0, 13));
accepted.getOutputStream().write(StringUtils.toAsciiBytes("HELLO THIS IS A TEST!"));
InputStream clientInB = socket.getInputStream();
byte[] clientReadB = new byte[15];
int clientRead = clientInB.read(clientReadB);
assertEquals(15, clientRead);
assertEquals("HELLO THIS IS A", StringUtils.getASCIIString(clientReadB, 0, 15));
WriteBufferChannel clientOutNB = new WriteBufferChannel();
socket.setWriteObserver(clientOutNB);
NIODispatcher.instance().getScheduledExecutorService().submit(new Runnable() {public void run() {}}).get(); //wait for write to set
clientOutNB.setBuffer(ByteBuffer.wrap(StringUtils.toAsciiBytes("MORE TEST\r\n")));
serverB = new byte[16];
serverRead = accepted.getInputStream().read(serverB);
assertEquals(11, serverRead);
assertEquals("MORE TEST\r\n", StringUtils.getASCIIString(serverB, 0, 11));
ReadTester reader = new ReadTester();
socket.setReadObserver(reader);
Thread.sleep(500);
ByteBuffer read = reader.getRead();
assertEquals(6, read.limit());
assertEquals(" TEST!", StringUtils.getASCIIString(read.array(), 0, 6));
socket.close();
accepted.close();
server.close();
}
public void testRead10KBNonBlocking() throws Exception {
// 10 kb
final byte[] serverBuffer = new byte[10 * 1024];
SSLContext context = SSLContext.getInstance("TLS");
context.init(null, null, null);
SSLServerSocketFactory factory = context.getServerSocketFactory();
SSLServerSocket acceptor = (SSLServerSocket)factory.createServerSocket(9999);
try {
acceptor.setNeedClientAuth(false);
acceptor.setWantClientAuth(false);
acceptor.setEnabledCipherSuites(new String[] {"TLS_DH_anon_WITH_AES_128_CBC_SHA"});
TLSNIOSocket client = new TLSNIOSocket("127.0.0.1", 9999);
client.setSoTimeout(500);
try {
final Socket server = acceptor.accept();
try {
// if the server is
client.getOutputStream().write(1);
server.getOutputStream().write(serverBuffer);
InputStream in = client.getInputStream();
byte[] buffer = new byte[512];
int total = 0;
while (total < serverBuffer.length) {
int read = in.read(buffer);
assertNotEquals(-1, read);
total += read;
}
assertEquals(serverBuffer.length, total);
} finally {
server.close();
}
} finally {
client.close();
}
} finally {
acceptor.close();
}
}
/**
* Test that if output on the remote side is shutdown after we write
* something, but before handshaking finishes, that we detect that and close
* the socket.
*/
public void testWriteOnlyReadShuts() throws Exception {
TLSNIOServerSocket server = new TLSNIOServerSocket(9999);
TLSNIOSocket socket = new TLSNIOSocket("127.0.0.1", 9999);
Socket accepted = server.accept();
OutputStream output = socket.getOutputStream();
output.write(StringUtils.toAsciiBytes("Bugfinder"));
// IMPORTANT: do not tell accepted to read here, otherwise
// handshaking could finish before we shutdown output.
accepted.shutdownOutput();
NIOTestUtils.waitForNIO();
assertTrue(socket.isClosed());
accepted.close();
server.close();
}
private static class ReadTester implements ChannelReadObserver {
private InterestReadableByteChannel source;
private ByteBuffer readData = ByteBuffer.allocate(128 * 1024);
// ChannelReader methods.
public InterestReadableByteChannel getReadChannel() { return source; }
public void setReadChannel(InterestReadableByteChannel channel) { source = channel; }
// IOErrorObserver methods.
public void handleIOException(IOException x) { fail(x); }
// ReadObserver methods.
public void handleRead() throws IOException {
source.read(readData);
assertEquals(0, source.read(readData)); // must have finish on first read.
}
// Shutdownable methods.
public void shutdown() {}
public ByteBuffer getRead() { return (ByteBuffer)readData.flip(); }
}
}