/* * The Alluxio Open Foundation licenses this work under the Apache License, version 2.0 * (the "License"). You may not use this work except in compliance with the License, which is * available at www.apache.org/licenses/LICENSE-2.0 * * This software is distributed on an "AS IS" basis, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, * either express or implied, as more fully set forth in the License. * * See the NOTICE file distributed with this work for information regarding copyright ownership. */ package alluxio.client.block.stream; import alluxio.Constants; import alluxio.EmbeddedChannels; import alluxio.client.file.FileSystemContext; import alluxio.network.protocol.RPCProtoMessage; import alluxio.network.protocol.databuffer.DataBuffer; import alluxio.network.protocol.databuffer.DataNettyBufferV2; import alluxio.proto.dataserver.Protocol; import alluxio.util.CommonUtils; import alluxio.util.WaitForOptions; import alluxio.util.io.BufferUtils; import alluxio.wire.WorkerNetAddress; import com.google.common.base.Function; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.netty.channel.embedded.EmbeddedChannel; import org.junit.After; import org.junit.Assert; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mockito; import org.powermock.api.mockito.PowerMockito; import org.powermock.core.classloader.annotations.PrepareForTest; import org.powermock.modules.junit4.PowerMockRunner; import java.util.Random; import java.util.concurrent.Callable; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; @RunWith(PowerMockRunner.class) @PrepareForTest({FileSystemContext.class, WorkerNetAddress.class}) public final class NettyPacketReaderTest { private static final int PACKET_SIZE = 1024; private static final ExecutorService EXECUTOR = Executors.newFixedThreadPool(4); private static final Random RANDOM = new Random(); private static final long BLOCK_ID = 1L; private static final long SESSION_ID = 2L; private static final long LOCK_ID = 3L; private FileSystemContext mContext; private WorkerNetAddress mAddress; private EmbeddedChannels.EmbeddedEmptyCtorChannel mChannel; private NettyPacketReader.Factory mFactory; @Before public void before() throws Exception { mContext = PowerMockito.mock(FileSystemContext.class); mAddress = Mockito.mock(WorkerNetAddress.class); mFactory = new NettyPacketReader.Factory(mContext, mAddress, BLOCK_ID, LOCK_ID, SESSION_ID, false, Protocol.RequestType.ALLUXIO_BLOCK, PACKET_SIZE); mChannel = new EmbeddedChannels.EmbeddedEmptyCtorChannel(); PowerMockito.when(mContext.acquireNettyChannel(mAddress)).thenReturn(mChannel); PowerMockito.doNothing().when(mContext).releaseNettyChannel(mAddress, mChannel); } @After public void after() throws Exception { mChannel.close(); } /** * Reads an empty file. */ @Test public void readEmptyFile() throws Exception { try (PacketReader reader = create(0, 10)) { sendReadResponses(mChannel, 0, 0, 0); Assert.assertEquals(null, reader.readPacket()); } validateReadRequestSent(mChannel, 0, 10, false, PACKET_SIZE); } /** * Reads all contents in a file. */ @Test(timeout = 1000 * 60) public void readFullFile() throws Exception { long length = PACKET_SIZE * 1024 + PACKET_SIZE / 3; try (PacketReader reader = create(0, length)) { Future<Long> checksum = sendReadResponses(mChannel, length, 0, length - 1); long checksumActual = checkPackets(reader, 0, length); Assert.assertEquals(checksum.get().longValue(), checksumActual); } validateReadRequestSent(mChannel, 0, length, false, PACKET_SIZE); } /** * Reads part of a file and checks the checksum of the part that is read. */ @Test(timeout = 1000 * 60) public void readPartialFile() throws Exception { long length = PACKET_SIZE * 1024 + PACKET_SIZE / 3; long offset = 10; long checksumStart = 100; long bytesToRead = length / 3; try (PacketReader reader = create(offset, length)) { Future<Long> checksum = sendReadResponses(mChannel, length, checksumStart, bytesToRead - 1); long checksumActual = checkPackets(reader, checksumStart, bytesToRead); Assert.assertEquals(checksum.get().longValue(), checksumActual); } validateReadRequestSent(mChannel, offset, length, false, PACKET_SIZE); validateReadRequestSent(mChannel, 0, 0, true, 0); } /** * Reads a file with unknown length. */ @Test(timeout = 1000 * 60) public void fileLengthUnknown() throws Exception { long lengthActual = PACKET_SIZE * 1024 + PACKET_SIZE / 3; long checksumStart = 100; long bytesToRead = lengthActual / 3; try (PacketReader reader = create(0, Long.MAX_VALUE)) { Future<Long> checksum = sendReadResponses(mChannel, lengthActual, checksumStart, bytesToRead - 1); long checksumActual = checkPackets(reader, checksumStart, bytesToRead); Assert.assertEquals(checksum.get().longValue(), checksumActual); } validateReadRequestSent(mChannel, 0, Long.MAX_VALUE, false, PACKET_SIZE); validateReadRequestSent(mChannel, 0, 0, true, 0); } /** * Creates a {@link PacketReader}. * * @param offset the offset * @param length the length * @return the packet reader instance */ private PacketReader create(long offset, long length) throws Exception { PacketReader reader = mFactory.create(offset, length); mChannel.finishChannelCreation(); return reader; } /** * Reads the packets from the given {@link PacketReader}. * * @param reader the packet reader * @param checksumStart the start position to calculate the checksum * @param bytesToRead bytes to read * @return the checksum of the data read starting from checksumStart */ private long checkPackets(PacketReader reader, long checksumStart, long bytesToRead) throws Exception { long pos = 0; long checksum = 0; while (true) { DataBuffer packet = reader.readPacket(); if (packet == null) { break; } try { Assert.assertTrue(packet instanceof DataNettyBufferV2); ByteBuf buf = (ByteBuf) packet.getNettyOutput(); byte[] bytes = new byte[buf.readableBytes()]; buf.readBytes(bytes); for (int i = 0; i < bytes.length; i++) { if (pos >= checksumStart) { checksum += BufferUtils.byteToInt(bytes[i]); } pos++; if (pos >= bytesToRead) { return checksum; } } } finally { packet.release(); } } return checksum; } /** * Validates the read request sent. * * @param channel the channel * @param offset the offset * @param length the length * @param cancel whether it is a cancel request * @param packetSize the packet size */ private void validateReadRequestSent(final EmbeddedChannel channel, long offset, long length, boolean cancel, long packetSize) { Object request = CommonUtils.waitForResult("read request", new Function<Void, Object>() { @Override public Object apply(Void v) { return channel.readOutbound(); } }, WaitForOptions.defaults().setTimeout(Constants.MINUTE_MS)); Assert.assertTrue(request != null); Assert.assertTrue(request instanceof RPCProtoMessage); Assert.assertEquals(null, ((RPCProtoMessage) request).getPayloadDataBuffer()); Protocol.ReadRequest readRequest = ((RPCProtoMessage) request).getMessage().asReadRequest(); Assert.assertEquals(BLOCK_ID, readRequest.getId()); Assert.assertEquals(offset, readRequest.getOffset()); Assert.assertEquals(length, readRequest.getLength()); Assert.assertEquals(cancel, readRequest.getCancel()); Assert.assertEquals(packetSize, readRequest.getPacketSize()); } /** * Sends read responses to the channel. * * @param channel the channel * @param length the length * @param start the start position to calculate the checksum * @param end the end position to calculate the checksum * @return the checksum */ private Future<Long> sendReadResponses(final EmbeddedChannel channel, final long length, final long start, final long end) { return EXECUTOR.submit(new Callable<Long>() { @Override public Long call() { long checksum = 0; long pos = 0; long remaining = length; while (remaining > 0) { int bytesToSend = (int) Math.min(remaining, PACKET_SIZE); byte[] data = new byte[bytesToSend]; RANDOM.nextBytes(data); ByteBuf buf = Unpooled.wrappedBuffer(data); RPCProtoMessage message = RPCProtoMessage.createOkResponse(new DataNettyBufferV2(buf)); channel.writeInbound(message); remaining -= bytesToSend; for (int i = 0; i < data.length; i++) { if (pos >= start && pos <= end) { checksum += BufferUtils.byteToInt(data[i]); } pos++; } } // send EOF. channel.writeInbound(RPCProtoMessage.createOkResponse(null)); return checksum; } }); } }