/**
* Copyright 2016 LinkedIn Corp. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
package com.github.ambry.protocol;
import com.github.ambry.clustermap.MockClusterMap;
import com.github.ambry.clustermap.MockPartitionId;
import com.github.ambry.clustermap.PartitionId;
import com.github.ambry.commons.BlobId;
import com.github.ambry.commons.ServerErrorCode;
import com.github.ambry.messageformat.BlobProperties;
import com.github.ambry.messageformat.BlobType;
import com.github.ambry.messageformat.MessageFormatFlags;
import com.github.ambry.store.FindToken;
import com.github.ambry.store.FindTokenFactory;
import com.github.ambry.store.MessageInfo;
import com.github.ambry.utils.ByteBufferChannel;
import com.github.ambry.utils.ByteBufferInputStream;
import com.github.ambry.utils.ByteBufferOutputStream;
import com.github.ambry.utils.Utils;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.Channels;
import java.nio.channels.WritableByteChannel;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import org.junit.Assert;
import org.junit.Test;
class MockFindTokenFactory implements FindTokenFactory {
@Override
public FindToken getFindToken(DataInputStream stream) throws IOException {
return new MockFindToken(stream);
}
@Override
public FindToken getNewFindToken() {
return new MockFindToken(0, 0);
}
}
class MockFindToken implements FindToken {
int index;
long bytesRead;
public MockFindToken(int index, long bytesRead) {
this.index = index;
this.bytesRead = bytesRead;
}
public MockFindToken(DataInputStream stream) throws IOException {
this.index = stream.readInt();
this.bytesRead = stream.readLong();
}
@Override
public byte[] toBytes() {
ByteBuffer byteBuffer = ByteBuffer.allocate(12);
byteBuffer.putInt(index);
byteBuffer.putLong(bytesRead);
return byteBuffer.array();
}
public int getIndex() {
return index;
}
public long getBytesRead() {
return this.bytesRead;
}
}
class InvalidVersionPutRequest extends PutRequest {
static final short Put_Request_Invalid_version = 0;
public InvalidVersionPutRequest(int correlationId, String clientId, BlobId blobId, BlobProperties properties,
ByteBuffer usermetadata, ByteBuffer blob, long blobSize, BlobType blobType) {
super(correlationId, clientId, blobId, properties, usermetadata, blob, blobSize, blobType);
versionId = Put_Request_Invalid_version;
}
}
public class RequestResponseTest {
private final Random random = new Random();
private void testPutRequest(MockClusterMap clusterMap, int correlationId, String clientId, BlobId blobId,
BlobProperties blobProperties, byte[] userMetadata, BlobType blobType, byte[] blob, int blobSize)
throws IOException {
// This PutRequest is created just to get the size.
int sizeInBytes =
(int) new PutRequest(correlationId, clientId, blobId, blobProperties, ByteBuffer.wrap(userMetadata),
ByteBuffer.wrap(blob), blobSize, blobType).sizeInBytes();
// Initialize channel write limits in such a way that writeTo() may or may not be able to write out all the
// data at once.
int channelWriteLimits[] =
{sizeInBytes, 2 * sizeInBytes, sizeInBytes / 2, sizeInBytes / (random.nextInt(sizeInBytes - 1) + 1)};
int sizeInBlobProperties = (int) blobProperties.getBlobSize();
for (int allocationSize : channelWriteLimits) {
PutRequest request =
new PutRequest(correlationId, clientId, blobId, blobProperties, ByteBuffer.wrap(userMetadata),
ByteBuffer.wrap(blob), blobSize, blobType);
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
int expectedWriteToCount = ((int) request.sizeInBytes() + allocationSize - 1) / allocationSize;
int actualWriteToCount = 0;
while (!request.isSendComplete()) {
ByteBufferChannel channel = new ByteBufferChannel(ByteBuffer.allocate(allocationSize));
request.writeTo(channel);
ByteBuffer underlyingBuf = channel.getBuffer();
underlyingBuf.flip();
outputStream.write(underlyingBuf.array(), underlyingBuf.arrayOffset(), underlyingBuf.remaining());
actualWriteToCount++;
}
Assert.assertEquals("writeTo() should have written out as much as the channel could take in every call",
expectedWriteToCount, actualWriteToCount);
DataInputStream requestStream = new DataInputStream(new ByteArrayInputStream(outputStream.toByteArray()));
requestStream.readLong();
Assert.assertEquals(RequestOrResponseType.values()[requestStream.readShort()], RequestOrResponseType.PutRequest);
PutRequest.ReceivedPutRequest deserializedPutRequest = PutRequest.readFrom(requestStream, clusterMap);
Assert.assertEquals(deserializedPutRequest.getBlobId(), blobId);
Assert.assertEquals(deserializedPutRequest.getBlobProperties().getBlobSize(), sizeInBlobProperties);
Assert.assertArrayEquals(userMetadata, deserializedPutRequest.getUsermetadata().array());
Assert.assertEquals(deserializedPutRequest.getBlobSize(), blobSize);
Assert.assertEquals(deserializedPutRequest.getBlobType(), blobType);
byte[] blobRead = new byte[blobSize];
deserializedPutRequest.getBlobStream().read(blobRead);
Assert.assertArrayEquals(blob, blobRead);
}
}
private void testPutRequestInvalidVersion(MockClusterMap clusterMap, int correlationId, String clientId,
BlobId blobId, BlobProperties blobProperties, byte[] userMetadata, byte[] blob) throws IOException {
int sizeInBlobProperties = (int) blobProperties.getBlobSize();
PutRequest request =
new InvalidVersionPutRequest(correlationId, clientId, blobId, blobProperties, ByteBuffer.wrap(userMetadata),
ByteBuffer.wrap(blob), sizeInBlobProperties, BlobType.DataBlob);
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
WritableByteChannel writableByteChannel = Channels.newChannel(outputStream);
while (!request.isSendComplete()) {
request.writeTo(writableByteChannel);
}
DataInputStream requestStream = new DataInputStream(new ByteArrayInputStream(outputStream.toByteArray()));
requestStream.readLong();
Assert.assertEquals(RequestOrResponseType.values()[requestStream.readShort()], RequestOrResponseType.PutRequest);
try {
PutRequest.readFrom(requestStream, clusterMap);
Assert.fail("Deserialization of PutRequest with invalid version should have thrown an exception.");
} catch (IllegalStateException e) {
}
}
@Test
public void putRequestResponseTest() throws IOException {
Random rnd = new Random();
MockClusterMap clusterMap = new MockClusterMap();
int correlationId = 5;
String clientId = "client";
BlobId blobId = new BlobId(clusterMap.getWritablePartitionIds().get(0));
byte[] userMetadata = new byte[50];
rnd.nextBytes(userMetadata);
ByteBuffer.wrap(userMetadata);
int blobSize = 100;
byte[] blob = new byte[blobSize];
rnd.nextBytes(blob);
BlobProperties blobProperties =
new BlobProperties(blobSize, "serviceID", "memberId", "contentType", false, Utils.Infinite_Time);
testPutRequest(clusterMap, correlationId, clientId, blobId, blobProperties, userMetadata, BlobType.DataBlob, blob,
blobSize);
// Put Request with size in blob properties different from the data size and blob type: Data blob.
blobProperties =
new BlobProperties(blobSize * 10, "serviceID", "memberId", "contentType", false, Utils.Infinite_Time);
testPutRequest(clusterMap, correlationId, clientId, blobId, blobProperties, userMetadata, BlobType.DataBlob, blob,
blobSize);
// Put Request with size in blob properties different from the data size and blob type: Metadata blob.
blobProperties =
new BlobProperties(blobSize * 10, "serviceID", "memberId", "contentType", false, Utils.Infinite_Time);
testPutRequest(clusterMap, correlationId, clientId, blobId, blobProperties, userMetadata, BlobType.MetadataBlob,
blob, blobSize);
// Put Request with empty user metadata.
byte[] emptyUserMetadata = new byte[0];
blobProperties = new BlobProperties(blobSize, "serviceID", "memberId", "contentType", false, Utils.Infinite_Time);
testPutRequest(clusterMap, correlationId, clientId, blobId, blobProperties, emptyUserMetadata, BlobType.DataBlob,
blob, blobSize);
blobProperties = new BlobProperties(blobSize, "serviceID", "memberId", "contentType", false, Utils.Infinite_Time);
// Ensure a Put Request with an invalid version does not get deserialized correctly.
testPutRequestInvalidVersion(clusterMap, correlationId, clientId, blobId, blobProperties, userMetadata, blob);
// Response test
PutResponse response = new PutResponse(1234, clientId, ServerErrorCode.No_Error);
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
WritableByteChannel writableByteChannel = Channels.newChannel(outputStream);
do {
response.writeTo(writableByteChannel);
} while (!response.isSendComplete());
DataInputStream responseStream = new DataInputStream(new ByteArrayInputStream(outputStream.toByteArray()));
responseStream.readLong();
PutResponse deserializedPutResponse = PutResponse.readFrom(responseStream);
Assert.assertEquals(deserializedPutResponse.getCorrelationId(), 1234);
Assert.assertEquals(deserializedPutResponse.getError(), ServerErrorCode.No_Error);
}
@Test
public void getRequestResponseTest() throws IOException {
MockClusterMap clusterMap = new MockClusterMap();
BlobId id1 = new BlobId(clusterMap.getWritablePartitionIds().get(0));
ArrayList<BlobId> blobIdList = new ArrayList<BlobId>();
blobIdList.add(id1);
PartitionRequestInfo partitionRequestInfo1 = new PartitionRequestInfo(new MockPartitionId(), blobIdList);
ArrayList<PartitionRequestInfo> partitionRequestInfoList = new ArrayList<PartitionRequestInfo>();
partitionRequestInfoList.add(partitionRequestInfo1);
GetRequest getRequest =
new GetRequest(1234, "clientId", MessageFormatFlags.Blob, partitionRequestInfoList, GetOption.None);
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
WritableByteChannel writableByteChannel = Channels.newChannel(outputStream);
do {
getRequest.writeTo(writableByteChannel);
} while (!getRequest.isSendComplete());
DataInputStream requestStream = new DataInputStream(new ByteArrayInputStream(outputStream.toByteArray()));
requestStream.readLong(); // read length
requestStream.readShort(); // read short
GetRequest deserializedGetRequest = GetRequest.readFrom(requestStream, clusterMap);
Assert.assertEquals(deserializedGetRequest.getClientId(), "clientId");
Assert.assertEquals(deserializedGetRequest.getPartitionInfoList().size(), 1);
Assert.assertEquals(deserializedGetRequest.getPartitionInfoList().get(0).getBlobIds().size(), 1);
Assert.assertEquals(deserializedGetRequest.getPartitionInfoList().get(0).getBlobIds().get(0), id1);
MessageInfo messageInfo = new MessageInfo(id1, 1000, 1000);
ArrayList<MessageInfo> messageInfoList = new ArrayList<MessageInfo>();
messageInfoList.add(messageInfo);
PartitionResponseInfo partitionResponseInfo =
new PartitionResponseInfo(clusterMap.getWritablePartitionIds().get(0), messageInfoList);
List<PartitionResponseInfo> partitionResponseInfoList = new ArrayList<PartitionResponseInfo>();
partitionResponseInfoList.add(partitionResponseInfo);
byte[] buf = new byte[1000];
new Random().nextBytes(buf);
ByteArrayInputStream byteStream = new ByteArrayInputStream(buf);
GetResponse response =
new GetResponse(1234, "clientId", partitionResponseInfoList, byteStream, ServerErrorCode.No_Error);
outputStream.reset();
do {
response.writeTo(writableByteChannel);
} while (!response.isSendComplete());
requestStream = new DataInputStream(new ByteArrayInputStream(outputStream.toByteArray()));
requestStream.readLong(); // read size
GetResponse deserializedGetResponse = GetResponse.readFrom(requestStream, clusterMap);
Assert.assertEquals(deserializedGetResponse.getCorrelationId(), 1234);
Assert.assertEquals(deserializedGetResponse.getError(), ServerErrorCode.No_Error);
Assert.assertEquals(deserializedGetResponse.getPartitionResponseInfoList().size(), 1);
Assert.assertEquals(deserializedGetResponse.getPartitionResponseInfoList().get(0).getMessageInfoList().size(), 1);
Assert.assertEquals(
deserializedGetResponse.getPartitionResponseInfoList().get(0).getMessageInfoList().get(0).getSize(), 1000);
Assert.assertEquals(
deserializedGetResponse.getPartitionResponseInfoList().get(0).getMessageInfoList().get(0).getStoreKey(), id1);
Assert.assertEquals(deserializedGetResponse.getPartitionResponseInfoList()
.get(0)
.getMessageInfoList()
.get(0)
.getExpirationTimeInMs(), 1000);
}
@Test
public void deleteRequestResponseTest() throws IOException {
MockClusterMap clusterMap = new MockClusterMap();
BlobId id1 = new BlobId(clusterMap.getWritablePartitionIds().get(0));
DeleteRequest deleteRequest = new DeleteRequest(1234, "client", id1);
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
WritableByteChannel writableByteChannel = Channels.newChannel(outputStream);
do {
deleteRequest.writeTo(writableByteChannel);
} while (!deleteRequest.isSendComplete());
DataInputStream requestStream = new DataInputStream(new ByteArrayInputStream(outputStream.toByteArray()));
requestStream.readLong(); // read length
requestStream.readShort(); // read short
DeleteRequest deserializedDeleteRequest = DeleteRequest.readFrom(requestStream, clusterMap);
Assert.assertEquals(deserializedDeleteRequest.getClientId(), "client");
Assert.assertEquals(deserializedDeleteRequest.getBlobId(), id1);
DeleteResponse response = new DeleteResponse(1234, "client", ServerErrorCode.No_Error);
outputStream.reset();
do {
response.writeTo(writableByteChannel);
} while (!response.isSendComplete());
requestStream = new DataInputStream(new ByteArrayInputStream(outputStream.toByteArray()));
requestStream.readLong(); // read size
DeleteResponse deserializedDeleteResponse = DeleteResponse.readFrom(requestStream);
Assert.assertEquals(deserializedDeleteResponse.getCorrelationId(), 1234);
Assert.assertEquals(deserializedDeleteResponse.getError(), ServerErrorCode.No_Error);
}
@Test
public void replicaMetadataRequestTest() throws IOException {
MockClusterMap clusterMap = new MockClusterMap();
BlobId id1 = new BlobId(clusterMap.getWritablePartitionIds().get(0));
List<ReplicaMetadataRequestInfo> replicaMetadataRequestInfoList = new ArrayList<ReplicaMetadataRequestInfo>();
ReplicaMetadataRequestInfo replicaMetadataRequestInfo =
new ReplicaMetadataRequestInfo(new MockPartitionId(), new MockFindToken(0, 1000), "localhost", "path");
replicaMetadataRequestInfoList.add(replicaMetadataRequestInfo);
ReplicaMetadataRequest request = new ReplicaMetadataRequest(1, "id", replicaMetadataRequestInfoList, 1000);
ByteBuffer buffer = ByteBuffer.allocate((int) request.sizeInBytes());
ByteBufferOutputStream byteBufferOutputStream = new ByteBufferOutputStream(buffer);
do {
request.writeTo(Channels.newChannel(byteBufferOutputStream));
} while (!request.isSendComplete());
buffer.flip();
buffer.getLong();
buffer.getShort();
ReplicaMetadataRequest replicaMetadataRequestFromBytes =
ReplicaMetadataRequest.readFrom(new DataInputStream(new ByteBufferInputStream(buffer)), new MockClusterMap(),
new MockFindTokenFactory());
Assert.assertEquals(replicaMetadataRequestFromBytes.getMaxTotalSizeOfEntriesInBytes(), 1000);
Assert.assertEquals(replicaMetadataRequestFromBytes.getReplicaMetadataRequestInfoList().size(), 1);
try {
request = new ReplicaMetadataRequest(1, "id", null, 12);
buffer = ByteBuffer.allocate((int) request.sizeInBytes());
byteBufferOutputStream = new ByteBufferOutputStream(buffer);
do {
request.writeTo(Channels.newChannel(byteBufferOutputStream));
} while (!request.isSendComplete());
Assert.assertEquals(true, false);
} catch (IllegalArgumentException e) {
Assert.assertEquals(true, true);
}
try {
replicaMetadataRequestInfo = new ReplicaMetadataRequestInfo(new MockPartitionId(), null, "localhost", "path");
Assert.assertTrue(false);
} catch (IllegalArgumentException e) {
Assert.assertTrue(true);
}
MessageInfo messageInfo = new MessageInfo(id1, 1000);
List<MessageInfo> messageInfoList = new ArrayList<MessageInfo>();
messageInfoList.add(messageInfo);
ReplicaMetadataResponseInfo responseInfo =
new ReplicaMetadataResponseInfo(clusterMap.getWritablePartitionIds().get(0), new MockFindToken(0, 1000),
messageInfoList, 1000);
List<ReplicaMetadataResponseInfo> replicaMetadataResponseInfoList = new ArrayList<ReplicaMetadataResponseInfo>();
replicaMetadataResponseInfoList.add(responseInfo);
ReplicaMetadataResponse response =
new ReplicaMetadataResponse(1234, "clientId", ServerErrorCode.No_Error, replicaMetadataResponseInfoList);
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
WritableByteChannel writableByteChannel = Channels.newChannel(outputStream);
outputStream.reset();
do {
response.writeTo(writableByteChannel);
} while (!response.isSendComplete());
DataInputStream requestStream = new DataInputStream(new ByteArrayInputStream(outputStream.toByteArray()));
requestStream.readLong(); // read size
ReplicaMetadataResponse deserializedDeleteResponse =
ReplicaMetadataResponse.readFrom(requestStream, new MockFindTokenFactory(), clusterMap);
Assert.assertEquals(deserializedDeleteResponse.getCorrelationId(), 1234);
Assert.assertEquals(deserializedDeleteResponse.getError(), ServerErrorCode.No_Error);
}
/**
* Tests the ser/de of {@link AdminRequest} and {@link AdminResponse} and checks for equality of fields with
* reference data.
* @throws IOException
*/
@Test
public void adminRequestResponseTest() throws IOException {
for (AdminRequestOrResponseType type : AdminRequestOrResponseType.values()) {
MockClusterMap clusterMap = new MockClusterMap();
PartitionId id = clusterMap.getWritablePartitionIds().get(0);
AdminRequest adminRequest = new AdminRequest(type, id, 1234, "client");
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
WritableByteChannel writableByteChannel = Channels.newChannel(outputStream);
do {
adminRequest.writeTo(writableByteChannel);
} while (!adminRequest.isSendComplete());
DataInputStream requestStream = new DataInputStream(new ByteArrayInputStream(outputStream.toByteArray()));
// read length
requestStream.readLong();
// read version
requestStream.readShort();
AdminRequest deserializedAdminRequest = AdminRequest.readFrom(requestStream, clusterMap);
Assert.assertEquals(deserializedAdminRequest.getCorrelationId(), 1234);
Assert.assertEquals(deserializedAdminRequest.getClientId(), "client");
Assert.assertEquals(deserializedAdminRequest.getType(), type);
Assert.assertTrue(deserializedAdminRequest.getPartitionId().isEqual(id.toString()));
AdminResponse response = new AdminResponse(1234, "client", ServerErrorCode.No_Error);
outputStream.reset();
do {
response.writeTo(writableByteChannel);
} while (!response.isSendComplete());
requestStream = new DataInputStream(new ByteArrayInputStream(outputStream.toByteArray()));
requestStream.readLong(); // read size
AdminResponse deserializedAdminResponse = AdminResponse.readFrom(requestStream);
Assert.assertEquals(deserializedAdminResponse.getCorrelationId(), 1234);
Assert.assertEquals(deserializedAdminResponse.getError(), ServerErrorCode.No_Error);
}
}
}