package com.limegroup.gnutella.connection;
import java.io.*;
import java.nio.ByteBuffer;
import java.nio.channels.*;
import java.util.*;
import java.net.*;
import junit.framework.Test;
import com.limegroup.gnutella.Response;
import com.limegroup.gnutella.messages.*;
import com.limegroup.gnutella.routing.*;
import com.limegroup.gnutella.util.*;
/**
* Tests that MessageWriter deflates data written to it correctly,
* passing it on to the source channel.
*/
public final class MessageWriterTest extends BaseTestCase {
private static final byte[] IP = new byte[] { 1, 1, 1, 1 };
private ConnectionStats STATS = new ConnectionStats();
private StubQueue QUEUE = new StubQueue();
private StubSentHandler SENT = new StubSentHandler();
private WriteBufferChannel SINK = new WriteBufferChannel(1024 * 1024);
private MessageWriter WRITER = new MessageWriter(STATS, QUEUE, SENT, SINK);
public MessageWriterTest(String name) {
super(name);
}
public static Test suite() {
return buildTestSuite(MessageWriterTest.class);
}
public static void main(String[] args) {
junit.textui.TestRunner.run(suite());
}
public void testSimpleWrite() throws Exception {
Message one, two, three;
one = q("query one");
two = g(7123);
three = s(8134);
assertEquals(0, STATS.getSent());
assertFalse(SINK.interested());
WRITER.send(one);
WRITER.send(two);
WRITER.send(three);
assertEquals(3, STATS.getSent());
assertTrue(SINK.interested());
assertEquals(0, SENT.size());
assertFalse(WRITER.handleWrite()); // nothing left to write.
assertEquals(3, SENT.size());
ByteBuffer buffer = SINK.getBuffer();
assertEquals(one.getTotalLength() + two.getTotalLength() + three.getTotalLength(), buffer.limit());
ByteArrayInputStream in = new ByteArrayInputStream(buffer.array(), 0, buffer.limit());
Message in1, in2, in3;
in1 = read(in);
in2 = read(in);
in3 = read(in);
assertEquals(-1, in.read());
assertEquals(buffer(one), buffer(in1));
assertEquals(buffer(two), buffer(in2));
assertEquals(buffer(three), buffer(in3));
assertEquals(buffer(one), buffer(SENT.next()));
assertEquals(buffer(two), buffer(SENT.next()));
assertEquals(buffer(three), buffer(SENT.next()));
assertFalse(SINK.interested());
assertEquals(3, STATS.getSent());
}
public void testWritePartialMsg() throws Exception {
assertEquals(0, SENT.size());
assertEquals(0, STATS.getSent());
Message m = q("reaalllllllllly long query");
SINK.resize(m.getTotalLength() - 20);
WRITER.send(m);
assertEquals(1, STATS.getSent());
assertTrue(WRITER.handleWrite()); // still stuff left to write.
assertTrue(SINK.interested());
assertEquals(1, SENT.size()); // it's sent, even though the other side didn't receive it fully yet.
assertEquals(buffer(m), buffer(SENT.next()));
ByteBuffer buffer = ByteBuffer.allocate(m.getTotalLength());
buffer.put(SINK.getBuffer());
SINK.resize(100000);
assertFalse(WRITER.handleWrite());
assertFalse(SINK.interested());
buffer.put(SINK.getBuffer());
Message in = read((ByteBuffer)buffer.flip());
assertEquals(buffer(m), buffer(in));
}
public void testWritePartialAndMore() throws Exception {
Message out1 = q("first long query");
Message out2 = q("second long query");
Message out3 = q("third long query");
assertEquals(0, STATS.getSent());
SINK.resize(out1.getTotalLength() + 20);
WRITER.send(out1);
WRITER.send(out2);
assertEquals(2, STATS.getSent());
assertEquals(0, SENT.size());
assertTrue(WRITER.handleWrite());
assertTrue(SINK.interested());
assertEquals(2, SENT.size()); // two were sent, one was received.
assertEquals(buffer(out1), buffer(SENT.next()));
assertEquals(buffer(out2), buffer(SENT.next()));
ByteBuffer buffer = ByteBuffer.allocate(1000);
buffer.put(SINK.getBuffer()).flip();
SINK.resize(10000);
Message in1 = read(buffer);
assertTrue(buffer.hasRemaining());
assertEquals(20, buffer.remaining());
buffer.compact();
WRITER.send(out3);
assertEquals(3, STATS.getSent());
assertFalse(WRITER.handleWrite());
assertEquals(1, SENT.size());
assertEquals(buffer(out3), buffer(SENT.next()));
assertFalse(SINK.interested());
buffer.put(SINK.getBuffer()).flip();
Message in2 = read(buffer);
Message in3 = read(buffer);
assertTrue(!buffer.hasRemaining());
assertEquals(buffer(out2), buffer(in2));
assertEquals(buffer(out3), buffer(in3));
}
public void testDroppingMessagesWhileAdded() throws Exception {
assertEquals(0, STATS.getSent());
assertEquals(0, STATS.getSentDropped());
Message m[] = new Message[10];
for(int i = 0; i < m.length; i++)
m[i] = g(i+1);
// Set queue to drop msgs (5 of'm) after the 3rd is added.
QUEUE.setNumToDrop(4);
QUEUE.setStartDropIn(3);
for(int i = 0; i < m.length; i++)
WRITER.send(m[i]);
assertEquals(4, STATS.getSentDropped());
assertEquals(10, STATS.getSent());
assertFalse(WRITER.handleWrite());
ByteBuffer buffer = SINK.getBuffer();
Message in[] = read(buffer, 6);
assertFalse(buffer.hasRemaining());
assertEquals(6, SENT.size());
for(int i = 0; i < in.length; i++)
assertEquals(buffer(m[i+4]), buffer(in[i]));
}
public void testDroppingMessagesWhileSending() throws Exception {
assertEquals(0, STATS.getSent());
assertEquals(0, STATS.getSentDropped());
Message m[] = new Message[10];
for(int i = 0; i < m.length; i++)
m[i] = g(i+1);
// Set queue to drop msgs (5 of'm) after the 3rd is added.
for(int i = 0; i < m.length; i++)
WRITER.send(m[i]);
assertEquals(0, STATS.getSentDropped());
assertEquals(10, STATS.getSent());
QUEUE.setNumToDrop(4);
QUEUE.setStartDropIn(3);
assertFalse(WRITER.handleWrite());
assertEquals(4, STATS.getSentDropped());
ByteBuffer buffer = SINK.getBuffer();
Message in[] = read(buffer, 6);
assertFalse(buffer.hasRemaining());
assertEquals(6, SENT.size());
assertEquals(buffer(m[0]), buffer(in[0]));
assertEquals(buffer(m[1]), buffer(in[1]));
assertEquals(buffer(m[2]), buffer(in[2])); // started dropping now.
assertEquals(buffer(m[7]), buffer(in[3])); // finished dropping here.
assertEquals(buffer(m[8]), buffer(in[4]));
assertEquals(buffer(m[9]), buffer(in[5]));
}
private Message read(InputStream in) throws Exception {
return Message.read(in, (byte)100);
}
private Message read(ByteBuffer buffer) throws Exception {
ByteArrayInputStream in = new ByteArrayInputStream(buffer.array(), buffer.position(), buffer.limit());
Message m = read(in);
buffer.position(buffer.position() + m.getTotalLength());
return m;
}
private Message[] read(ByteBuffer buffer, int lim) throws Exception {
Message m[] = new Message[lim];
int length = 0;
ByteArrayInputStream in = new ByteArrayInputStream(buffer.array(), buffer.position(), buffer.limit());
for(int i = 0; i < lim; i++) {
m[i] = read(in);
length += m[i].getTotalLength();
}
buffer.position(buffer.position() + length);
return m;
}
private ByteBuffer buffer(Message m) throws Exception {
ByteArrayOutputStream out = new ByteArrayOutputStream();
m.write(out);
out.flush();
return ByteBuffer.wrap(out.toByteArray());
}
private ByteBuffer buffer(Message m[]) throws Exception {
ByteArrayOutputStream out = new ByteArrayOutputStream();
for(int i = 0; i < m.length; i++)
m[i].write(out);
out.flush();
return ByteBuffer.wrap(out.toByteArray());
}
private ByteBuffer buffer(List ms) throws Exception {
ByteArrayOutputStream out = new ByteArrayOutputStream();
for(Iterator i = ms.iterator(); i.hasNext(); )
((Message)i.next()).write(out);
out.flush();
return ByteBuffer.wrap(out.toByteArray());
}
private ByteBuffer buffer(ByteBuffer[] bufs) throws Exception {
int length = 0;
for(int i = 0; i < bufs.length; i++)
length += bufs[i].limit();
ByteBuffer combined = ByteBuffer.allocate(length);
for(int i = 0; i < bufs.length; i++)
combined.put(bufs[i]);
combined.flip();
return combined;
}
private ByteBuffer buffer(byte[] data) {
return ByteBuffer.wrap(data);
}
private QueryRequest q(String query) {
return QueryRequest.createQuery(query, (byte)5);
}
private PingReply g(int port) {
return PingReply.create(new byte[16], (byte)5, port, IP);
}
private PushRequest s(int port) {
return new PushRequest(new byte[16], (byte)5, new byte[16], 0, IP, port);
}
}