/* * JBoss, Home of Professional Open Source. * Copyright 2014 Red Hat, Inc., and individual contributors * as indicated by the @author tags. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package io.undertow.io; import java.io.IOException; import java.nio.ByteBuffer; import java.nio.channels.FileChannel; import java.nio.charset.Charset; import java.nio.charset.StandardCharsets; import io.undertow.UndertowLogger; import io.undertow.UndertowMessages; import io.undertow.server.HttpServerExchange; import io.undertow.util.Headers; import org.xnio.Buffers; import org.xnio.ChannelExceptionHandler; import org.xnio.ChannelListener; import org.xnio.ChannelListeners; import org.xnio.IoUtils; import io.undertow.connector.PooledByteBuffer; import org.xnio.channels.StreamSinkChannel; /** * @author Stuart Douglas */ public class AsyncSenderImpl implements Sender { private StreamSinkChannel channel; private final HttpServerExchange exchange; private ByteBuffer[] buffer; private PooledByteBuffer[] pooledBuffers = null; private FileChannel fileChannel; private IoCallback callback; private boolean inCallback; private ChannelListener<StreamSinkChannel> writeListener; public class TransferTask implements Runnable, ChannelListener<StreamSinkChannel> { public boolean run(boolean complete) { try { FileChannel source = fileChannel; long pos = source.position(); long size = source.size(); StreamSinkChannel dest = channel; if (dest == null) { if (callback == IoCallback.END_EXCHANGE) { if (exchange.getResponseContentLength() == -1 && !exchange.getResponseHeaders().contains(Headers.TRANSFER_ENCODING)) { exchange.setResponseContentLength(size); } } channel = dest = exchange.getResponseChannel(); if (dest == null) { throw UndertowMessages.MESSAGES.responseChannelAlreadyProvided(); } } while (size - pos > 0) { long ret = dest.transferFrom(source, pos, size - pos); pos += ret; if (ret == 0) { source.position(pos); dest.getWriteSetter().set(this); dest.resumeWrites(); return false; } } if (complete) { invokeOnComplete(); } } catch (IOException e) { invokeOnException(callback, e); } return true; } @Override public void handleEvent(StreamSinkChannel channel) { channel.suspendWrites(); channel.getWriteSetter().set(null); exchange.dispatch(this); } @Override public void run() { run(true); } } private TransferTask transferTask; public AsyncSenderImpl(final HttpServerExchange exchange) { this.exchange = exchange; } @Override public void send(final ByteBuffer buffer, final IoCallback callback) { if (callback == null) { throw UndertowMessages.MESSAGES.argumentCannotBeNull("callback"); } if(exchange.isResponseComplete()) { throw UndertowMessages.MESSAGES.responseComplete(); } if (this.buffer != null || this.fileChannel != null) { throw UndertowMessages.MESSAGES.dataAlreadyQueued(); } long responseContentLength = exchange.getResponseContentLength(); if(responseContentLength > 0 && buffer.remaining() > responseContentLength) { invokeOnException(callback, UndertowLogger.ROOT_LOGGER.dataLargerThanContentLength(buffer.remaining(), responseContentLength)); return; } StreamSinkChannel channel = this.channel; if (channel == null) { if (callback == IoCallback.END_EXCHANGE) { if (responseContentLength == -1 && !exchange.getResponseHeaders().contains(Headers.TRANSFER_ENCODING)) { exchange.setResponseContentLength(buffer.remaining()); } } this.channel = channel = exchange.getResponseChannel(); if (channel == null) { throw UndertowMessages.MESSAGES.responseChannelAlreadyProvided(); } } this.callback = callback; if (inCallback) { this.buffer = new ByteBuffer[]{buffer}; return; } try { do { if (buffer.remaining() == 0) { callback.onComplete(exchange, this); return; } int res = channel.write(buffer); if (res == 0) { this.buffer = new ByteBuffer[]{buffer}; this.callback = callback; if(writeListener == null) { initWriteListener(); } channel.getWriteSetter().set(writeListener); channel.resumeWrites(); return; } } while (buffer.hasRemaining()); invokeOnComplete(); } catch (IOException e) { invokeOnException(callback, e); } } @Override public void send(final ByteBuffer[] buffer, final IoCallback callback) { if (callback == null) { throw UndertowMessages.MESSAGES.argumentCannotBeNull("callback"); } if(exchange.isResponseComplete()) { throw UndertowMessages.MESSAGES.responseComplete(); } if (this.buffer != null) { throw UndertowMessages.MESSAGES.dataAlreadyQueued(); } this.callback = callback; if (inCallback) { this.buffer = buffer; return; } long totalToWrite = Buffers.remaining(buffer); long responseContentLength = exchange.getResponseContentLength(); if(responseContentLength > 0 && totalToWrite > responseContentLength) { invokeOnException(callback, UndertowLogger.ROOT_LOGGER.dataLargerThanContentLength(totalToWrite, responseContentLength)); return; } StreamSinkChannel channel = this.channel; if (channel == null) { if (callback == IoCallback.END_EXCHANGE) { if (responseContentLength == -1 && !exchange.getResponseHeaders().contains(Headers.TRANSFER_ENCODING)) { exchange.setResponseContentLength(totalToWrite); } } this.channel = channel = exchange.getResponseChannel(); if (channel == null) { throw UndertowMessages.MESSAGES.responseChannelAlreadyProvided(); } } final long total = totalToWrite; long written = 0; try { do { long res = channel.write(buffer); written += res; if (res == 0) { this.buffer = buffer; this.callback = callback; if(writeListener == null) { initWriteListener(); } channel.getWriteSetter().set(writeListener); channel.resumeWrites(); return; } } while (written < total); invokeOnComplete(); } catch (IOException e) { invokeOnException(callback, e); } } @Override public void transferFrom(FileChannel source, IoCallback callback) { if (callback == null) { throw UndertowMessages.MESSAGES.argumentCannotBeNull("callback"); } if(exchange.isResponseComplete()) { throw UndertowMessages.MESSAGES.responseComplete(); } if (this.fileChannel != null || this.buffer != null) { throw UndertowMessages.MESSAGES.dataAlreadyQueued(); } this.callback = callback; this.fileChannel = source; if (inCallback) { return; } if(transferTask == null) { transferTask = new TransferTask(); } if (exchange.isInIoThread()) { exchange.dispatch(transferTask); return; } transferTask.run(); } @Override public void send(final ByteBuffer buffer) { send(buffer, IoCallback.END_EXCHANGE); } @Override public void send(final ByteBuffer[] buffer) { send(buffer, IoCallback.END_EXCHANGE); } @Override public void send(final String data, final IoCallback callback) { send(data, StandardCharsets.UTF_8, callback); } @Override public void send(final String data, final Charset charset, final IoCallback callback) { if(exchange.isResponseComplete()) { throw UndertowMessages.MESSAGES.responseComplete(); } ByteBuffer bytes = ByteBuffer.wrap(data.getBytes(charset)); if (bytes.remaining() == 0) { callback.onComplete(exchange, this); } else { int i = 0; ByteBuffer[] bufs = null; while (bytes.hasRemaining()) { PooledByteBuffer pooled = exchange.getConnection().getByteBufferPool().allocate(); if (bufs == null) { int noBufs = (bytes.remaining() + pooled.getBuffer().remaining() - 1) / pooled.getBuffer().remaining(); //round up division trick pooledBuffers = new PooledByteBuffer[noBufs]; bufs = new ByteBuffer[noBufs]; } pooledBuffers[i] = pooled; bufs[i] = pooled.getBuffer(); Buffers.copy(pooled.getBuffer(), bytes); pooled.getBuffer().flip(); ++i; } send(bufs, callback); } } @Override public void send(final String data) { send(data, IoCallback.END_EXCHANGE); } @Override public void send(final String data, final Charset charset) { send(data, charset, IoCallback.END_EXCHANGE); } @Override public void close(final IoCallback callback) { try { StreamSinkChannel channel = this.channel; if (channel == null) { if (exchange.getResponseContentLength() == -1 && !exchange.getResponseHeaders().contains(Headers.TRANSFER_ENCODING)) { exchange.setResponseContentLength(0); } this.channel = channel = exchange.getResponseChannel(); if (channel == null) { throw UndertowMessages.MESSAGES.responseChannelAlreadyProvided(); } } channel.shutdownWrites(); if (!channel.flush()) { channel.getWriteSetter().set(ChannelListeners.flushingChannelListener( new ChannelListener<StreamSinkChannel>() { @Override public void handleEvent(final StreamSinkChannel channel) { if(callback != null) { callback.onComplete(exchange, AsyncSenderImpl.this); } } }, new ChannelExceptionHandler<StreamSinkChannel>() { @Override public void handleException(final StreamSinkChannel channel, final IOException exception) { try { if(callback != null) { invokeOnException(callback, exception); } } finally { IoUtils.safeClose(channel); } } } )); channel.resumeWrites(); } else { if (callback != null) { callback.onComplete(exchange, this); } } } catch (IOException e) { if (callback != null) { invokeOnException(callback, e); } } } @Override public void close() { close(null); } /** * Invokes the onComplete method. If send is called again in onComplete then * we loop and write it out. This prevents possible stack overflows due to recursion */ private void invokeOnComplete() { for (; ; ) { if (pooledBuffers != null) { for (PooledByteBuffer buffer : pooledBuffers) { buffer.close(); } pooledBuffers = null; } IoCallback callback = this.callback; this.buffer = null; this.fileChannel = null; this.callback = null; inCallback = true; try { callback.onComplete(exchange, this); } finally { inCallback = false; } StreamSinkChannel channel = this.channel; if (this.buffer != null) { long t = Buffers.remaining(buffer); final long total = t; long written = 0; try { do { long res = channel.write(buffer); written += res; if (res == 0) { if(writeListener == null) { initWriteListener(); } channel.getWriteSetter().set(writeListener); channel.resumeWrites(); return; } } while (written < total); //we loop and invoke onComplete again } catch (IOException e) { invokeOnException(callback, e); } } else if (this.fileChannel != null) { if(transferTask == null) { transferTask = new TransferTask(); } if (!transferTask.run(false)) { return; } } else { return; } } } private void invokeOnException(IoCallback callback, IOException e) { if (pooledBuffers != null) { for (PooledByteBuffer buffer : pooledBuffers) { buffer.close(); } pooledBuffers = null; } callback.onException(exchange, this, e); } private void initWriteListener() { writeListener = new ChannelListener<StreamSinkChannel>() { @Override public void handleEvent(final StreamSinkChannel streamSinkChannel) { try { long toWrite = Buffers.remaining(buffer); long written = 0; while (written < toWrite) { long res = streamSinkChannel.write(buffer, 0, buffer.length); written += res; if (res == 0) { return; } } streamSinkChannel.suspendWrites(); invokeOnComplete(); } catch (IOException e) { streamSinkChannel.suspendWrites(); invokeOnException(callback, e); } } }; } }