/** * Copyright 2015 Palantir Technologies, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.palantir.giraffe.internal; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkElementIndex; import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkPositionIndex; import java.io.IOException; import java.io.InputStream; import java.io.InterruptedIOException; import java.io.OutputStream; import java.util.EnumSet; import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.ThreadSafe; import com.google.common.annotations.VisibleForTesting; /** * Provides output and input streams that read and write to the same array. * Automatically increases size as needed. * * @author jchien * @author bkeyes */ @ThreadSafe final class SharedByteArrayStream { // from JDK ArrayList implementation private static final int MAX_BUFFER_SIZE = Integer.MAX_VALUE - 8; private static final int DEFAULT_BUFFER_SIZE = 1024; private enum Mode { READ, WRITE } private final Object lock = new Object(); private final SharedOutputStream outputStream; private final SharedInputStream inputStream; private final int windowSize; // startPosition is the first byte of the window // readPosition is the first byte unread byte // writePosition is the first empty element // empty when readPosition == writePosition // full when (writePosition + 1) % size == startPosition // if READ is in modes, the input stream is open // if WRITE is in modes, the output stream is open @GuardedBy("lock") private final EnumSet<Mode> modes = EnumSet.of(Mode.READ, Mode.WRITE); @GuardedBy("lock") private final byte[] oneByte = new byte[1]; @GuardedBy("lock") private byte[] buffer; @GuardedBy("lock") private int startPosition = 0; @GuardedBy("lock") private int readPosition = 0; @GuardedBy("lock") private int writePosition = 0; public SharedByteArrayStream() { this(Integer.MAX_VALUE); } public SharedByteArrayStream(int windowSize) { this(windowSize, DEFAULT_BUFFER_SIZE); } @VisibleForTesting SharedByteArrayStream(int windowSize, int bufferSize) { checkArgument(windowSize >= 0, "windowSize must be non-negative"); checkArgument(bufferSize > 0, "bufferSize must be positive"); this.windowSize = windowSize; this.buffer = new byte[bufferSize]; outputStream = new SharedOutputStream(); inputStream = new SharedInputStream(); } final class SharedInputStream extends InputStream { private SharedInputStream() {} @Override public int available() throws IOException { synchronized (lock) { checkOpen(Mode.READ); return readSize(); } } @Override public int read() throws IOException { synchronized (lock) { checkOpen(Mode.READ); if (read(oneByte, 0, 1) == -1) { return -1; } else { return oneByte[0]; } } } @Override public int read(byte[] b, int off, int len) throws IOException { checkArray(b, off, len); synchronized (lock) { checkOpen(Mode.READ); if (len == 0) { return 0; } if (isMode(Mode.WRITE)) { waitForInput(); } int total = readSize(); if (total == 0) { return -1; } else { return readFromBuffer(b, off, Math.min(len, total)); } } } @GuardedBy("lock") private void waitForInput() throws IOException { while (readSize() == 0 && isMode(Mode.READ) && isMode(Mode.WRITE)) { try { lock.wait(); } catch (InterruptedException e) { Thread.currentThread().interrupt(); throw new InterruptedIOException(); } } } @Override public void close() { synchronized (lock) { modes.remove(Mode.READ); lock.notifyAll(); } } } final class SharedOutputStream extends OutputStream { private SharedOutputStream() {} @Override public void write(int b) throws IOException { synchronized (lock) { oneByte[0] = (byte) b; write(oneByte, 0, 1); } } @Override public void write(byte[] b, int off, int len) throws IOException { checkArray(b, off, len); synchronized (lock) { checkOpen(Mode.WRITE); if (len == 0 || windowSize == 0) { return; } // truncate writes that exceed the window size int length = len; int offset = off; if (len > windowSize) { offset = off + len - windowSize; length = windowSize; } if (length > writeSize()) { makeSpace(length); } writeToBuffer(b, offset, length); // discard data now outside of the window int bufferedSize = bufferedSize(); if (bufferedSize > windowSize) { advanceStartPosition(bufferedSize - windowSize); } lock.notifyAll(); } } @Override public void close() { synchronized (lock) { modes.remove(Mode.WRITE); lock.notifyAll(); } } } @GuardedBy("lock") private void checkOpen(Mode requiredMode) throws IOException { if (!isMode(requiredMode)) { throw new IOException("Stream is closed"); } } @GuardedBy("lock") private boolean isMode(Mode mode) { return modes.contains(mode); } /** * Total amount of data currently in the buffer. */ @GuardedBy("lock") private int bufferedSize() { int count = writePosition - startPosition; return (count < 0) ? count + buffer.length : count; } /** * Amount of unread data in the buffer. */ @GuardedBy("lock") private int readSize() { int count = writePosition - readPosition; return (count < 0) ? count + buffer.length : count; } /** * Amount of space available for writing in the buffer. */ @GuardedBy("lock") private int writeSize() { int count = startPosition - (writePosition + 1); return (count < 0) ? count + buffer.length : count; } @GuardedBy("lock") private int readFromBuffer(byte[] b, int off, int len) { assert len <= readSize() : "len (" + len + ") > size (" + readSize() + ")"; copyOut(readPosition, b, off, len); advanceReadPosition(len); return len; } @GuardedBy("lock") private void writeToBuffer(byte[] b, int off, int len) { assert len <= writeSize() : "len (" + len + ") > size (" + writeSize() + ")"; copyIn(writePosition, b, off, len); advanceWritePosition(len); } @GuardedBy("lock") private int copyFromBuffer(byte[] b, int off, int len) { assert len <= bufferedSize() : "len (" + len + ") > size (" + bufferedSize() + ")"; copyOut(startPosition, b, off, len); return len; } /** * Copies {@code len} bytes from {@code b} at {@code off} into the buffer at * {@code position}. */ @GuardedBy("lock") private void copyOut(int position, byte[] b, int off, int len) { assert len <= buffer.length : "len (" + len + ") > size (" + buffer.length + ")"; int total = 0; int copyLen = Math.min(len, buffer.length - position); System.arraycopy(buffer, position, b, off, copyLen); total += copyLen; if (total < len) { copyLen = len - total; System.arraycopy(buffer, 0, b, off + total, copyLen); total += copyLen; } } /** * Copies {@code len} bytes from the buffer at {@code position} into * {@code b} at {@code off}. */ @GuardedBy("lock") private void copyIn(int position, byte[] b, int off, int len) { assert len <= buffer.length : "len (" + len + ") > size (" + buffer.length + ")"; int total = 0; int copyLen = Math.min(len, buffer.length - position); System.arraycopy(b, off, buffer, position, copyLen); total += copyLen; if (total < len) { copyLen = len - total; System.arraycopy(b, off + total, buffer, 0, copyLen); total += copyLen; } } /** * Creates space in the buffer, either by resizing or discarding data. */ @GuardedBy("lock") private void makeSpace(int needed) throws IOException { assert needed <= windowSize : "need (" + needed + ") > window size (" + windowSize + ")"; int capacity = buffer.length - 1; if (needed < capacity && capacity > windowSize) { // the new data should fit if we discard data outside the window advanceStartPosition(needed - writeSize()); } else { resize(needed); } } /** * Resizes the buffer until at least {@code needed} positions are available. */ @GuardedBy("lock") private void resize(int needed) throws IOException { int newLength = computeResize(writeSize(), needed, buffer.length); if (newLength < 0) { throw new IOException("maximum buffer size exceeded"); } byte[] newBuffer = new byte[newLength]; int readSize = readSize(); int bufferedSize = bufferedSize(); copyFromBuffer(newBuffer, 0, bufferedSize); buffer = newBuffer; startPosition = 0; readPosition = bufferedSize - readSize; writePosition = bufferedSize; } @GuardedBy("lock") private void advanceStartPosition(int n) { // if we will drop unread data, move the read position if (isCross(startPosition, readPosition, n, buffer.length)) { advanceReadPosition(n); } startPosition = advance(startPosition, n, buffer.length); } @GuardedBy("lock") private void advanceReadPosition(int n) { readPosition = advance(readPosition, n, buffer.length); } @GuardedBy("lock") private void advanceWritePosition(int n) { writePosition = advance(writePosition, n, buffer.length); } private static int advance(int position, int n, int length) { return (int) (((long) position + n) % length); } /** * Determines if advancing {@code position} by {@code n} crosses (ends * strictly ahead of) {@code fixed}. Assumes {@code position} starts behind * or equal with {@code fixed} in the buffer. */ private static boolean isCross(int position, int fixed, int n, int length) { if (position < fixed) { return n > fixed - position; } else if (position > fixed) { return (n - (length - position)) > fixed; } else { return n > 0; } } /** * Computes the new size of the buffer during a resize. * * @param available the amount currently in the buffer * @param needed the amount of space needed for the write * @param length the current buffer length * * @return the size of the new buffer or -1 if the maximum size is exceeded */ @VisibleForTesting int computeResize(int available, int needed, int length) { assert available < needed : "available (" + available + ") >= needed (" + needed + ")"; assert available < length : "available (" + available + ") >= length (" + length + ")"; // long to avoid int overflow during computation long additional = 0; for (int i = 1; additional + available <= needed; i++) { // double length each iteration (2^i - 1, to account for initial length) additional = (long) length * ((1 << i) - 1); } // min() constrains to int range int newLength = (int) Math.min(length + additional, MAX_BUFFER_SIZE); if (newLength - (length - available) - 1 < needed) { return -1; } else { return newLength; } } @VisibleForTesting int capacity() { synchronized (lock) { return buffer.length - 1; } } SharedOutputStream getOutputStream() { return outputStream; } SharedInputStream getInputStream() { return inputStream; } /** * Returns the buffered data from the input stream as a byte array. */ public byte[] getBufferedData() { synchronized (lock) { int bufferedSize = bufferedSize(); byte[] data = new byte[bufferedSize]; copyFromBuffer(data, 0, bufferedSize); return data; } } private static void checkArray(byte[] b, int off, int len) { checkNotNull(b); checkElementIndex(off, b.length); checkPositionIndex(len, b.length - off); } }