package org.jboss.pitbull.internal.nio.http;
import org.jboss.pitbull.PitbullChannel;
import org.jboss.pitbull.ReadTimeoutException;
import org.jboss.pitbull.internal.logging.Logger;
import org.jboss.pitbull.internal.nio.socket.ByteBuffers;
import java.io.IOException;
import java.io.InterruptedIOException;
import java.nio.ByteBuffer;
import java.util.concurrent.TimeUnit;
import static java.lang.Math.*;
/**
* An input stream which reads from a stream source channel with a buffer. It will only read a total fixed length
* set of total bytes. In addition, the
* {@link #available()} method can be used to determine whether the next read will or will not block.
*
* @apiviz.exclude
* @since 2.1
*/
public class ContentLengthInputStream extends ContentInputStream
{
private final PitbullChannel channel;
private final ByteBuffer buffer;
private long remainingBytes;
private volatile boolean closed;
protected static final Logger log = Logger.getLogger(ContentLengthInputStream.class);
/**
* Construct a new instance.
*
* @param channel the channel to wrap
* @param bufferSize the size of the internal buffer
*/
public ContentLengthInputStream(final PitbullChannel channel, ByteBuffer buffer, long contentLength)
{
if (channel == null)
{
throw new NullPointerException("channel is null");
}
if (buffer == null)
{
throw new NullPointerException("buffer is null");
}
this.buffer = buffer;
this.channel = channel;
this.remainingBytes = contentLength;
}
protected void resetBufferLimit()
{
if (remainingBytes == 0)
{
log.trace("resetBufferLimit with 0 remaining bytes");
}
buffer.clear();
if (buffer.capacity() > remainingBytes) buffer.limit((int) (remainingBytes));
}
@Override
public void eat()
{
if (remainingBytes > 0)
{
log.trace("Need to skip bytes: {0}", remainingBytes);
}
while (remainingBytes > 0)
{
try
{
skip(1000);
}
catch (IOException e)
{
throw new RuntimeException(e);
}
}
}
/**
* Read a byte, blocking if necessary.
*
* @return the byte read, or -1 if the end of the stream has been reached
* @throws java.io.IOException if an I/O error occurs
*/
public int read() throws IOException
{
if (closed) return -1;
if (remainingBytes <= 0) return -1;
final ByteBuffer buffer = this.buffer;
//final SocketChannel channel = this.channel.getChannel();
final long timeout = this.timeout;
if (timeout == 0L)
{
while (!buffer.hasRemaining())
{
resetBufferLimit();
final int res = channel.readBlocking(buffer);
if (res == -1)
{
return -1;
}
buffer.flip();
}
}
else
{
if (!buffer.hasRemaining())
{
long now = System.currentTimeMillis();
final long deadline = timeout - now;
do
{
resetBufferLimit();
if (deadline <= now)
{
throw new ReadTimeoutException("Read timed out");
}
final int res = channel.readBlocking(buffer, deadline - now, TimeUnit.MILLISECONDS);
if (res == -1)
{
return -1;
}
buffer.flip();
} while (!buffer.hasRemaining());
}
}
remainingBytes--;
return buffer.get() & 0xff;
}
/**
* Read bytes into an array.
*
* @param b the destination array
* @param off the offset into the array at which bytes should be filled
* @param len the number of bytes to fill
* @return the number of bytes read, or -1 if the end of the stream has been reached
* @throws java.io.IOException if an I/O error occurs
*/
public int read(final byte[] b, int off, int len) throws IOException
{
if (closed) return -1;
if (remainingBytes <= 0) return -1;
if (len < 1)
{
return 0;
}
if (len > remainingBytes)
{
len = (int) remainingBytes;
}
int total = 0;
final ByteBuffer buffer = this.buffer;
if (buffer.hasRemaining())
{
final int cnt = min(buffer.remaining(), len);
buffer.get(b, off, cnt);
total += cnt;
off += cnt;
len -= cnt;
remainingBytes -= cnt;
}
if (closed) return -1;
if (len <= 0) return total;
final long timeout = this.timeout;
try
{
if (timeout == 0L)
{
final ByteBuffer dst = ByteBuffer.wrap(b, off, len);
int res = total > 0 ? channel.read(dst) : channel.readBlocking(dst);
if (res == -1)
{
return total == 0 ? -1 : total;
}
else if (res == 0)
{
return total;
}
else
{
remainingBytes -= res;
total += res;
return total;
}
}
else
{
final ByteBuffer dst = ByteBuffer.wrap(b, off, len);
int res;
if (total > 0)
{
res = channel.read(dst);
}
else
{
res = channel.readBlocking(dst, timeout, TimeUnit.MILLISECONDS);
if (res == 0)
{
throw new ReadTimeoutException("Read timed out");
}
}
if (res == -1)
{
return total == 0 ? -1 : total;
}
else if (res == 0)
{
return total;
}
else
{
remainingBytes -= res;
total += res;
return total;
}
}
}
catch (InterruptedIOException e)
{
e.bytesTransferred = total;
throw e;
}
}
/**
* Skip bytes in the stream.
*
* @param n the number of bytes to skip
* @return the number of bytes skipped (0 if the end of stream has been reached)
* @throws java.io.IOException if an I/O error occurs
*/
public long skip(long n) throws IOException
{
if (closed) return 0L;
if (remainingBytes <= 0) return 0L;
if (n < 1L)
{
return 0L;
}
long total = 0L;
final ByteBuffer buffer = this.buffer;
if (buffer.hasRemaining())
{
final int cnt = (int) min(buffer.remaining(), n);
ByteBuffers.skip(buffer, cnt);
remainingBytes -= cnt;
total += cnt;
n -= cnt;
}
if (closed)
{
return total;
}
if (n > remainingBytes)
{
n = remainingBytes;
}
if (n > 0L)
{
// Buffer was cleared
try
{
while (n > 0L)
{
resetBufferLimit();
int res = total > 0L ? channel.read(buffer) : channel.readBlocking(buffer);
if (res <= 0)
{
return total;
}
total += (long) res;
remainingBytes -= res;
}
}
finally
{
buffer.position(0).limit(0);
}
}
return total;
}
/**
* Return the number of bytes available to read, or 0 if a subsequent {@code read()} operation would block.
*
* @return the number of ready bytes, or 0 for none
* @throws java.io.IOException if an I/O error occurs
*/
public int available() throws IOException
{
final ByteBuffer buffer = this.buffer;
final int rem = buffer.remaining();
if (rem > 0 || closed)
{
return rem;
}
resetBufferLimit();
try
{
channel.read(buffer);
}
catch (IOException e)
{
buffer.limit(0);
throw e;
}
buffer.flip();
return buffer.remaining();
}
/**
* Close the stream. Shuts down the channel's read side.
*
* @throws java.io.IOException if an I/O error occurs
*/
public void close() throws IOException
{
closed = true;
}
}