/* * JBoss, Home of Professional Open Source. * Copyright 2014 Red Hat, Inc., and individual contributors * as indicated by the @author tags. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package io.undertow.server.protocol.framed; import static org.xnio.Bits.allAreClear; import static org.xnio.Bits.allAreSet; import static org.xnio.Bits.anyAreSet; import java.io.IOException; import java.io.InterruptedIOException; import java.nio.ByteBuffer; import java.nio.channels.FileChannel; import java.util.Deque; import java.util.LinkedList; import java.util.concurrent.TimeUnit; import io.undertow.UndertowLogger; import org.xnio.Buffers; import org.xnio.ChannelListener; import org.xnio.ChannelListeners; import org.xnio.IoUtils; import org.xnio.Option; import io.undertow.connector.PooledByteBuffer; import org.xnio.XnioExecutor; import org.xnio.XnioIoThread; import org.xnio.XnioWorker; import org.xnio.channels.StreamSinkChannel; import org.xnio.channels.StreamSourceChannel; import io.undertow.UndertowMessages; /** * Source channel, used to receive framed messages. * * @author Stuart Douglas */ public abstract class AbstractFramedStreamSourceChannel<C extends AbstractFramedChannel<C, R, S>, R extends AbstractFramedStreamSourceChannel<C, R, S>, S extends AbstractFramedStreamSinkChannel<C, R, S>> implements StreamSourceChannel { private final ChannelListener.SimpleSetter<? extends R> readSetter = new ChannelListener.SimpleSetter(); private final ChannelListener.SimpleSetter<? extends R> closeSetter = new ChannelListener.SimpleSetter(); private final C framedChannel; private final Deque<FrameData> pendingFrameData = new LinkedList<>(); private int state = 0; private static final int STATE_DONE = 1 << 1; private static final int STATE_READS_RESUMED = 1 << 2; private static final int STATE_CLOSED = 1 << 3; private static final int STATE_LAST_FRAME = 1 << 4; private static final int STATE_IN_LISTENER_LOOP = 1 << 5; private static final int STATE_STREAM_BROKEN = 1 << 6; private static final int STATE_RETURNED_MINUS_ONE = 1 << 7; private static final int STATE_WAITNG_MINUS_ONE = 1 << 8; /** * The backing data for the current frame. */ private volatile PooledByteBuffer data; private int currentDataOriginalSize; /** * The amount of data left in the frame. If this is larger than the data in the backing buffer then */ private long frameDataRemaining; private final Object lock = new Object(); private int waiters; private volatile boolean waitingForFrame; private int readFrameCount = 0; private long maxStreamSize = -1; private long currentStreamSize; private ChannelListener[] closeListeners = null; public AbstractFramedStreamSourceChannel(C framedChannel) { this.framedChannel = framedChannel; this.waitingForFrame = true; } public AbstractFramedStreamSourceChannel(C framedChannel, PooledByteBuffer data, long frameDataRemaining) { this.framedChannel = framedChannel; this.waitingForFrame = data == null && frameDataRemaining <= 0; this.frameDataRemaining = frameDataRemaining; this.currentStreamSize = frameDataRemaining; if (data != null) { if (!data.getBuffer().hasRemaining()) { data.close(); this.data = null; this.waitingForFrame = frameDataRemaining <= 0; } else { dataReady(null, data); } } } @Override public long transferTo(long position, long count, FileChannel target) throws IOException { if (anyAreSet(state, STATE_DONE)) { return -1; } beforeRead(); if (waitingForFrame) { return 0; } try { if (frameDataRemaining == 0 && anyAreSet(state, STATE_LAST_FRAME)) { synchronized (lock) { state |= STATE_RETURNED_MINUS_ONE; return -1; } } else if (data != null) { int old = data.getBuffer().limit(); try { if (count < data.getBuffer().remaining()) { data.getBuffer().limit((int) (data.getBuffer().position() + count)); } return target.write(data.getBuffer(), position); } finally { data.getBuffer().limit(old); decrementFrameDataRemaining(); } } return 0; } finally { exitRead(); } } private void decrementFrameDataRemaining() { if(!data.getBuffer().hasRemaining()) { frameDataRemaining -= currentDataOriginalSize; } } @Override public long transferTo(long count, ByteBuffer throughBuffer, StreamSinkChannel streamSinkChannel) throws IOException { if (anyAreSet(state, STATE_DONE)) { return -1; } beforeRead(); if (waitingForFrame) { throughBuffer.position(throughBuffer.limit()); return 0; } try { if (frameDataRemaining == 0 && anyAreSet(state, STATE_LAST_FRAME)) { synchronized (lock) { state |= STATE_RETURNED_MINUS_ONE; return -1; } } else if (data != null && data.getBuffer().hasRemaining()) { int old = data.getBuffer().limit(); try { if (count < data.getBuffer().remaining()) { data.getBuffer().limit((int) (data.getBuffer().position() + count)); } int written = streamSinkChannel.write(data.getBuffer()); if(data.getBuffer().hasRemaining()) { //we can still add more data //stick it it throughbuffer, otherwise transfer code will continue to attempt to use this method throughBuffer.clear(); Buffers.copy(throughBuffer, data.getBuffer()); throughBuffer.flip(); } else { throughBuffer.position(throughBuffer.limit()); } return written; } finally { data.getBuffer().limit(old); decrementFrameDataRemaining(); } } else { throughBuffer.position(throughBuffer.limit()); } return 0; } finally { exitRead(); } } public long getMaxStreamSize() { return maxStreamSize; } public void setMaxStreamSize(long maxStreamSize) { this.maxStreamSize = maxStreamSize; if(maxStreamSize > 0) { if(maxStreamSize < currentStreamSize) { handleStreamTooLarge(); } } } private void handleStreamTooLarge() { IoUtils.safeClose(this); } @Override public void suspendReads() { synchronized (lock) { state &= ~STATE_READS_RESUMED; } } /** * Method that is invoked when all data has been read. * * @throws IOException */ protected void complete() throws IOException { close(); } protected boolean isComplete() { return anyAreSet(state, STATE_DONE); } @Override public void resumeReads() { resumeReadsInternal(false); } @Override public boolean isReadResumed() { return anyAreSet(state, STATE_READS_RESUMED); } @Override public void wakeupReads() { resumeReadsInternal(true); } public void addCloseTask(ChannelListener<R> channelListener) { if(closeListeners == null) { closeListeners = new ChannelListener[]{channelListener}; } else { ChannelListener[] old = closeListeners; closeListeners = new ChannelListener[old.length + 1]; System.arraycopy(old, 0, closeListeners, 0, old.length); closeListeners[old.length] = channelListener; } } /** * For this class there is no difference between a resume and a wakeup */ void resumeReadsInternal(boolean wakeup) { synchronized (lock) { boolean alreadyResumed = anyAreSet(state, STATE_READS_RESUMED); state |= STATE_READS_RESUMED; if (!alreadyResumed || wakeup) { if (!anyAreSet(state, STATE_IN_LISTENER_LOOP)) { state |= STATE_IN_LISTENER_LOOP; getFramedChannel().runInIoThread(new Runnable() { @Override public void run() { try { boolean moreData; do { ChannelListener<? super R> listener = getReadListener(); if (listener == null || !isReadResumed()) { return; } ChannelListeners.invokeChannelListener((R) AbstractFramedStreamSourceChannel.this, listener); //if writes are shutdown or we become active then we stop looping //we stop when writes are shutdown because we can't flush until we are active //although we may be flushed as part of a batch moreData = (frameDataRemaining > 0 && data != null) || !pendingFrameData.isEmpty() || anyAreSet(state, STATE_WAITNG_MINUS_ONE); } while (allAreSet(state, STATE_READS_RESUMED) && allAreClear(state, STATE_CLOSED) && moreData); } finally { state &= ~STATE_IN_LISTENER_LOOP; } } }); } } } } private ChannelListener<? super R> getReadListener() { return (ChannelListener<? super R>) readSetter.get(); } @Override public void shutdownReads() throws IOException { close(); } protected void lastFrame() { synchronized (lock) { state |= STATE_LAST_FRAME; } waitingForFrame = false; if(data == null && pendingFrameData.isEmpty() && frameDataRemaining == 0) { state |= STATE_DONE | STATE_CLOSED; getFramedChannel().notifyFrameReadComplete(this); getFramedChannel().notifyClosed(this); IoUtils.safeClose(this); } } protected boolean isLastFrame() { return anyAreSet(state, STATE_LAST_FRAME); } @Override public void awaitReadable() throws IOException { if(Thread.currentThread() == getIoThread()) { throw UndertowMessages.MESSAGES.awaitCalledFromIoThread(); } if (data == null && pendingFrameData.isEmpty()) { synchronized (lock) { if (data == null && pendingFrameData.isEmpty()) { try { waiters++; lock.wait(); } catch (InterruptedException e) { Thread.currentThread().interrupt(); throw new InterruptedIOException(); } finally { waiters--; } } } } } @Override public void awaitReadable(long l, TimeUnit timeUnit) throws IOException { if(Thread.currentThread() == getIoThread()) { throw UndertowMessages.MESSAGES.awaitCalledFromIoThread(); } if (data == null) { synchronized (lock) { if (data == null) { try { waiters++; lock.wait(timeUnit.toMillis(l)); } catch (InterruptedException e) { Thread.currentThread().interrupt(); throw new InterruptedIOException(); } finally { waiters--; } } } } } /** * Called when data has been read from the underlying channel. * * @param headerData The frame header data. This may be null if the data is part of a an existing frame * @param frameData The frame data */ protected void dataReady(FrameHeaderData headerData, PooledByteBuffer frameData) { if(anyAreSet(state, STATE_STREAM_BROKEN | STATE_CLOSED)) { frameData.close(); return; } synchronized (lock) { boolean newData = pendingFrameData.isEmpty(); this.pendingFrameData.add(new FrameData(headerData, frameData)); if (newData) { if (waiters > 0) { lock.notifyAll(); } } waitingForFrame = false; } if (anyAreSet(state, STATE_READS_RESUMED)) { resumeReadsInternal(true); } if(headerData != null) { currentStreamSize += headerData.getFrameLength(); if(maxStreamSize > 0 && currentStreamSize > maxStreamSize) { handleStreamTooLarge(); } } } protected long updateFrameDataRemaining(PooledByteBuffer frameData, long frameDataRemaining) { return frameDataRemaining; } protected PooledByteBuffer processFrameData(PooledByteBuffer data, boolean lastFragmentOfFrame) throws IOException { return data; } protected void handleHeaderData(FrameHeaderData headerData) { } @Override public XnioExecutor getReadThread() { return framedChannel.getIoThread(); } @Override public ChannelListener.Setter<? extends R> getReadSetter() { return readSetter; } @Override public ChannelListener.Setter<? extends R> getCloseSetter() { return closeSetter; } @Override public XnioWorker getWorker() { return framedChannel.getWorker(); } @Override public XnioIoThread getIoThread() { return framedChannel.getIoThread(); } @Override public boolean supportsOption(Option<?> option) { return false; } @Override public <T> T getOption(Option<T> tOption) throws IOException { return null; } @Override public <T> T setOption(Option<T> tOption, T t) throws IllegalArgumentException, IOException { return null; } @Override public long read(ByteBuffer[] dsts, int offset, int length) throws IOException { if (anyAreSet(state, STATE_DONE)) { return -1; } beforeRead(); if (waitingForFrame) { return 0; } try { if (frameDataRemaining == 0 && anyAreSet(state, STATE_LAST_FRAME)) { synchronized (lock) { state |= STATE_RETURNED_MINUS_ONE; } return -1; } else if (data != null) { int old = data.getBuffer().limit(); try { long count = Buffers.remaining(dsts, offset, length); if (count < data.getBuffer().remaining()) { data.getBuffer().limit((int) (data.getBuffer().position() + count)); } else { count = data.getBuffer().remaining(); } return Buffers.copy((int) count, dsts, offset, length, data.getBuffer()); } finally { data.getBuffer().limit(old); decrementFrameDataRemaining(); } } return 0; } finally { exitRead(); } } @Override public long read(ByteBuffer[] dsts) throws IOException { return read(dsts, 0, dsts.length); } @Override public int read(ByteBuffer dst) throws IOException { if (anyAreSet(state, STATE_DONE)) { return -1; } if (!dst.hasRemaining()) { return 0; } beforeRead(); if (waitingForFrame) { return 0; } try { if (frameDataRemaining == 0 && anyAreSet(state, STATE_LAST_FRAME)) { synchronized (lock) { state |= STATE_RETURNED_MINUS_ONE; } return -1; } else if (data != null) { int old = data.getBuffer().limit(); try { int count = dst.remaining(); if (count < data.getBuffer().remaining()) { data.getBuffer().limit(data.getBuffer().position() + count); } else { count = data.getBuffer().remaining(); } return Buffers.copy(count, dst, data.getBuffer()); } finally { data.getBuffer().limit(old); decrementFrameDataRemaining(); } } return 0; } finally { try { exitRead(); } catch (Exception e) { markStreamBroken(); } } } private void beforeRead() throws IOException { if (anyAreSet(state, STATE_STREAM_BROKEN)) { throw UndertowMessages.MESSAGES.channelIsClosed(); } if (data == null) { synchronized (lock) { FrameData pending = pendingFrameData.poll(); if (pending != null) { PooledByteBuffer frameData = pending.getFrameData(); boolean hasData = true; if(!frameData.getBuffer().hasRemaining()) { frameData.close(); hasData = false; } if (pending.getFrameHeaderData() != null) { this.frameDataRemaining = pending.getFrameHeaderData().getFrameLength(); handleHeaderData(pending.getFrameHeaderData()); } if(hasData) { this.frameDataRemaining = updateFrameDataRemaining(frameData, frameDataRemaining); this.currentDataOriginalSize = frameData.getBuffer().remaining(); try { this.data = processFrameData(frameData, frameDataRemaining - currentDataOriginalSize == 0); } catch (Exception e) { frameData.close(); UndertowLogger.REQUEST_IO_LOGGER.ioException(new IOException(e)); markStreamBroken(); } } } } } } private void exitRead() throws IOException { if (data != null && !data.getBuffer().hasRemaining()) { data.close(); data = null; } if (frameDataRemaining == 0) { try { synchronized (lock) { readFrameCount++; if (pendingFrameData.isEmpty()) { if (anyAreSet(state, STATE_RETURNED_MINUS_ONE)) { state |= STATE_DONE; getFramedChannel().notifyClosed(this); complete(); close(); } else if(anyAreSet(state, STATE_LAST_FRAME)) { state |= STATE_WAITNG_MINUS_ONE; } else { waitingForFrame = true; } } } } finally { if (pendingFrameData.isEmpty()) { framedChannel.notifyFrameReadComplete(this); } } } } @Override public boolean isOpen() { return allAreClear(state, STATE_CLOSED); } @Override public void close() { if(anyAreSet(state, STATE_CLOSED)) { return; } synchronized (lock) { state |= STATE_CLOSED; if (allAreClear(state, STATE_DONE | STATE_LAST_FRAME)) { state |= STATE_STREAM_BROKEN; getFramedChannel().notifyClosed(this); channelForciblyClosed(); } if (data != null) { data.close(); data = null; } while (!pendingFrameData.isEmpty()) { pendingFrameData.poll().frameData.close(); } ChannelListeners.invokeChannelListener(this, (ChannelListener<? super AbstractFramedStreamSourceChannel<C, R, S>>) closeSetter.get()); if (closeListeners != null) { for (int i = 0; i < closeListeners.length; ++i) { closeListeners[i].handleEvent(this); } } } } protected void channelForciblyClosed() { //TODO: what should be the default action? //we can probably just ignore it, as it does not affect the underlying protocol } protected C getFramedChannel() { return framedChannel; } protected int getReadFrameCount() { return readFrameCount; } /** * Called when this stream is no longer valid. Reads from the stream will result * in an exception. */ protected void markStreamBroken() { if(anyAreSet(state, STATE_STREAM_BROKEN)) { return; } synchronized (lock) { state |= STATE_STREAM_BROKEN; PooledByteBuffer data = this.data; if(data != null) { try { data.close(); //may have been closed by the read thread } catch (Exception e) { //ignore } this.data = null; } for(FrameData frame : pendingFrameData) { frame.frameData.close(); } pendingFrameData.clear(); getFramedChannel().notifyClosed(this); if(isReadResumed()) { resumeReadsInternal(true); } if (waiters > 0) { lock.notifyAll(); } } } private class FrameData { private final FrameHeaderData frameHeaderData; private final PooledByteBuffer frameData; FrameData(FrameHeaderData frameHeaderData, PooledByteBuffer frameData) { this.frameHeaderData = frameHeaderData; this.frameData = frameData; } FrameHeaderData getFrameHeaderData() { return frameHeaderData; } PooledByteBuffer getFrameData() { return frameData; } } }