/* * JBoss, Home of Professional Open Source. * Copyright 2014 Red Hat, Inc., and individual contributors * as indicated by the @author tags. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package io.undertow.conduits; import io.undertow.UndertowLogger; import org.xnio.Buffers; import org.xnio.channels.FixedLengthOverflowException; import org.xnio.channels.StreamSourceChannel; import org.xnio.conduits.AbstractStreamSinkConduit; import org.xnio.conduits.Conduits; import org.xnio.conduits.StreamSinkConduit; import java.io.IOException; import java.nio.ByteBuffer; import java.nio.channels.ClosedChannelException; import java.nio.channels.FileChannel; import java.util.concurrent.TimeUnit; import static java.lang.Math.min; import static org.xnio.Bits.allAreClear; import static org.xnio.Bits.allAreSet; import static org.xnio.Bits.anyAreSet; import static org.xnio.Bits.longBitMask; /** * A channel which writes a fixed amount of data. A listener is called once the data has been written. * * @author <a href="mailto:david.lloyd@redhat.com">David M. Lloyd</a> */ public abstract class AbstractFixedLengthStreamSinkConduit extends AbstractStreamSinkConduit<StreamSinkConduit> { private int config; private long state; private boolean broken = false; private static final int CONF_FLAG_CONFIGURABLE = 1 << 0; private static final int CONF_FLAG_PASS_CLOSE = 1 << 1; private static final long FLAG_CLOSE_REQUESTED = 1L << 63L; private static final long FLAG_CLOSE_COMPLETE = 1L << 62L; private static final long FLAG_FINISHED_CALLED = 1L << 61L; private static final long MASK_COUNT = longBitMask(0, 60); /** * Construct a new instance. * * @param next the next channel * @param contentLength the content length * @param configurable {@code true} if this instance should pass configuration to the next * @param propagateClose {@code true} if this instance should pass close to the next */ public AbstractFixedLengthStreamSinkConduit(final StreamSinkConduit next, final long contentLength, final boolean configurable, final boolean propagateClose) { super(next); if (contentLength < 0L) { throw new IllegalArgumentException("Content length must be greater than or equal to zero"); } else if (contentLength > MASK_COUNT) { throw new IllegalArgumentException("Content length is too long"); } config = (configurable ? CONF_FLAG_CONFIGURABLE : 0) | (propagateClose ? CONF_FLAG_PASS_CLOSE : 0); this.state = contentLength; } protected void reset(long contentLength, boolean propagateClose) { this.state = contentLength; if (propagateClose) { config |= CONF_FLAG_PASS_CLOSE; } else { config &= ~CONF_FLAG_PASS_CLOSE; } } public int write(final ByteBuffer src) throws IOException { long val = state; final long remaining = val & MASK_COUNT; if (!src.hasRemaining()) { return 0; } if (allAreSet(val, FLAG_CLOSE_REQUESTED)) { throw new ClosedChannelException(); } int oldLimit = src.limit(); if (remaining == 0) { throw new FixedLengthOverflowException(); } else if (src.remaining() > remaining) { src.limit((int) (src.position() + remaining)); } int res = 0; try { return res = next.write(src); } catch (IOException e) { broken = true; throw e; } finally { src.limit(oldLimit); exitWrite(val, (long) res); } } public long write(final ByteBuffer[] srcs, final int offset, final int length) throws IOException { if (length == 0) { return 0L; } else if (length == 1) { return write(srcs[offset]); } long val = state; final long remaining = val & MASK_COUNT; if (allAreSet(val, FLAG_CLOSE_REQUESTED)) { throw new ClosedChannelException(); } long toWrite = Buffers.remaining(srcs, offset, length); if (remaining == 0) { throw new FixedLengthOverflowException(); } int[] limits = null; if (toWrite > remaining) { limits = new int[length]; long r = remaining; for (int i = offset; i < offset + length; ++i) { limits[i - offset] = srcs[i].limit(); int br = srcs[i].remaining(); if(br < r) { r -= br; } else { srcs[i].limit((int) (srcs[i].position() + r)); r = 0; } } } long res = 0L; try { return res = next.write(srcs, offset, length); } catch (IOException e) { broken = true; throw e; } finally { if (limits != null) { for (int i = offset; i < offset + length; ++i) { srcs[i].limit(limits[i - offset]); } } exitWrite(val, res); } } @Override public long writeFinal(ByteBuffer[] srcs, int offset, int length) throws IOException { try { return Conduits.writeFinalBasic(this, srcs, offset, length); } catch (IOException e) { broken = true; throw e; } } @Override public int writeFinal(ByteBuffer src) throws IOException { try { return Conduits.writeFinalBasic(this, src); } catch (IOException e) { broken = true; throw e; } } public long transferFrom(final FileChannel src, final long position, final long count) throws IOException { if (count == 0L) return 0L; long val = state; if (allAreSet(val, FLAG_CLOSE_REQUESTED)) { throw new ClosedChannelException(); } if (allAreClear(val, MASK_COUNT)) { throw new FixedLengthOverflowException(); } long res = 0L; try { return res = next.transferFrom(src, position, min(count, (val & MASK_COUNT))); } catch (IOException e) { broken = true; throw e; } finally { exitWrite(val, res); } } public long transferFrom(final StreamSourceChannel source, final long count, final ByteBuffer throughBuffer) throws IOException { if (count == 0L) return 0L; long val = state; if (allAreSet(val, FLAG_CLOSE_REQUESTED)) { throw new ClosedChannelException(); } if (allAreClear(val, MASK_COUNT)) { throw new FixedLengthOverflowException(); } long res = 0L; try { return res = next.transferFrom(source, min(count, (val & MASK_COUNT)), throughBuffer); } catch (IOException e) { broken = true; throw e; } finally { exitWrite(val, res); } } public boolean flush() throws IOException { long val = state; if (anyAreSet(val, FLAG_CLOSE_COMPLETE)) { return true; } boolean flushed = false; try { return flushed = next.flush(); } catch (IOException e) { broken = true; throw e; } finally { exitFlush(val, flushed); } } public boolean isWriteResumed() { // not perfect but not provably wrong either... return allAreClear(state, FLAG_CLOSE_COMPLETE) && next.isWriteResumed(); } public void wakeupWrites() { long val = state; if (anyAreSet(val, FLAG_CLOSE_COMPLETE)) { return; } next.wakeupWrites(); } public void terminateWrites() throws IOException { final long val = enterShutdown(); if (anyAreSet(val, MASK_COUNT) && !broken) { UndertowLogger.REQUEST_IO_LOGGER.debugf("Fixed length stream closed with with %s bytes remaining", val & MASK_COUNT); try { next.truncateWrites(); } finally { if (!anyAreSet(state, FLAG_FINISHED_CALLED)) { state |= FLAG_FINISHED_CALLED; channelFinished(); } } } else if (allAreSet(config, CONF_FLAG_PASS_CLOSE)) { next.terminateWrites(); } } @Override public void truncateWrites() throws IOException { try { if (!anyAreSet(state, FLAG_FINISHED_CALLED)) { state |= FLAG_FINISHED_CALLED; channelFinished(); } } finally { super.truncateWrites(); } } public void awaitWritable() throws IOException { next.awaitWritable(); } public void awaitWritable(final long time, final TimeUnit timeUnit) throws IOException { next.awaitWritable(time, timeUnit); } /** * Get the number of remaining bytes in this fixed length channel. * * @return the number of remaining bytes */ public long getRemaining() { return state & MASK_COUNT; } private void exitWrite(long oldVal, long consumed) { long newVal = oldVal - consumed; state = newVal; } private void exitFlush(long oldVal, boolean flushed) { long newVal = oldVal; boolean callFinish = false; if ((anyAreSet(oldVal, FLAG_CLOSE_REQUESTED) || (newVal & MASK_COUNT) == 0L) && flushed) { newVal |= FLAG_CLOSE_COMPLETE; if (!anyAreSet(oldVal, FLAG_FINISHED_CALLED) && (newVal & MASK_COUNT) == 0L) { newVal |= FLAG_FINISHED_CALLED; callFinish = true; } state = newVal; if (callFinish) { channelFinished(); } } } protected void channelFinished() { } private long enterShutdown() { long oldVal, newVal; oldVal = state; if (anyAreSet(oldVal, FLAG_CLOSE_REQUESTED | FLAG_CLOSE_COMPLETE)) { // no action necessary return oldVal; } newVal = oldVal | FLAG_CLOSE_REQUESTED; if (anyAreSet(oldVal, MASK_COUNT)) { // error: channel not filled. set both close flags. newVal |= FLAG_CLOSE_COMPLETE; } state = newVal; return oldVal; } }