/* * The Alluxio Open Foundation licenses this work under the Apache License, version 2.0 * (the "License"). You may not use this work except in compliance with the License, which is * available at www.apache.org/licenses/LICENSE-2.0 * * This software is distributed on an "AS IS" basis, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, * either express or implied, as more fully set forth in the License. * * See the NOTICE file distributed with this work for information regarding copyright ownership. */ package alluxio.client.block.stream; import alluxio.Constants; import alluxio.EmbeddedChannels; import alluxio.client.file.FileSystemContext; import alluxio.network.protocol.RPCProtoMessage; import alluxio.network.protocol.databuffer.DataBuffer; import alluxio.network.protocol.databuffer.DataNettyBufferV2; import alluxio.proto.dataserver.Protocol; import alluxio.util.CommonUtils; import alluxio.util.ThreadFactoryUtils; import alluxio.util.WaitForOptions; import alluxio.util.io.BufferUtils; import alluxio.wire.WorkerNetAddress; import com.google.common.base.Function; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.netty.channel.embedded.EmbeddedChannel; import org.junit.After; import org.junit.Assert; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mockito; import org.powermock.api.mockito.PowerMockito; import org.powermock.core.classloader.annotations.PrepareForTest; import org.powermock.modules.junit4.PowerMockRunner; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.IOException; import java.util.Random; import java.util.concurrent.Callable; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; @RunWith(PowerMockRunner.class) @PrepareForTest({FileSystemContext.class, WorkerNetAddress.class}) public final class NettyPacketWriterTest { private static final Logger LOG = LoggerFactory.getLogger(NettyPacketWriterTest.class); private static final int PACKET_SIZE = 1024; private static final ExecutorService EXECUTOR = Executors.newFixedThreadPool(4, ThreadFactoryUtils.build("test-executor-%d", true)); private static final Random RANDOM = new Random(); private static final long BLOCK_ID = 1L; private static final long SESSION_ID = 2L; private static final int TIER = 0; private FileSystemContext mContext; private WorkerNetAddress mAddress; private EmbeddedChannels.EmbeddedEmptyCtorChannel mChannel; @Before public void before() throws Exception { mContext = PowerMockito.mock(FileSystemContext.class); mAddress = Mockito.mock(WorkerNetAddress.class); mChannel = new EmbeddedChannels.EmbeddedEmptyCtorChannel(); PowerMockito.when(mContext.acquireNettyChannel(mAddress)).thenReturn(mChannel); PowerMockito.doNothing().when(mContext).releaseNettyChannel(mAddress, mChannel); } @After public void after() throws Exception { mChannel.close(); } /** * Writes an empty file. */ @Test(timeout = 1000 * 60) public void writeEmptyFile() throws Exception { Future<Long> checksumActual; try (PacketWriter writer = create(10)) { checksumActual = verifyWriteRequests(mChannel, 0, 10); } Assert.assertEquals(0, checksumActual.get().longValue()); } /** * Writes a file with file length matches what is given and verifies the checksum of the whole * file. */ @Test(timeout = 1000 * 60) public void writeFullFile() throws Exception { Future<Long> checksumActual; Future<Long> checksumExpected; long length = PACKET_SIZE * 1024 + PACKET_SIZE / 3; try (PacketWriter writer = create(length)) { checksumExpected = writeFile(writer, length, 0, length - 1); checksumActual = verifyWriteRequests(mChannel, 0, length - 1); checksumExpected.get(); } Assert.assertEquals(checksumExpected.get(), checksumActual.get()); } /** * Writes a file with file length matches what is given and verifies the checksum of the whole * file. */ @Test(timeout = 1000 * 60) public void writeFileChecksumOfPartialFile() throws Exception { Future<Long> checksumActual; Future<Long> checksumExpected; long length = PACKET_SIZE * 1024 + PACKET_SIZE / 3; try (PacketWriter writer = create(length)) { checksumExpected = writeFile(writer, length, 10, length / 3); checksumActual = verifyWriteRequests(mChannel, 10, length / 3); checksumExpected.get(); } Assert.assertEquals(checksumExpected.get(), checksumActual.get()); } /** * Writes a file with unknown length. */ @Test(timeout = 1000 * 60) public void writeFileUnknownLength() throws Exception { Future<Long> checksumActual; Future<Long> checksumExpected; long length = PACKET_SIZE * 1024; try (PacketWriter writer = create(Long.MAX_VALUE)) { checksumExpected = writeFile(writer, length, 10, length / 3); checksumActual = verifyWriteRequests(mChannel, 10, length / 3); checksumExpected.get(); } Assert.assertEquals(checksumExpected.get(), checksumActual.get()); } /** * Writes lots of packets. */ @Test(timeout = 1000 * 60) public void writeFileManyPackets() throws Exception { Future<Long> checksumActual; Future<Long> checksumExpected; long length = PACKET_SIZE * 30000 + PACKET_SIZE / 3; try (PacketWriter writer = create(Long.MAX_VALUE)) { checksumExpected = writeFile(writer, length, 10, length / 3); checksumActual = verifyWriteRequests(mChannel, 10, length / 3); checksumExpected.get(); } Assert.assertEquals(checksumExpected.get(), checksumActual.get()); } /** * Creates a {@link PacketWriter}. * * @param length the length * @return the packet writer instance */ private PacketWriter create(long length) throws Exception { PacketWriter writer = new NettyPacketWriter(mContext, mAddress, BLOCK_ID, length, SESSION_ID, TIER, Protocol.RequestType.ALLUXIO_BLOCK, PACKET_SIZE); mChannel.finishChannelCreation(); return writer; } /** * Writes packets via the given packet writer and returns a checksum for a region of the data * written. * * @param length the length * @param start the start position to calculate the checksum * @param end the end position to calculate the checksum * @return the checksum */ private Future<Long> writeFile(final PacketWriter writer, final long length, final long start, final long end) throws Exception { return EXECUTOR.submit(new Callable<Long>() { @Override public Long call() throws IOException { try { long checksum = 0; long pos = 0; long remaining = length; while (remaining > 0) { int bytesToWrite = (int) Math.min(remaining, PACKET_SIZE); byte[] data = new byte[bytesToWrite]; RANDOM.nextBytes(data); ByteBuf buf = Unpooled.wrappedBuffer(data); try { writer.writePacket(buf); } catch (Exception e) { Assert.fail(e.getMessage()); throw e; } remaining -= bytesToWrite; for (int i = 0; i < data.length; i++) { if (pos >= start && pos <= end) { checksum += BufferUtils.byteToInt(data[i]); } pos++; } } return checksum; } catch (Throwable throwable) { LOG.error("Failed to write file.", throwable); Assert.fail(); throw throwable; } } }); } /** * Verifies the packets written. After receiving the last packet, it will also send an EOF to * the channel. * * @param checksumStart the start position to calculate the checksum * @param checksumEnd the end position to calculate the checksum * @return the checksum of the data read starting from checksumStart */ private Future<Long> verifyWriteRequests(final EmbeddedChannel channel, final long checksumStart, final long checksumEnd) { return EXECUTOR.submit(new Callable<Long>() { @Override public Long call() { try { long checksum = 0; long pos = 0; while (true) { RPCProtoMessage request = (RPCProtoMessage) CommonUtils .waitForResult("wrtie request", new Function<Void, Object>() { @Override public Object apply(Void v) { return channel.readOutbound(); } }, WaitForOptions.defaults().setTimeout(Constants.MINUTE_MS)); validateWriteRequest(request.getMessage().asWriteRequest(), pos); DataBuffer buffer = request.getPayloadDataBuffer(); // Last packet. if (buffer == null) { channel.writeInbound(RPCProtoMessage.createOkResponse(null)); return checksum; } try { Assert.assertTrue(buffer instanceof DataNettyBufferV2); ByteBuf buf = (ByteBuf) buffer.getNettyOutput(); while (buf.readableBytes() > 0) { if (pos >= checksumStart && pos <= checksumEnd) { checksum += BufferUtils.byteToInt(buf.readByte()); } else { buf.readByte(); } pos++; } } finally { buffer.release(); } } } catch (Throwable throwable) { LOG.error("Failed to verify write requests.", throwable); Assert.fail(); throw throwable; } } }); } /** * Validates the read request sent. * * @param request the request * @param offset the offset */ private void validateWriteRequest(Protocol.WriteRequest request, long offset) { Assert.assertEquals(BLOCK_ID, request.getId()); Assert.assertEquals(offset, request.getOffset()); } }