/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.cassandra.io.util; import java.io.Closeable; import java.io.File; import java.io.FilterInputStream; import java.io.IOException; import java.io.InputStream; import java.io.RandomAccessFile; import java.util.concurrent.atomic.AtomicBoolean; import static org.apache.cassandra.utils.Throwables.maybeFail; import static org.apache.cassandra.utils.Throwables.merge; /** * Adds mark/reset functionality to another input stream by caching read bytes to a memory buffer and * spilling to disk if necessary. * * When the stream is marked via {@link #mark()} or {@link #mark(int)}, up to * <code>maxMemBufferSize</code> will be cached in memory (heap). If more than * <code>maxMemBufferSize</code> bytes are read while the stream is marked, the * following bytes are cached on the <code>spillFile</code> for up to <code>maxDiskBufferSize</code>. * * Please note that successive calls to {@link #mark()} and {@link #reset()} will write * sequentially to the same <code>spillFile</code> until <code>maxDiskBufferSize</code> is reached. * At this point, if less than <code>maxDiskBufferSize</code> bytes are currently cached on the * <code>spillFile</code>, the remaining bytes are written to the beginning of the file, * treating the <code>spillFile</code> as a circular buffer. * * If more than <code>maxMemBufferSize + maxDiskBufferSize</code> are cached while the stream is marked, * the following {@link #reset()} invocation will throw a {@link IllegalStateException}. * */ public class RewindableDataInputStreamPlus extends FilterInputStream implements RewindableDataInput, Closeable { private boolean marked = false; private boolean exhausted = false; private AtomicBoolean closed = new AtomicBoolean(false); protected int memAvailable = 0; protected int diskTailAvailable = 0; protected int diskHeadAvailable = 0; private final File spillFile; private final int initialMemBufferSize; private final int maxMemBufferSize; private final int maxDiskBufferSize; private volatile byte memBuffer[]; private int memBufferSize; private RandomAccessFile spillBuffer; private final DataInputPlus dataReader; public RewindableDataInputStreamPlus(InputStream in, int initialMemBufferSize, int maxMemBufferSize, File spillFile, int maxDiskBufferSize) { super(in); dataReader = new DataInputStreamPlus(this); this.initialMemBufferSize = initialMemBufferSize; this.maxMemBufferSize = maxMemBufferSize; this.spillFile = spillFile; this.maxDiskBufferSize = maxDiskBufferSize; } /* RewindableDataInput methods */ /** * Marks the current position of a stream to return to this position later via the {@link #reset(DataPosition)} method. * @return An empty @link{DataPosition} object */ public DataPosition mark() { mark(0); return new RewindableDataInputPlusMark(); } /** * Rewinds to the previously marked position via the {@link #mark()} method. * @param mark it's not possible to return to a custom position, so this parameter is ignored. * @throws IOException if an error ocurs while resetting */ public void reset(DataPosition mark) throws IOException { reset(); } public long bytesPastMark(DataPosition mark) { return maxMemBufferSize - memAvailable + (diskTailAvailable == -1? 0 : maxDiskBufferSize - diskHeadAvailable - diskTailAvailable); } protected static class RewindableDataInputPlusMark implements DataPosition { } /* InputStream methods */ public boolean markSupported() { return true; } /** * Marks the current position of a stream to return to this position * later via the {@link #reset()} method. * @param readlimit the maximum amount of bytes to cache */ public synchronized void mark(int readlimit) { if (marked) throw new IllegalStateException("Cannot mark already marked stream."); if (memAvailable > 0 || diskHeadAvailable > 0 || diskTailAvailable > 0) throw new IllegalStateException("Can only mark stream after reading previously marked data."); marked = true; memAvailable = maxMemBufferSize; diskHeadAvailable = -1; diskTailAvailable = -1; } public synchronized void reset() throws IOException { if (!marked) throw new IOException("Must call mark() before calling reset()."); if (exhausted) throw new IOException(String.format("Read more than capacity: %d bytes.", maxMemBufferSize + maxDiskBufferSize)); memAvailable = maxMemBufferSize - memAvailable; memBufferSize = memAvailable; if (diskTailAvailable == -1) { diskHeadAvailable = 0; diskTailAvailable = 0; } else { int initialPos = diskTailAvailable > 0 ? 0 : (int)getIfNotClosed(spillBuffer).getFilePointer(); int diskMarkpos = initialPos + diskHeadAvailable; getIfNotClosed(spillBuffer).seek(diskMarkpos); diskHeadAvailable = diskMarkpos - diskHeadAvailable; diskTailAvailable = (maxDiskBufferSize - diskTailAvailable) - diskMarkpos; } marked = false; } public int available() throws IOException { return super.available() + (marked? 0 : memAvailable + diskHeadAvailable + diskTailAvailable); } public int read() throws IOException { int read = readOne(); if (read == -1) return read; if (marked) { //mark exhausted if (isExhausted(1)) { exhausted = true; return read; } writeOne(read); } return read; } public int read(byte[] b, int off, int len) throws IOException { int readBytes = readMulti(b, off, len); if (readBytes == -1) return readBytes; if (marked) { //check we have space on buffer if (isExhausted(readBytes)) { exhausted = true; return readBytes; } writeMulti(b, off, readBytes); } return readBytes; } private void maybeCreateDiskBuffer() throws IOException { if (spillBuffer == null) { if (!spillFile.getParentFile().exists()) spillFile.getParentFile().mkdirs(); spillFile.createNewFile(); this.spillBuffer = new RandomAccessFile(spillFile, "rw"); } } private int readOne() throws IOException { if (!marked) { if (memAvailable > 0) { int pos = memBufferSize - memAvailable; memAvailable--; return getIfNotClosed(memBuffer)[pos] & 0xff; } if (diskTailAvailable > 0 || diskHeadAvailable > 0) { int read = getIfNotClosed(spillBuffer).read(); if (diskTailAvailable > 0) diskTailAvailable--; else if (diskHeadAvailable > 0) diskHeadAvailable++; if (diskTailAvailable == 0) spillBuffer.seek(0); return read; } } return getIfNotClosed(in).read(); } private boolean isExhausted(int readBytes) { return exhausted || readBytes > memAvailable + (long)(diskTailAvailable == -1? maxDiskBufferSize : diskTailAvailable + diskHeadAvailable); } private int readMulti(byte[] b, int off, int len) throws IOException { int readBytes = 0; if (!marked) { if (memAvailable > 0) { readBytes += memAvailable < len ? memAvailable : len; int pos = memBufferSize - memAvailable; System.arraycopy(memBuffer, pos, b, off, readBytes); memAvailable -= readBytes; off += readBytes; len -= readBytes; } if (len > 0 && diskTailAvailable > 0) { int readFromTail = diskTailAvailable < len? diskTailAvailable : len; readFromTail = getIfNotClosed(spillBuffer).read(b, off, readFromTail); readBytes += readFromTail; diskTailAvailable -= readFromTail; off += readFromTail; len -= readFromTail; if (diskTailAvailable == 0) spillBuffer.seek(0); } if (len > 0 && diskHeadAvailable > 0) { int readFromHead = diskHeadAvailable < len? diskHeadAvailable : len; readFromHead = getIfNotClosed(spillBuffer).read(b, off, readFromHead); readBytes += readFromHead; diskHeadAvailable -= readFromHead; off += readFromHead; len -= readFromHead; } } if (len > 0) readBytes += getIfNotClosed(in).read(b, off, len); return readBytes; } private void writeMulti(byte[] b, int off, int len) throws IOException { if (memAvailable > 0) { if (memBuffer == null) memBuffer = new byte[initialMemBufferSize]; int pos = maxMemBufferSize - memAvailable; int memWritten = memAvailable < len? memAvailable : len; if (pos + memWritten >= getIfNotClosed(memBuffer).length) growMemBuffer(pos, memWritten); System.arraycopy(b, off, memBuffer, pos, memWritten); off += memWritten; len -= memWritten; memAvailable -= memWritten; } if (len > 0) { if (diskTailAvailable == -1) { maybeCreateDiskBuffer(); diskHeadAvailable = (int)spillBuffer.getFilePointer(); diskTailAvailable = maxDiskBufferSize - diskHeadAvailable; } if (len > 0 && diskTailAvailable > 0) { int diskTailWritten = diskTailAvailable < len? diskTailAvailable : len; getIfNotClosed(spillBuffer).write(b, off, diskTailWritten); off += diskTailWritten; len -= diskTailWritten; diskTailAvailable -= diskTailWritten; if (diskTailAvailable == 0) spillBuffer.seek(0); } if (len > 0 && diskTailAvailable > 0) { int diskHeadWritten = diskHeadAvailable < len? diskHeadAvailable : len; getIfNotClosed(spillBuffer).write(b, off, diskHeadWritten); } } } private void writeOne(int value) throws IOException { if (memAvailable > 0) { if (memBuffer == null) memBuffer = new byte[initialMemBufferSize]; int pos = maxMemBufferSize - memAvailable; if (pos == getIfNotClosed(memBuffer).length) growMemBuffer(pos, 1); getIfNotClosed(memBuffer)[pos] = (byte)value; memAvailable--; return; } if (diskTailAvailable == -1) { maybeCreateDiskBuffer(); diskHeadAvailable = (int)spillBuffer.getFilePointer(); diskTailAvailable = maxDiskBufferSize - diskHeadAvailable; } if (diskTailAvailable > 0 || diskHeadAvailable > 0) { getIfNotClosed(spillBuffer).write(value); if (diskTailAvailable > 0) diskTailAvailable--; else if (diskHeadAvailable > 0) diskHeadAvailable--; if (diskTailAvailable == 0) spillBuffer.seek(0); return; } } public int read(byte[] b) throws IOException { return read(b, 0, b.length); } private void growMemBuffer(int pos, int writeSize) { int newSize = Math.min(2 * (pos + writeSize), maxMemBufferSize); byte newBuffer[] = new byte[newSize]; System.arraycopy(memBuffer, 0, newBuffer, 0, pos); memBuffer = newBuffer; } public long skip(long n) throws IOException { long skipped = 0; if (marked) { //if marked, we need to cache skipped bytes while (n-- > 0 && read() != -1) { skipped++; } return skipped; } if (memAvailable > 0) { skipped += memAvailable < n ? memAvailable : n; memAvailable -= skipped; n -= skipped; } if (n > 0 && diskTailAvailable > 0) { int skipFromTail = diskTailAvailable < n? diskTailAvailable : (int)n; getIfNotClosed(spillBuffer).skipBytes(skipFromTail); diskTailAvailable -= skipFromTail; skipped += skipFromTail; n -= skipFromTail; if (diskTailAvailable == 0) spillBuffer.seek(0); } if (n > 0 && diskHeadAvailable > 0) { int skipFromHead = diskHeadAvailable < n? diskHeadAvailable : (int)n; getIfNotClosed(spillBuffer).skipBytes(skipFromHead); diskHeadAvailable -= skipFromHead; skipped += skipFromHead; n -= skipFromHead; } if (n > 0) skipped += getIfNotClosed(in).skip(n); return skipped; } private <T> T getIfNotClosed(T in) throws IOException { if (closed.get()) throw new IOException("Stream closed"); return in; } public void close() throws IOException { close(true); } public void close(boolean closeUnderlying) throws IOException { if (closed.compareAndSet(false, true)) { Throwable fail = null; if (closeUnderlying) { try { super.close(); } catch (IOException e) { fail = merge(fail, e); } } try { if (spillBuffer != null) { this.spillBuffer.close(); this.spillBuffer = null; } } catch (IOException e) { fail = merge(fail, e); } try { if (spillFile.exists()) { spillFile.delete(); } } catch (Throwable e) { fail = merge(fail, e); } maybeFail(fail, IOException.class); } } /* DataInputPlus methods */ public void readFully(byte[] b) throws IOException { dataReader.readFully(b); } public void readFully(byte[] b, int off, int len) throws IOException { dataReader.readFully(b, off, len); } public int skipBytes(int n) throws IOException { return dataReader.skipBytes(n); } public boolean readBoolean() throws IOException { return dataReader.readBoolean(); } public byte readByte() throws IOException { return dataReader.readByte(); } public int readUnsignedByte() throws IOException { return dataReader.readUnsignedByte(); } public short readShort() throws IOException { return dataReader.readShort(); } public int readUnsignedShort() throws IOException { return dataReader.readUnsignedShort(); } public char readChar() throws IOException { return dataReader.readChar(); } public int readInt() throws IOException { return dataReader.readInt(); } public long readLong() throws IOException { return dataReader.readLong(); } public float readFloat() throws IOException { return dataReader.readFloat(); } public double readDouble() throws IOException { return dataReader.readDouble(); } public String readLine() throws IOException { return dataReader.readLine(); } public String readUTF() throws IOException { return dataReader.readUTF(); } }