package net.jxta.impl.endpoint.netty;
import static net.jxta.impl.endpoint.netty.NettyTestUtils.*;
import static org.junit.Assert.*;
import java.io.IOException;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.net.URI;
import java.net.UnknownHostException;
import java.nio.ByteBuffer;
import java.util.List;
import net.jxta.document.MimeMediaType;
import net.jxta.endpoint.EndpointAddress;
import net.jxta.endpoint.Message;
import net.jxta.endpoint.StringMessageElement;
import net.jxta.endpoint.WireFormatMessage;
import net.jxta.endpoint.WireFormatMessageFactory;
import net.jxta.impl.endpoint.msgframing.MessagePackageHeader;
import net.jxta.impl.endpoint.msgframing.WelcomeMessage;
import net.jxta.peer.PeerID;
import org.jboss.netty.buffer.ChannelBuffer;
import org.jboss.netty.buffer.ChannelBuffers;
import org.jboss.netty.channel.ChannelEvent;
import org.jboss.netty.channel.ChannelPipeline;
import org.jboss.netty.channel.ChannelState;
import org.jboss.netty.channel.Channels;
import org.jboss.netty.channel.DownstreamMessageEvent;
import org.jboss.netty.channel.MessageEvent;
import org.jboss.netty.channel.UpstreamChannelStateEvent;
import org.jboss.netty.channel.UpstreamMessageEvent;
import org.jboss.netty.handler.timeout.TimeoutException;
import org.junit.Before;
import org.junit.Test;
public class JxtaProtocolHandlerTest {
public static final String TEST_PROTO_NAME = "htun";
private static final InetAddress SERVER_BIND_ADDR = createAddress(new byte[] { 0, 0, 0, 0 });
private static final InetAddress LOCAL_ADDR = createAddress(new byte[] { 10, 1, 1, 1 });
private static final SocketAddress PARENT_SOCK_ADDR = new InetSocketAddress(SERVER_BIND_ADDR, 12345);
private static final SocketAddress LOCAL_SOCK_ADDR = new InetSocketAddress(LOCAL_ADDR, 57043);
private static final EndpointAddress LOCAL_ENDPOINT_ADDR = new EndpointAddress(TEST_PROTO_NAME, "10.1.1.1:12345", null, null);
private static final SocketAddress REMOTE_SOCK_ADDR = InetSocketAddress.createUnresolved("remoteaddr", 54321);
private static final EndpointAddress REMOTE_ENDPOINT_ADDR = new EndpointAddress(TEST_PROTO_NAME, "remoteaddr:54321", null, null);
private static final EndpointAddress REMOTE_ENDPOINT_ADDR_WITH_PARAMS = new EndpointAddress(TEST_PROTO_NAME, "remoteAddr:54321", "TestService", "testparams");
private static final PeerID LOCAL_PEER_ID = PeerID.create(URI.create("urn:jxta:uuid-59616261646162614E5047205032503304F8E1DEBB4942C0BF16DD923DEC949803"));
private static final PeerID REMOTE_PEER_ID = PeerID.create(URI.create("urn:jxta:uuid-59616261646162614E50472050325033E7E1335996F44E38BD66B16349BB1F1E03"));
private static InetAddress createAddress(byte[] b) {
try {
return InetAddress.getByAddress(b);
} catch (UnknownHostException e) {
throw new RuntimeException("Bad address in test");
}
}
private FakeChannel channel;
private JxtaProtocolHandler handler;
private FakeTimer timeoutTimer;
private FakeChannelSink downstreamCatcher;
private UpstreamEventCatcher upstreamCatcher;
private Message testMessage;
@Before
public void setUp() throws Exception {
timeoutTimer = new FakeTimer();
handler = new JxtaProtocolHandler(new InetSocketAddressTranslator(TEST_PROTO_NAME), LOCAL_PEER_ID, timeoutTimer, REMOTE_ENDPOINT_ADDR_WITH_PARAMS, LOCAL_ENDPOINT_ADDR);
downstreamCatcher = new FakeChannelSink();
ChannelPipeline pipeline = Channels.pipeline();
upstreamCatcher = new UpstreamEventCatcher();
pipeline.addLast(JxtaProtocolHandler.NAME, handler);
pipeline.addLast(UpstreamEventCatcher.NAME, upstreamCatcher);
FakeChannel parent = new FakeChannel(null, null, Channels.pipeline(), new FakeChannelSink());
parent.localAddress = PARENT_SOCK_ADDR;
parent.bound = true;
parent.connected = false;
channel = new FakeChannel(parent, null, pipeline, downstreamCatcher);
channel.localAddress = LOCAL_SOCK_ADDR;
channel.remoteAddress = REMOTE_SOCK_ADDR;
channel.bound = true;
channel.connected = true;
testMessage = createTestMessage();
}
@Test
public void testSendsWelcomeMessageImmediately() throws Exception {
final WelcomeMessage expectedWelcomeMessage = new WelcomeMessage(REMOTE_ENDPOINT_ADDR_WITH_PARAMS, LOCAL_ENDPOINT_ADDR, LOCAL_PEER_ID, false);
Channels.fireChannelConnected(channel, REMOTE_SOCK_ADDR);
assertEquals(1, downstreamCatcher.events.size());
ChannelEvent event = downstreamCatcher.events.poll();
assertTrue(event instanceof DownstreamMessageEvent);
DownstreamMessageEvent msgEv = (DownstreamMessageEvent)event;
assertTrue(msgEv.getMessage() instanceof ChannelBuffer);
ChannelBuffer sentData = (ChannelBuffer)msgEv.getMessage();
ByteBuffer data = convertReadable(sentData);
WelcomeMessage sent = new WelcomeMessage();
assertTrue(sent.read(data));
assertEquals(expectedWelcomeMessage.getWelcomeString(), sent.getWelcomeString());
}
@Test
public void testSignalsConnectedOnceReceiveWelcomeMessage() throws Exception {
emulateConnect();
ChannelBuffer welcomeBytes = createRemoteWelcomeMessageBuffer();
Channels.fireMessageReceived(channel, welcomeBytes);
assertEquals(2, upstreamCatcher.events.size());
checkIsWelcomeMessage(upstreamCatcher.events.poll());
checkUpstreamChannelStateEvent(upstreamCatcher.events.poll(), ChannelState.CONNECTED, REMOTE_SOCK_ADDR);
}
private void checkIsWelcomeMessage(ChannelEvent ev) {
assertTrue(ev instanceof UpstreamMessageEvent);
assertTrue(((UpstreamMessageEvent)ev).getMessage() instanceof WelcomeMessage);
}
@Test
public void testReceiveWelcomeMessageInChunks() throws Exception {
emulateConnect();
final WelcomeMessage receivedWelcomeMessage = new WelcomeMessage(LOCAL_ENDPOINT_ADDR, REMOTE_ENDPOINT_ADDR, REMOTE_PEER_ID, false);
ByteBuffer allBytes = receivedWelcomeMessage.getByteBuffer();
List<ChannelBuffer> parts = splitIntoChunks(4, ChannelBuffers.wrappedBuffer(allBytes));
for(ChannelBuffer part : parts.subList(0, parts.size() - 1)) {
Channels.fireMessageReceived(channel, part);
checkQueuesEmpty();
}
Channels.fireMessageReceived(channel, parts.get(parts.size() - 1));
assertEquals(2, upstreamCatcher.events.size());
checkIsWelcomeMessage(upstreamCatcher.events.poll());
checkUpstreamChannelStateEvent(upstreamCatcher.events.poll(), ChannelState.CONNECTED, REMOTE_SOCK_ADDR);
}
@Test
public void testWelcomeMessageTimeout() throws Exception {
Channels.fireChannelConnected(channel, REMOTE_SOCK_ADDR);
assertEquals(1, timeoutTimer.numRegisteredTimeouts());
// we don't care about anything sent up or downstream at this point
clearQueues();
Channels.fireExceptionCaught(channel, new TimeoutException());
assertEquals(1, downstreamCatcher.events.size());
ChannelEvent ev = downstreamCatcher.events.poll();
checkDownstreamChannelStateEvent(ev, ChannelState.OPEN, Boolean.FALSE);
}
@Test
public void testTimerRemovedAfterWelcomeReceived() throws Exception {
emulateEstablished();
assertEquals(0, timeoutTimer.numRegisteredTimeouts());
}
@Test
public void testReceiveEmptyFramedMessage() throws IOException {
emulateEstablished();
MessagePackageHeader header = createHeader(ChannelBuffers.buffer(0));
ChannelBuffer headerBuffer = ChannelBuffers.wrappedBuffer(header.getByteBuffer());
Channels.fireMessageReceived(channel, headerBuffer);
assertEquals(1, upstreamCatcher.events.size());
ChannelEvent event = upstreamCatcher.events.poll();
ChannelBuffer unwrappedMessage = checkIsUpstreamMessageEventContainingSerializedMessage(event, WireFormatMessageFactory.DEFAULT_WIRE_MIME);
assertEquals(0, unwrappedMessage.readableBytes());
}
@Test
public void testReceiveFramedMessage() throws IOException {
emulateEstablished();
ChannelBuffer messageContents = serializeMessage(testMessage);
MessagePackageHeader header = createHeader(messageContents);
ChannelBuffer messageBuffer = ChannelBuffers.wrappedBuffer(ChannelBuffers.wrappedBuffer(header.getByteBuffer()), messageContents);
Channels.fireMessageReceived(channel, messageBuffer);
assertEquals(1, upstreamCatcher.events.size());
ChannelEvent event = upstreamCatcher.events.poll();
ChannelBuffer unwrappedMessage = checkIsUpstreamMessageEventContainingSerializedMessage(event, WireFormatMessageFactory.DEFAULT_WIRE_MIME);
assertEquals(messageContents, unwrappedMessage);
}
@Test
public void testReceiveFramedMessageInChunks() throws IOException {
emulateEstablished();
ChannelBuffer messageContents = serializeMessage(testMessage);
MessagePackageHeader header = createHeader(messageContents);
ChannelBuffer headerBuffer = ChannelBuffers.wrappedBuffer(header.getByteBuffer());
List<ChannelBuffer> parts = splitIntoChunks(5, headerBuffer, messageContents);
for(ChannelBuffer part : parts.subList(0, parts.size()-1)) {
Channels.fireMessageReceived(channel, part);
checkQueuesEmpty();
}
Channels.fireMessageReceived(channel, parts.get(parts.size() - 1));
assertEquals(1, upstreamCatcher.events.size());
ChannelEvent event = upstreamCatcher.events.poll();
ChannelBuffer unwrappedMessage = checkIsUpstreamMessageEventContainingSerializedMessage(event, WireFormatMessageFactory.DEFAULT_WIRE_MIME);
assertEquals(messageContents, unwrappedMessage);
}
@Test
public void receiveWelcomeAndMessagesInOverlappingChunks() throws IOException {
emulateConnect();
ChannelBuffer welcomeBuffer = createRemoteWelcomeMessageBuffer();
ChannelBuffer serializedMessage = serializeMessage(testMessage);
ChannelBuffer messageBuffer = createFramedMessage(serializedMessage);
ChannelBuffer messageBuffer2 = createFramedMessage(serializedMessage);
List<ChannelBuffer> parts = splitIntoChunks(7, welcomeBuffer, messageBuffer, messageBuffer2);
for(ChannelBuffer part : parts) {
Channels.fireMessageReceived(channel, part);
}
assertEquals(4, upstreamCatcher.events.size());
checkIsWelcomeMessage(upstreamCatcher.events.poll());
checkUpstreamChannelStateEvent(upstreamCatcher.events.poll(), ChannelState.CONNECTED, REMOTE_SOCK_ADDR);
assertEquals(serializedMessage, checkIsUpstreamMessageEventContainingSerializedMessage(upstreamCatcher.events.poll(), WireFormatMessageFactory.DEFAULT_WIRE_MIME));
assertEquals(serializedMessage, checkIsUpstreamMessageEventContainingSerializedMessage(upstreamCatcher.events.poll(), WireFormatMessageFactory.DEFAULT_WIRE_MIME));
}
@Test
public void receiveWelcomeAndMessagesInSingleChunk() throws IOException {
emulateConnect();
ChannelBuffer welcomeBuffer = createRemoteWelcomeMessageBuffer();
ChannelBuffer serializedMessage = serializeMessage(testMessage);
ChannelBuffer messageBuffer = createFramedMessage(serializedMessage);
ChannelBuffer messageBuffer2 = createFramedMessage(serializedMessage);
ChannelBuffer combined = ChannelBuffers.wrappedBuffer(welcomeBuffer, messageBuffer, messageBuffer2);
Channels.fireMessageReceived(channel, combined);
assertEquals(4, upstreamCatcher.events.size());
checkIsWelcomeMessage(upstreamCatcher.events.poll());
checkUpstreamChannelStateEvent(upstreamCatcher.events.poll(), ChannelState.CONNECTED, REMOTE_SOCK_ADDR);
assertEquals(serializedMessage, checkIsUpstreamMessageEventContainingSerializedMessage(upstreamCatcher.events.poll(), WireFormatMessageFactory.DEFAULT_WIRE_MIME));
assertEquals(serializedMessage, checkIsUpstreamMessageEventContainingSerializedMessage(upstreamCatcher.events.poll(), WireFormatMessageFactory.DEFAULT_WIRE_MIME));
}
@Test
public void testSendMessageWrapsWithFrame() throws IOException {
emulateEstablished();
ChannelBuffer messageContents = serializeMessage(testMessage);
MessagePackageHeader header = createHeader(messageContents);
SerializedMessage message = new SerializedMessage(header, messageContents);
ChannelBuffer headerBuffer = ChannelBuffers.wrappedBuffer(header.getByteBuffer());
ChannelBuffer fullFrame = ChannelBuffers.wrappedBuffer(headerBuffer, messageContents);
Channels.write(channel, message);
assertEquals(1, downstreamCatcher.events.size());
ChannelEvent event = downstreamCatcher.events.poll();
assertTrue(event instanceof DownstreamMessageEvent);
assertEquals(fullFrame, checkIsMessageEventContainingBuffer(event));
}
@Test
public void testSendIllegallyLargeWelcomeMessage() throws Exception {
emulateConnect();
ChannelBuffer fakeTooLongWelcomeMessage = ChannelBuffers.buffer(JxtaProtocolHandler.MAX_WELCOME_MESSAGE_SIZE + 1);
fakeTooLongWelcomeMessage.writeBytes("JXTAHELLO ".getBytes("UTF-8"));
fakeTooLongWelcomeMessage.writerIndex(fakeTooLongWelcomeMessage.capacity());
Channels.fireMessageReceived(channel, fakeTooLongWelcomeMessage);
assertEquals(1, downstreamCatcher.events.size());
checkDownstreamChannelStateEvent(downstreamCatcher.events.poll(), ChannelState.OPEN, Boolean.FALSE);
}
@Test
public void testSendInvalidWelcomeMessage() throws Exception {
emulateConnect();
ChannelBuffer fakeTooLongWelcomeMessage = ChannelBuffers.buffer(JxtaProtocolHandler.MAX_WELCOME_MESSAGE_SIZE + 1);
fakeTooLongWelcomeMessage.writeBytes("JXTAHELLO\r\n".getBytes("UTF-8"));
fakeTooLongWelcomeMessage.writerIndex(fakeTooLongWelcomeMessage.capacity());
Channels.fireMessageReceived(channel, fakeTooLongWelcomeMessage);
assertEquals(1, downstreamCatcher.events.size());
checkDownstreamChannelStateEvent(downstreamCatcher.events.poll(), ChannelState.OPEN, Boolean.FALSE);
}
private ChannelBuffer createFramedMessage(ChannelBuffer serializedMessage) {
ChannelBuffer messageContents = serializedMessage;
MessagePackageHeader header = createHeader(messageContents);
ChannelBuffer headerBuffer = ChannelBuffers.wrappedBuffer(header.getByteBuffer());
return ChannelBuffers.wrappedBuffer(headerBuffer, messageContents);
}
private MessagePackageHeader createHeader(ChannelBuffer messageContents) {
MessagePackageHeader header = new MessagePackageHeader();
header.setContentLengthHeader(messageContents.readableBytes());
header.setContentTypeHeader(WireFormatMessageFactory.DEFAULT_WIRE_MIME);
return header;
}
private ChannelBuffer serializeMessage(Message testMessage) {
WireFormatMessage serializedMessage = WireFormatMessageFactory.toWire(testMessage, WireFormatMessageFactory.DEFAULT_WIRE_MIME, null);
ByteBuffer[] messageBody = serializedMessage.getByteBuffers();
return ChannelBuffers.wrappedBuffer(messageBody);
}
private Message createTestMessage() {
Message testMessage = new Message();
testMessage.addMessageElement(new StringMessageElement("a", "b", null));
testMessage.addMessageElement(new StringMessageElement("c", "d", null));
testMessage.addMessageElement(new StringMessageElement("e", "f", null));
return testMessage;
}
private ChannelBuffer checkIsUpstreamMessageEventContainingSerializedMessage(ChannelEvent event, MimeMediaType expectedMime) {
assertTrue(event instanceof UpstreamMessageEvent);
UpstreamMessageEvent messageEv = (UpstreamMessageEvent)event;
assertTrue(messageEv.getMessage() instanceof SerializedMessage);
SerializedMessage message = (SerializedMessage) messageEv.getMessage();
assertEquals(expectedMime, message.getMessageHeader().getContentTypeHeader());
return message.getMessageContents();
}
private ChannelBuffer checkIsMessageEventContainingBuffer(ChannelEvent event) {
assertTrue(event instanceof MessageEvent);
MessageEvent messageEv = (MessageEvent)event;
assertTrue(messageEv.getMessage() instanceof ChannelBuffer);
ChannelBuffer unwrappedMessage = (ChannelBuffer) messageEv.getMessage();
return unwrappedMessage;
}
private ChannelBuffer createRemoteWelcomeMessageBuffer() throws IOException {
final WelcomeMessage receivedWelcomeMessage = new WelcomeMessage(LOCAL_ENDPOINT_ADDR, REMOTE_ENDPOINT_ADDR, REMOTE_PEER_ID, false);
ByteBuffer byteBuffer = receivedWelcomeMessage.getByteBuffer();
ChannelBuffer welcomeBytes = ChannelBuffers.wrappedBuffer(byteBuffer);
return welcomeBytes;
}
private void emulateConnect() {
Channels.fireChannelConnected(channel, REMOTE_SOCK_ADDR);
clearQueues();
}
private void emulateEstablished() throws IOException {
emulateConnect();
Channels.fireMessageReceived(channel, createRemoteWelcomeMessageBuffer());
clearQueues();
}
private void checkUpstreamChannelStateEvent(ChannelEvent event, ChannelState expectedState, SocketAddress expectedRemoteAddr) {
assertTrue(event instanceof UpstreamChannelStateEvent);
UpstreamChannelStateEvent stateEv = (UpstreamChannelStateEvent) event;
assertEquals(expectedState, stateEv.getState());
assertEquals(expectedRemoteAddr, stateEv.getValue());
}
private void checkQueuesEmpty() {
assertTrue(upstreamCatcher.events.isEmpty());
assertTrue(downstreamCatcher.events.isEmpty());
}
private void clearQueues() {
upstreamCatcher.events.clear();
downstreamCatcher.events.clear();
}
}