package org.dcache.pool.movers;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import java.io.File;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.security.NoSuchAlgorithmException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import diskCacheV111.util.ChecksumFactory;
import org.dcache.pool.repository.FileRepositoryChannel;
import org.dcache.pool.repository.RepositoryChannel;
import org.dcache.util.Checksum;
import org.dcache.util.ChecksumType;
import static com.google.common.collect.Lists.newArrayList;
import static org.dcache.util.ByteUnit.KiB;
import static org.hamcrest.Matchers.*;
import static org.junit.Assert.*;
import static org.mockito.Matchers.*;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class ChecksumChannelTest {
private ChecksumChannel chksumChannel;
private byte[] data = "\0Just\0A\0Short\0TestString\0To\0Verify\0\0Checksumming\0\0Works\12".getBytes(StandardCharsets.ISO_8859_1); // \12 is a octal 10, linefeed
private Checksum expectedChecksum;
private int blocksize = 2;
private int blockcount = data.length/blocksize;
private ByteBuffer[] buffers = new ByteBuffer[blockcount];
private Path testFile;
@Before
public void setUp() throws NoSuchAlgorithmException, IOException {
testFile = Files.createTempFile("ChecksumChannelTest", ".tmp");
RepositoryChannel mockRepositoryChannel = new FileRepositoryChannel(testFile, "rw");
ChecksumFactory checksumFactory = ChecksumFactory.getFactory(ChecksumType.MD5_TYPE);
chksumChannel = new ChecksumChannel(mockRepositoryChannel, checksumFactory);
chksumChannel._readBackBuffer = ByteBuffer.allocate(2);
chksumChannel._zerosBuffer = ByteBuffer.allocate(1);
expectedChecksum = new Checksum(ChecksumType.MD5_TYPE, checksumFactory.create().digest(data));
for (int b = 0; b < blockcount; b++) {
buffers[b] = ByteBuffer.wrap(Arrays.copyOfRange(data, b*blocksize, (b+1)*blocksize));
}
}
@After
public void tearDown() throws IOException {
chksumChannel.close();
Files.delete(testFile);
}
@Test
public void shouldSucceedIfWrittenAtOnce() throws IOException {
ByteBuffer buffer = ByteBuffer.wrap( data );
chksumChannel.write(buffer, 0);
assertThat(chksumChannel.getChecksum(), equalTo(expectedChecksum));
}
@Test
public void shouldReturnSameChecksumOnSecondCall() throws IOException {
ByteBuffer buffer = ByteBuffer.wrap( data );
chksumChannel.write(buffer, 0);
assertThat(chksumChannel.getChecksum(), equalTo(expectedChecksum));
assertThat(chksumChannel.getChecksum(), equalTo(expectedChecksum));
}
@Test
public void shouldSucceedIfWrittenInOrder() throws IOException {
for (int block = 0; block < blockcount; block++) {
chksumChannel.write(buffers[block], block * blocksize);
}
assertThat(chksumChannel.getChecksum(), equalTo(expectedChecksum));
}
@Test
public void shouldSucceedIfWrittenOutOfOrderWithPosition() throws IOException {
int[] blockorder = getRandomPermutationOfBlockOrder();
for (int i = 0; i < blockcount; i++) {
chksumChannel.write(buffers[blockorder[i]], blockorder[i] * blocksize);
}
assertThat(chksumChannel.getChecksum(), equalTo(expectedChecksum));
}
@Test
public void shouldSucceedIfWrittenOutOfOrderWithSingleBuffer() throws IOException {
int[] blockorder = getRandomPermutationOfBlockOrder();
for (int i = 0; i < blockcount; i++) {
chksumChannel.position(blockorder[i] * blocksize);
chksumChannel.write(buffers[blockorder[i]]);
}
assertThat(chksumChannel.getChecksum(), equalTo(expectedChecksum));
}
@Test
public void shouldSucceedIfWrittenInOrderWithMultipleBuffers() throws IOException {
chksumChannel.write(buffers);
assertThat(chksumChannel.getChecksum(), equalTo(expectedChecksum));
}
@Test
public void shouldSucceedIfWrittenInOrderWithMultipleBuffersAndOffset() throws IOException {
ByteBuffer[] buffers = new ByteBuffer[blockcount+2];
buffers[0] = this.buffers[blockcount-1];
buffers[blockcount] = this.buffers[0];
System.arraycopy(this.buffers, 0, buffers, 1, blockcount);
chksumChannel.write(buffers, 1, blockcount);
assertThat(chksumChannel.getChecksum(), equalTo(expectedChecksum));
}
@Test
public void shouldReturnNullDigestOnDoubleWrites() throws IOException {
chksumChannel.write(buffers[0], 0);
buffers[0].rewind();
chksumChannel.write(buffers[0], 0);
assertThat(chksumChannel.getChecksum(), equalTo(null));
}
@Test
public void shouldReturnNullDigestOnPartlyOverlappingWrites() throws IOException {
chksumChannel.write(buffers[1], blocksize);
chksumChannel.write(buffers[0], blocksize - 1);
if (blocksize == 1) {
fail("Pick a blocksize > 1 for testing correct handling of partly overlapping writes!");
}
assertThat(chksumChannel.getChecksum(), equalTo(null));
}
@Test
public void shouldNotUpdateChecksumForIncompleteWritesWithZeroByteWritesToChannelWithSingleBuffer () throws IOException, NoSuchAlgorithmException {
RepositoryChannel mockRepositoryChannel = mock(RepositoryChannel.class);
when(mockRepositoryChannel.write(any(), anyInt())).thenReturn(0);
ChecksumFactory factory = ChecksumFactory.getFactory(ChecksumType.MD5_TYPE);
ChecksumChannel csc = new ChecksumChannel(mockRepositoryChannel, factory);
csc.write(buffers[0]);
}
@Test
public void shouldNotUpdateChecksumForIncompleteWritesWithZeroByteWritesToChannelWithSingleBufferAndPosition () throws IOException, NoSuchAlgorithmException {
RepositoryChannel mockRepositoryChannel = mock(RepositoryChannel.class);
when(mockRepositoryChannel.write(any(), anyInt())).thenReturn(0);
ChecksumFactory factory = ChecksumFactory.getFactory(ChecksumType.MD5_TYPE);
ChecksumChannel csc = new ChecksumChannel(mockRepositoryChannel, factory);
csc.write(buffers[0], 0);
}
@Test
public void shouldUpdateChecksumOnlyForWrittenBytesOnIncompleteWritesWithSingleBufferAndPosition () throws IOException, NoSuchAlgorithmException {
RepositoryChannel mockRepositoryChannel = mock(RepositoryChannel.class);
when(mockRepositoryChannel.write(any(), anyInt())).thenReturn(1);
when(mockRepositoryChannel.read(any(), eq(3L))).thenReturn(1);
ChecksumFactory factory = ChecksumFactory.getFactory(ChecksumType.MD5_TYPE);
ChecksumChannel csc = new ChecksumChannel(mockRepositoryChannel, factory);
csc.write(buffers[0], 0);
csc.write(buffers[1], 1);
csc.write(buffers[3], 3);
csc.write(buffers[2], 2);
assertNotNull(csc.getChecksum());
}
@Test
public void shouldNotUpdateChecksumForIncompleteWritesWithZeroBytesWritesToChannelWithMultipleBuffers () throws IOException, NoSuchAlgorithmException {
RepositoryChannel mockRepositoryChannel = mock(RepositoryChannel.class);
when(mockRepositoryChannel.write(any(), anyInt())).thenReturn(0);
ChecksumFactory factory = ChecksumFactory.getFactory(ChecksumType.MD5_TYPE);
ChecksumChannel csc = new ChecksumChannel(mockRepositoryChannel, factory);
csc.write(buffers);
}
@Test
public void shouldNotUpdateChecksumForIncompleteWritesWithZeroByteWritesToChannelWithMultipleBuffersAndOffset () throws IOException, NoSuchAlgorithmException {
RepositoryChannel mockRepositoryChannel = mock(RepositoryChannel.class);
when(mockRepositoryChannel.write(any(), anyInt())).thenReturn(0);
ChecksumFactory factory = ChecksumFactory.getFactory(ChecksumType.MD5_TYPE);
ChecksumChannel csc = new ChecksumChannel(mockRepositoryChannel, factory);
csc.write(buffers, 0, blockcount);
}
@Test
public void shouldUpdateChecksumSynchronizedForMultiThreadedWrites() throws IOException, InterruptedException {
class Writer implements Runnable {
ChecksumChannel channel;
ByteBuffer block;
int position;
public Writer(ChecksumChannel channel, ByteBuffer block, int position) {
this.channel = channel;
this.block = block;
this.position = position;
}
@Override
public void run() {
try {
channel.write(block, position);
} catch (IOException e) {
e.printStackTrace();
}
}
}
int[] blockorder = getRandomPermutationOfBlockOrder();
List<Thread> writers = newArrayList();
for (int i = 0; i < blockcount; i++) {
writers.add(new Thread(new Writer(chksumChannel, buffers[blockorder[i]], blockorder[i] * blocksize)));
}
writers.forEach(Thread::start);
for (Thread writer: writers) {
writer.join();
}
assertThat(chksumChannel.getChecksum(), equalTo(expectedChecksum));
}
@Test(expected = IllegalStateException.class)
public void shouldThrowIllegalStateExceptionOnWritesAfterGetChecksum() throws IOException {
chksumChannel.getChecksum();
chksumChannel.write(buffers[0], 0);
}
@Test
public void shouldBeAbleToReadBackRangesOfSizeGreater4Gb() throws IOException {
/*
Because of the way the methods in this test are mocked,
readBackCapacity has to be a divisor of writeBufferCapacity
*/
int readBackCapacity = KiB.toBytes(256);
int writeBufferCapacity = 4 * readBackCapacity;
ByteBuffer writeBuffer = ByteBuffer.allocate(writeBufferCapacity);
chksumChannel._readBackBuffer = ByteBuffer.allocate(readBackCapacity);
chksumChannel._channel = mock(FileRepositoryChannel.class);
when(chksumChannel._channel.write(any(), anyLong())).thenReturn(writeBufferCapacity);
when(chksumChannel._channel.read(any(), longThat(lessThan(0L)))).thenThrow(new IllegalArgumentException("Negative Position"));
when(chksumChannel._channel.read(any(), longThat(greaterThanOrEqualTo(0L)))).thenReturn(readBackCapacity);
for (long i = writeBuffer.capacity(); i < 2L*Integer.MAX_VALUE; i += writeBufferCapacity) {
chksumChannel.write(writeBuffer, i);
writeBuffer.rewind();
}
chksumChannel.write(writeBuffer, 0);
assertThat(chksumChannel.getChecksum(), notNullValue());
}
@Test
public void shouldBeAbleToFillZeroRangesOfSizeGreater4Gb() throws IOException {
chksumChannel._zerosBuffer = ByteBuffer.allocate(KiB.toBytes(256));
chksumChannel._channel = mock(FileRepositoryChannel.class);
when(chksumChannel._channel.write(any(), anyLong())).thenReturn(buffers[0].capacity());
when(chksumChannel._channel.read(any(), anyLong())).thenReturn(2);
when(chksumChannel._channel.size()).thenReturn(2L*Integer.MAX_VALUE + 2);
chksumChannel.write(buffers[0], 2L*Integer.MAX_VALUE);
assertThat(chksumChannel.getChecksum(), notNullValue());
}
@Test
public void shouldFillUpRangeGapsWithZerosOnGetChecksum() throws IOException {
Map<Long, ByteBuffer> nonZeroBlocksFromByteArray = getNonZeroBlocksFromByteArray(data);
for (Long position : nonZeroBlocksFromByteArray.keySet()) {
chksumChannel.write(nonZeroBlocksFromByteArray.get(position), position);
}
assertThat(chksumChannel.getChecksum(), equalTo(expectedChecksum));
}
private Map<Long, ByteBuffer> getNonZeroBlocksFromByteArray(byte[] bytes) {
Map<Long, ByteBuffer> result = new TreeMap<>();
for (int position = 0; position < bytes.length; position++) {
if (bytes[position] > 0) {
int blockEnd = position;
while (blockEnd < bytes.length && bytes[blockEnd] != 0) { blockEnd++; }
result.put((long) position, ByteBuffer.wrap(Arrays.copyOfRange(bytes, position, blockEnd)));
position = blockEnd;
}
}
return result;
}
private int[] getRandomPermutationOfBlockOrder() {
Integer[] blockSequence = new Integer[blockcount];
for (int i = 0; i < blockcount; i++) {
blockSequence[i] = i;
}
List<Integer> blockNumberList = newArrayList(blockSequence);
Collections.shuffle(blockNumberList);
int[] result = new int[blockcount];
for (int i = 0; i < blockcount; i++) {
result[i] = blockNumberList.get(i);
}
return result;
}
}