package org.dcache.pool.movers; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.DiscreteDomain; import com.google.common.collect.Range; import com.google.common.collect.RangeSet; import com.google.common.collect.TreeRangeSet; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.InterruptedIOException; import java.io.IOException; import java.nio.ByteBuffer; import java.nio.channels.ReadableByteChannel; import java.nio.channels.SeekableByteChannel; import java.nio.channels.WritableByteChannel; import java.security.MessageDigest; import java.util.ArrayList; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantReadWriteLock; import javax.annotation.concurrent.GuardedBy; import diskCacheV111.util.ChecksumFactory; import org.dcache.pool.repository.RepositoryChannel; import org.dcache.util.Checksum; import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.Lists.newArrayList; import static org.dcache.util.ByteUnit.KiB; /** * A wrapper for RepositoryChannel that computes a digest * on the fly during write as long as all writes are * sequential. */ public class ChecksumChannel implements RepositoryChannel { private static final Logger _log = LoggerFactory.getLogger(ChecksumChannel.class); /** * Inner channel to which all operations are delegated. */ @VisibleForTesting RepositoryChannel _channel; /** * Factory object for creating digests. */ private final ChecksumFactory _checksumFactory; /** * Digest used for computing the checksum during write. */ private final MessageDigest _digest; /** * Cached checksum after getChecksum is called the first time. */ private Checksum _finalChecksum; /** * RangeSet to keep track of written bytes */ private final RangeSet<Long> _dataRangeSet = TreeRangeSet.create(); /** * The offset where the checksum was calculated. */ @GuardedBy("_digest") private long _nextChecksumOffset = 0L; /** * Flag to indicate whether it is still possible to calculated a checksum */ private volatile boolean _isChecksumViable = true; /** * Flag to indicate whether we still allow writing to the channel. * This flag is set to false after getChecksum has been called. */ @GuardedBy("_ioStateLock") private boolean _isWritable = true; /** * Lock to protect _isWritable field. */ private final ReentrantReadWriteLock _ioStateLock = new ReentrantReadWriteLock(); private final Lock _ioStateReadLock = _ioStateLock.readLock(); private final Lock _ioStateWriteLock = _ioStateLock.writeLock(); /** * Buffer to be used for reading data back from the inner channel for * checksum calculations. */ @VisibleForTesting ByteBuffer _readBackBuffer = ByteBuffer.allocate(KiB.toBytes(256)); /* * Static buffer with zeros shared with in all instances of ChecksumChannel. */ private static final ByteBuffer ZERO_BUFFER = ByteBuffer .allocate(KiB.toBytes(256)) .asReadOnlyBuffer(); /** * Buffer to be used for feeding the checksum digester with 0s to fill up * gaps in ranges. */ @VisibleForTesting ByteBuffer _zerosBuffer = ZERO_BUFFER.duplicate(); public ChecksumChannel(RepositoryChannel inner, ChecksumFactory checksumFactory) { _channel = inner; _checksumFactory = checksumFactory; _digest = _checksumFactory.create(); } @Override public long position() throws IOException { return _channel.position(); } @Override public SeekableByteChannel position(long position) throws IOException { return _channel.position(position); } @Override public long size() throws IOException { return _channel.size(); } @Override public int write(ByteBuffer buffer, long position) throws IOException { _ioStateReadLock.lock(); try { checkState(_isWritable, "ChecksumChannel must not be written to after getChecksum"); int bytes; if (_isChecksumViable) { ByteBuffer readOnly = buffer.asReadOnlyBuffer(); bytes = _channel.write(buffer, position); updateChecksum(readOnly, position, bytes); } else { bytes = _channel.write(buffer, position); } return bytes; } finally { _ioStateReadLock.unlock(); } } @Override public int read(ByteBuffer buffer, long position) throws IOException { return _channel.read(buffer, position); } @Override public SeekableByteChannel truncate(long size) throws IOException { return _channel.truncate(size); } @Override public void sync() throws IOException { _channel.sync(); } @Override public long transferTo(long position, long count, WritableByteChannel target) throws IOException { return _channel.transferTo(position, count, target); } @Override public long transferFrom(ReadableByteChannel src, long position, long count) throws IOException { _isChecksumViable = false; return _channel.transferFrom(src, position, count); } @Override public int write(ByteBuffer src) throws IOException { _ioStateReadLock.lock(); try { checkState(_isWritable, "ChecksumChannel must not be written to after getChecksum"); int bytes; if (_isChecksumViable) { bytes = writeWithChecksumUpdate(src); } else { bytes = _channel.write(src); } return bytes; } finally { _ioStateReadLock.unlock(); } } @Override public long write(ByteBuffer[] srcs, int offset, int length) throws IOException { _ioStateReadLock.lock(); try { checkState(_isWritable, "ChecksumChannel must not be written to after getChecksum"); long bytes = 0; if (_isChecksumViable) { for (int i = offset; i < offset + length; i++) { bytes += writeWithChecksumUpdate(srcs[i]); } } else { bytes = _channel.write(srcs, offset, length); } return bytes; } finally { _ioStateReadLock.unlock(); } } @Override public synchronized long write(ByteBuffer[] srcs) throws IOException { return write(srcs, 0, srcs.length); } @Override public boolean isOpen() { return _channel.isOpen(); } @Override public void close() throws IOException { _channel.close(); } @Override public long read(ByteBuffer[] dsts, int offset, int length) throws IOException { return _channel.read(dsts, offset, length); } @Override public long read(ByteBuffer[] dsts) throws IOException { return _channel.read(dsts); } @Override public int read(ByteBuffer dst) throws IOException { return _channel.read(dst); } /** * @return final checksum of this channel */ public Checksum getChecksum() { if (!_isChecksumViable) { return null; } if (_finalChecksum == null) { _finalChecksum = finalizeChecksum(); } return _finalChecksum; } /** * Returns the computed digest or null if overlapping writes have been detected. * * @return Checksum */ private Checksum finalizeChecksum() { _ioStateWriteLock.lock(); try { _isWritable = false; } finally { _ioStateWriteLock.unlock(); } // we need to synchronize on rangeSet and digest get exclusive access synchronized (_dataRangeSet) { synchronized (_digest) { try { if (_dataRangeSet.asRanges().size() != 1 || _nextChecksumOffset == 0) { feedZerosToDigesterForRangeGaps(); } return _checksumFactory.create(_digest.digest()); } catch (IOException e) { _log.info("Unable to generate checksum of sparse file: {}", e.toString()); return null; } } } } private void feedZerosToDigesterForRangeGaps() throws IOException { ArrayList<Range<Long>> complement = newArrayList(_dataRangeSet.complement().subRangeSet(Range.closed(0L, size())).asRanges()); complement.sort((r1, r2) -> r1.lowerEndpoint().compareTo(r2.lowerEndpoint())); for (Range<Long> range : complement) { long bytesToWrite = range.upperEndpoint() - range.lowerEndpoint(); long chunkOffset = range.lowerEndpoint(); while (bytesToWrite > 0) { _zerosBuffer.clear(); long chunkSize = Math.min(_zerosBuffer.capacity(), bytesToWrite); _zerosBuffer.limit((int)chunkSize); updateChecksum(_zerosBuffer, chunkOffset, _zerosBuffer.limit()); chunkOffset += chunkSize; bytesToWrite -= chunkSize; } } } private int writeWithChecksumUpdate(ByteBuffer src) throws IOException { int writtenBytes; ByteBuffer readOnly = src.asReadOnlyBuffer(); long updatePosition = position(); writtenBytes = _channel.write(src); updateChecksum(readOnly, updatePosition, writtenBytes); return writtenBytes; } /** * @param buffer buffer containing the data * @param position position of the data in the target file * @param bytes number of bytes to use from the input data * @throws IOException */ @VisibleForTesting void updateChecksum(ByteBuffer buffer, long position, int bytes) throws IOException { if (bytes == 0) { return; } if (bytes < buffer.remaining()) { buffer.limit(buffer.position() + bytes); } Range<Long> writeRange = Range.closed(position, position + buffer.remaining() - 1).canonical(DiscreteDomain.longs()); Range<Long> fileStartRange; synchronized (_dataRangeSet) { RangeSet<Long> overlappingRanges = _dataRangeSet.subRangeSet(writeRange); if (!overlappingRanges.isEmpty()) { _isChecksumViable = false; _log.info("On-transfer checksum aborted due to overlapping writes from client."); return; } fileStartRange = _dataRangeSet.rangeContaining(0L); boolean canCalculateChecksum = position == 0 || (fileStartRange != null && fileStartRange.upperEndpoint() == position); _dataRangeSet.add(writeRange); if (!canCalculateChecksum) { return; } // get it again as we may have merged two segments fileStartRange = _dataRangeSet.rangeContaining(0L); } synchronized (_digest) { /* * we are one of the threads which got the merge into continues block. * Nevertheless, there may be a different thread which needs to update * ahead of us. Wait for our turn. */ while(_nextChecksumOffset != position) { try { _digest.wait(); } catch (InterruptedException e) { throw new InterruptedIOException(); } } long bytesToRead = fileStartRange.upperEndpoint() - position; // update current buffer and then keep procesing following blocks, if any bytesToRead -= buffer.remaining(); // update offset prior digest calculation as digets#update will update position in the buffer _nextChecksumOffset += buffer.remaining(); _digest.update(buffer); long expectedOffsetAfterRead = _nextChecksumOffset + bytesToRead; try { while (bytesToRead > 0) { _readBackBuffer.clear(); long limit = Math.min(_readBackBuffer.capacity(), bytesToRead); _readBackBuffer.limit((int) limit); int lastBytesRead = _channel.read(_readBackBuffer, _nextChecksumOffset); if (lastBytesRead < 0) { throw new IOException("Checksum: Unexpectedly hit end-of-stream while reading data back from channel."); } _readBackBuffer.flip(); _digest.update(_readBackBuffer); bytesToRead -= lastBytesRead; _nextChecksumOffset += lastBytesRead; } } catch (IOException | RuntimeException e) { _isChecksumViable = false; throw e; } finally { _nextChecksumOffset = expectedOffsetAfterRead; _digest.notifyAll(); } } } }