package com.twitter.elephantbird.thrift;
import java.io.IOException;
import java.nio.ByteBuffer;
import org.apache.thrift.TException;
import org.apache.thrift.protocol.TProtocolException;
import org.apache.thrift.protocol.TType;
import org.apache.thrift.transport.TTransport;
import org.apache.thrift.transport.TTransportException;
import org.easymock.IAnswer;
import org.easymock.EasyMock;
import org.junit.Test;
import static org.easymock.EasyMock.anyInt;
import static org.easymock.EasyMock.createStrictMock;
import static org.easymock.EasyMock.expect;
import static org.easymock.EasyMock.isA;
import static org.easymock.EasyMock.replay;
import static org.easymock.EasyMock.verify;
public class TestThriftBinaryProtocol {
// helper method to set container size correctly in the supplied byte array
protected void setContainerSize(byte[] buf, int n) {
byte[] b = ByteBuffer.allocate(4).putInt(n).array();
for (int i = 0; i < 4; i++) {
buf[i] = b[i];
}
}
protected void setDataType(byte[] buf) {
buf[0] = TType.BYTE;
}
// mock transport for Set and List container types
protected TTransport getMockTransport(final int containerSize) throws TException {
TTransport transport = createStrictMock(TTransport.class);
// not using buffered mode for tests, so return -1 per the contract
expect(transport.getBytesRemainingInBuffer()).andReturn(-1);
// first call, set data type
expect(transport.readAll(isA(byte[].class), anyInt(), anyInt()))
.andAnswer(
new IAnswer<Integer>() {
public Integer answer() {
byte[] buf = (byte[])(EasyMock.getCurrentArguments()[0]);
setDataType(buf);
return 1;
}
}
);
expect(transport.getBytesRemainingInBuffer()).andReturn(-1);
// second call, set container size
expect(transport.readAll(isA(byte[].class), anyInt(), anyInt()))
.andAnswer(
new IAnswer<Integer>() {
public Integer answer() {
byte[] buf = (byte[])(EasyMock.getCurrentArguments()[0]);
setContainerSize(buf, containerSize);
return 4;
}
}
);
return transport;
}
// mock transport for Map container type
protected TTransport getMockMapTransport(final int containerSize) throws TException {
TTransport transport = createStrictMock(TTransport.class);
// not using buffered mode for tests, so return -1 per the contract
expect(transport.getBytesRemainingInBuffer()).andReturn(-1);
// first call, set key type
expect(transport.readAll(isA(byte[].class), anyInt(), anyInt()))
.andAnswer(
new IAnswer<Integer>() {
public Integer answer() {
byte[] buf = (byte[])(EasyMock.getCurrentArguments()[0]);
setDataType(buf);
return 1;
}
}
);
expect(transport.getBytesRemainingInBuffer()).andReturn(-1);
// second call, set value type
expect(transport.readAll(isA(byte[].class), anyInt(), anyInt()))
.andAnswer(
new IAnswer<Integer>() {
public Integer answer() {
byte[] buf = (byte[])(EasyMock.getCurrentArguments()[0]);
setDataType(buf);
return 1;
}
}
);
expect(transport.getBytesRemainingInBuffer()).andReturn(-1);
// third call, set container size
expect(transport.readAll(isA(byte[].class), anyInt(), anyInt()))
.andAnswer(
new IAnswer<Integer>() {
public Integer answer() {
byte[] buf = (byte[])(EasyMock.getCurrentArguments()[0]);
setContainerSize(buf, containerSize);
return 4;
}
}
);
return transport;
}
@Test
public void testCheckContainerSizeValid() throws TException {
// any non-negative value is considered valid when checkReadLength is not enabled
TTransport transport;
ThriftBinaryProtocol protocol;
transport = getMockTransport(3);
replay(transport);
protocol = new ThriftBinaryProtocol(transport);
protocol.readListBegin();
verify(transport);
transport = getMockTransport(3);
replay(transport);
protocol = new ThriftBinaryProtocol(transport);
protocol.readSetBegin();
verify(transport);
transport = getMockMapTransport(3);
replay(transport);
protocol = new ThriftBinaryProtocol(transport);
protocol.readMapBegin();
verify(transport);
}
@Test(expected=TProtocolException.class)
public void testCheckListContainerSizeInvalid() throws TException {
// any negative value is considered invalid when checkReadLength is not enabled
TTransport transport = getMockTransport(-1);
replay(transport);
ThriftBinaryProtocol protocol = new ThriftBinaryProtocol(transport);
protocol.readListBegin();
verify(transport);
}
@Test(expected=TProtocolException.class)
public void testCheckSetContainerSizeInvalid() throws TException {
TTransport transport = getMockTransport(-1);
replay(transport);
ThriftBinaryProtocol protocol = new ThriftBinaryProtocol(transport);
protocol.readSetBegin();
verify(transport);
}
@Test(expected=TProtocolException.class)
public void testCheckMapContainerSizeInvalid() throws TException {
TTransport transport = getMockMapTransport(-1);
replay(transport);
ThriftBinaryProtocol protocol = new ThriftBinaryProtocol(transport);
protocol.readMapBegin();
verify(transport);
}
}