package org.arquillian.cube.kubernetes.impl.portforward; import java.io.EOFException; import java.io.IOException; import java.nio.ByteBuffer; import java.nio.channels.Channel; import org.xnio.ChannelExceptionHandler; import org.xnio.ChannelListener; import org.xnio.ChannelListeners; import org.xnio.Pool; import org.xnio.Pooled; import org.xnio.channels.Channels; import org.xnio.channels.StreamSinkChannel; import org.xnio.channels.StreamSourceChannel; import static org.xnio._private.Messages.msg; public final class ChannelUtils { private ChannelUtils() { } public static <I extends StreamSourceChannel, O extends StreamSinkChannel> void initiateTransfer(long count, final I source, final O sink, Pool<ByteBuffer> pool) { ChannelUtils.initiateTransfer( count, source, sink, ChannelListeners.closingChannelListener(), ChannelListeners.<StreamSinkChannel>writeShutdownChannelListener( ChannelListeners.closingChannelListener(), ChannelListeners.closingChannelExceptionHandler()), ChannelListeners.closingChannelExceptionHandler(), ChannelListeners.closingChannelExceptionHandler(), pool); } /** * This is basically a copy of ChannelListeners.initiateTransfer(), but invokes a flush() on the sink after writing is * complete. * <p> * Initiate a low-copy transfer between two stream channels. The pool should be a direct buffer pool for best * performance. * * @param count * the number of bytes to transfer, or {@link Long#MAX_VALUE} to transfer all remaining bytes * @param source * the source channel * @param sink * the target channel * @param sourceListener * the source listener to set and call when the transfer is complete, or {@code null} to clear the listener at * that time * @param sinkListener * the target listener to set and call when the transfer is complete, or {@code null} to clear the listener at * that time * @param readExceptionHandler * the read exception handler to call if an error occurs during a read operation * @param writeExceptionHandler * the write exception handler to call if an error occurs during a write operation * @param pool * the pool from which the transfer buffer should be allocated */ public static <I extends StreamSourceChannel, O extends StreamSinkChannel> void initiateTransfer(long count, final I source, final O sink, final ChannelListener<? super I> sourceListener, final ChannelListener<? super O> sinkListener, final ChannelExceptionHandler<? super I> readExceptionHandler, final ChannelExceptionHandler<? super O> writeExceptionHandler, Pool<ByteBuffer> pool) { if (pool == null) { throw msg.nullParameter("pool"); } final Pooled<ByteBuffer> allocated = pool.allocate(); boolean free = true; try { final ByteBuffer buffer = allocated.getResource(); long transferred; for (; ; ) { try { transferred = source.transferTo(count, buffer, sink); } catch (IOException e) { ChannelListeners.invokeChannelExceptionHandler(source, readExceptionHandler, e); return; } if (transferred == 0 && !buffer.hasRemaining()) { break; } if (transferred == -1) { if (count == Long.MAX_VALUE) { Channels.setReadListener(source, sourceListener); if (sourceListener == null) { source.suspendReads(); } else { source.wakeupReads(); } Channels.setWriteListener(sink, sinkListener); if (sinkListener == null) { sink.suspendWrites(); } else { sink.wakeupWrites(); } } else { source.suspendReads(); sink.suspendWrites(); ChannelListeners.invokeChannelExceptionHandler(source, readExceptionHandler, new EOFException()); } return; } if (count != Long.MAX_VALUE) { count -= transferred; } while (buffer.hasRemaining()) { final int res; try { res = sink.write(buffer); } catch (IOException e) { ChannelListeners.invokeChannelExceptionHandler(sink, writeExceptionHandler, e); return; } if (res == 0) { // write first listener final TransferListener<I, O> listener = new TransferListener<I, O>(count, allocated, source, sink, sourceListener, sinkListener, writeExceptionHandler, readExceptionHandler, 1); source.suspendReads(); source.getReadSetter().set(listener); // flush the write channel sink.getWriteSetter() .set(ChannelListeners.flushingChannelListener(listener, new ChannelExceptionHandler<Channel>() { @Override public void handleException(Channel channel, IOException exception) { listener.writeFailed(exception); } })); sink.resumeWrites(); free = false; return; } else if (count != Long.MAX_VALUE) { count -= res; } } if (count == 0) { //we are done Channels.setReadListener(source, sourceListener); if (sourceListener == null) { source.suspendReads(); } else { source.wakeupReads(); } Channels.setWriteListener(sink, sinkListener); if (sinkListener == null) { sink.suspendWrites(); } else { sink.wakeupWrites(); } return; } } // flush the write channel try { // TODO: think about checking to see if a flush has already been issued sink.isReadyForFlush() if (!sink.flush()) { final TransferListener<I, O> listener = new TransferListener<I, O>(count, allocated, source, sink, sourceListener, sinkListener, writeExceptionHandler, readExceptionHandler, 1); source.suspendReads(); source.getReadSetter().set(listener); // flush the write channel sink.getWriteSetter() .set(ChannelListeners.flushingChannelListener(listener, new ChannelExceptionHandler<Channel>() { @Override public void handleException(Channel channel, IOException exception) { listener.writeFailed(exception); } })); sink.resumeWrites(); free = false; return; } } catch (IOException e) { ChannelListeners.invokeChannelExceptionHandler(sink, writeExceptionHandler, e); return; } // read first listener final TransferListener<I, O> listener = new TransferListener<I, O>(count, allocated, source, sink, sourceListener, sinkListener, writeExceptionHandler, readExceptionHandler, 0); sink.suspendWrites(); sink.getWriteSetter().set(listener); source.getReadSetter().set(listener); source.resumeReads(); free = false; return; } finally { if (free) allocated.free(); } } static final class TransferListener<I extends StreamSourceChannel, O extends StreamSinkChannel> implements ChannelListener<Channel> { private final Pooled<ByteBuffer> pooledBuffer; private final I source; private final O sink; private final ChannelListener<? super I> sourceListener; private final ChannelListener<? super O> sinkListener; private final ChannelExceptionHandler<? super O> writeExceptionHandler; private final ChannelExceptionHandler<? super I> readExceptionHandler; private long count; private volatile int state; TransferListener(final long count, final Pooled<ByteBuffer> pooledBuffer, final I source, final O sink, final ChannelListener<? super I> sourceListener, final ChannelListener<? super O> sinkListener, final ChannelExceptionHandler<? super O> writeExceptionHandler, final ChannelExceptionHandler<? super I> readExceptionHandler, final int state) { this.count = count; this.pooledBuffer = pooledBuffer; this.source = source; this.sink = sink; this.sourceListener = sourceListener; this.sinkListener = sinkListener; this.writeExceptionHandler = writeExceptionHandler; this.readExceptionHandler = readExceptionHandler; this.state = state; } public void handleEvent(final Channel channel) { final ByteBuffer buffer = pooledBuffer.getResource(); int state = this.state; // always read after and write before state long count = this.count; long lres; int ires; switch (state) { case 0: { // read listener for (; ; ) { boolean needsFlush = false; try { lres = source.transferTo(count, buffer, sink); } catch (IOException e) { readFailed(e); return; } if (lres == 0 && !buffer.hasRemaining()) { this.count = count; source.resumeReads(); return; } if (lres == -1) { // possibly unexpected EOF if (count == Long.MAX_VALUE) { // it's OK; just be done done(); return; } else { readFailed(new EOFException()); return; } } if (count != Long.MAX_VALUE) { count -= lres; } needsFlush = needsFlush || lres > 0; while (buffer.hasRemaining()) { try { ires = sink.write(buffer); } catch (IOException e) { writeFailed(e); return; } if (count != Long.MAX_VALUE) { count -= ires; } needsFlush = needsFlush || ires > 0; if (ires == 0) { // flush the write channel try { if (needsFlush && !sink.flush()) { sink.getWriteSetter() .set(ChannelListeners.flushingChannelListener(this, new ChannelExceptionHandler<Channel>() { @Override public void handleException(Channel channel, IOException exception) { writeFailed(exception); } })); } } catch (IOException e) { writeFailed(e); return; } this.count = count; this.state = 1; source.suspendReads(); sink.resumeWrites(); return; } } if (count == 0) { done(); return; } // flush the write channel try { if (needsFlush && !sink.flush()) { sink.getWriteSetter() .set(ChannelListeners.flushingChannelListener(this, new ChannelExceptionHandler<Channel>() { @Override public void handleException(Channel channel, IOException exception) { writeFailed(exception); } })); this.count = count; this.state = 1; source.suspendReads(); sink.resumeWrites(); return; } } catch (IOException e) { writeFailed(e); return; } } } case 1: { // write listener boolean needsFlush = false; for (; ; ) { while (buffer.hasRemaining()) { try { ires = sink.write(buffer); } catch (IOException e) { writeFailed(e); return; } if (count != Long.MAX_VALUE) { count -= ires; } needsFlush = needsFlush || ires > 0; if (ires == 0) { // flush the write channel try { if (!sink.flush()) { sink.getWriteSetter() .set(ChannelListeners.flushingChannelListener(this, new ChannelExceptionHandler<Channel>() { @Override public void handleException(Channel channel, IOException exception) { writeFailed(exception); } })); } } catch (IOException e) { writeFailed(e); } source.suspendReads(); sink.resumeWrites(); return; } } try { lres = source.transferTo(count, buffer, sink); } catch (IOException e) { readFailed(e); return; } needsFlush = needsFlush || lres > 0; if (lres == 0 && !buffer.hasRemaining()) { // flush the write channel try { if (needsFlush && !sink.flush()) { sink.getWriteSetter() .set(ChannelListeners.flushingChannelListener(this, new ChannelExceptionHandler<Channel>() { @Override public void handleException(Channel channel, IOException exception) { writeFailed(exception); } })); this.count = count; this.state = 1; source.suspendReads(); sink.resumeWrites(); return; } } catch (IOException e) { writeFailed(e); return; } // need more data this.count = count; this.state = 0; sink.suspendWrites(); source.resumeReads(); return; } if (lres == -1) { // possibly unexpected EOF if (count == Long.MAX_VALUE) { // it's OK; just be done done(); return; } else { readFailed(new EOFException()); return; } } if (count != Long.MAX_VALUE) { count -= lres; } if (count == 0) { done(); return; } } } } } private void writeFailed(final IOException e) { try { source.suspendReads(); sink.suspendWrites(); ChannelListeners.invokeChannelExceptionHandler(sink, writeExceptionHandler, e); } finally { pooledBuffer.free(); } } private void readFailed(final IOException e) { try { source.suspendReads(); sink.suspendWrites(); ChannelListeners.invokeChannelExceptionHandler(source, readExceptionHandler, e); } finally { pooledBuffer.free(); } } private void done() { try { final ChannelListener<? super I> sourceListener = this.sourceListener; final ChannelListener<? super O> sinkListener = this.sinkListener; final I source = this.source; final O sink = this.sink; Channels.setReadListener(source, sourceListener); if (sourceListener == null) { source.suspendReads(); } else { source.wakeupReads(); } Channels.setWriteListener(sink, sinkListener); if (sinkListener == null) { sink.suspendWrites(); } else { sink.wakeupWrites(); } } finally { pooledBuffer.free(); } } public String toString() { return "Transfer channel listener (" + source + " to " + sink + ") -> (" + sourceListener + " and " + sinkListener + ")"; } } }