package org.arquillian.cube.kubernetes.impl.portforward;
import io.undertow.UndertowMessages;
import io.undertow.client.ClientCallback;
import io.undertow.client.ClientConnection;
import io.undertow.client.ClientExchange;
import io.undertow.client.ClientRequest;
import io.undertow.connector.ByteBufferPool;
import io.undertow.protocols.spdy.SpdyStreamStreamSinkChannel;
import io.undertow.server.AbstractServerConnection;
import io.undertow.server.HttpServerExchange;
import io.undertow.server.HttpUpgradeListener;
import io.undertow.server.SSLSessionInfo;
import io.undertow.util.HttpString;
import io.undertow.util.Methods;
import io.undertow.util.StringReadChannelListener;
import java.io.IOException;
import java.util.Timer;
import java.util.TimerTask;
import java.util.concurrent.CountDownLatch;
import org.xnio.ChainedChannelListener;
import org.xnio.ChannelListener;
import org.xnio.IoUtils;
import org.xnio.OptionMap;
import org.xnio.StreamConnection;
import org.xnio.channels.CloseableChannel;
import org.xnio.conduits.StreamSinkConduit;
/**
* PortForwardServerConnection
*
* @author Rob Cernich
*/
public class PortForwardServerConnection extends AbstractServerConnection {
private final CountDownLatch errorComplete = new CountDownLatch(1);
private final CountDownLatch requestComplete = new CountDownLatch(1);
/**
* Create a new PortForwardServerConnection.
*/
public PortForwardServerConnection(StreamConnection channel, ByteBufferPool bufferPool, OptionMap undertowOptions,
int bufferSize) {
super(channel, bufferPool, null, undertowOptions, bufferSize);
}
@Override
public HttpServerExchange sendOutOfBandResponse(HttpServerExchange exchange) {
throw new UnsupportedOperationException("PortForward connection does not support HTTP!");
}
@Override
public void terminateRequestChannel(HttpServerExchange exchange) {
throw new UnsupportedOperationException("PortForward connection does not support HTTP!");
}
@Override
public SSLSessionInfo getSslSessionInfo() {
// We're not supporting SSL
return null;
}
@Override
public void setSslSessionInfo(SSLSessionInfo sessionInfo) {
throw new UnsupportedOperationException("PortForward connection does not support SSL!");
}
@Override
protected StreamConnection upgradeChannel() {
throw UndertowMessages.MESSAGES.upgradeNotSupported();
}
@Override
protected StreamSinkConduit getSinkConduit(HttpServerExchange exchange, StreamSinkConduit conduit) {
return conduit;
}
@Override
protected boolean isUpgradeSupported() {
return false;
}
@Override
protected void exchangeComplete(HttpServerExchange exchange) {
// We're not supporting HTTP so nothing to do here
}
@Override
public String getTransportProtocol() {
return "raw";
}
@Override
protected boolean isConnectSupported() {
return false;
}
@Override
public boolean isContinueResponseSupported() {
return false;
}
@Override
protected void setConnectListener(HttpUpgradeListener connectListener) {
}
public void startForwarding(final ClientConnection clientConnection, final String urlPath, final int targetPort,
final int requestId) throws IOException {
try {
// initiate the streams
openErrorStream(clientConnection, urlPath, targetPort, requestId);
openDataStream(clientConnection, urlPath, targetPort, requestId);
try {
/*
* wait for the request to complete. this will trigger when the
* client is done and the request stream closes.
*/
requestComplete.await();
/* wait for the response on the error stream. */
errorComplete.await();
} catch (InterruptedException e) {
e.printStackTrace();
}
} catch (Exception e) {
e.printStackTrace();
} finally {
IoUtils.safeClose(this);
}
}
private void openErrorStream(final ClientConnection clientConnection, final String urlPath, final int targetPort,
final int requestId) throws IOException {
ClientRequest request = new ClientRequest()
.setMethod(Methods.POST)
.setPath(urlPath);
request.getRequestHeaders()
.put(new HttpString("streamType"), "error")
.put(new HttpString("port"), targetPort)
.put(new HttpString("requestID"), requestId);
final CountDownLatch latch = new CountDownLatch(1);
final IOException[] holder = new IOException[1];
clientConnection.sendRequest(request, new ClientCallback<ClientExchange>() {
@Override
public void failed(IOException e) {
holder[0] = e;
latch.countDown();
}
@Override
public void completed(final ClientExchange result) {
latch.countDown();
result.setResponseListener(new ClientCallback<ClientExchange>() {
@Override
public void completed(final ClientExchange result) {
// read the error, if any
new StringReadChannelListener(getByteBufferPool()) {
@Override
protected void stringDone(String string) {
setError(string);
}
@Override
protected void error(IOException e) {
setError(e.getMessage());
}
}.setup(result.getResponseChannel());
}
@Override
public void failed(IOException e) {
setError(e.getMessage());
}
});
}
});
try {
// wait for the request to be sent
latch.await();
} catch (InterruptedException e) {
}
if (holder[0] != null) {
throw holder[0];
}
}
private void openDataStream(final ClientConnection clientConnection, final String urlPath, final int targetPort,
final int requestId) throws IOException {
ClientRequest request = new ClientRequest()
.setMethod(Methods.POST)
.setPath(urlPath);
request.getRequestHeaders()
.put(new HttpString("streamType"), "data")
.put(new HttpString("port"), targetPort)
.put(new HttpString("requestID"), requestId);
final CountDownLatch latch = new CountDownLatch(1);
final IOException[] holder = new IOException[1];
final Timer timer = new Timer("SPDY Keep Alive", true);
getChannel().getCloseSetter()
.set(new ChainedChannelListener<CloseableChannel>(
new CancelTimerChannelListener(timer),
new LatchReleaseChannelListener(requestComplete)));
clientConnection.sendRequest(request, new ClientCallback<ClientExchange>() {
@Override
public void failed(IOException e) {
holder[0] = e;
latch.countDown();
}
@Override
public void completed(final ClientExchange result) {
latch.countDown();
result.setResponseListener(new ClientCallback<ClientExchange>() {
@Override
public void completed(final ClientExchange result) {
result.getResponseChannel()
.getCloseSetter()
.set(new LatchReleaseChannelListener(requestComplete));
getIoThread().execute(new Runnable() {
@Override
public void run() {
// read from remote
ChannelUtils.initiateTransfer(
Long.MAX_VALUE,
result.getResponseChannel(),
getChannel().getSinkChannel(),
getBufferPool());
}
});
}
@Override
public void failed(IOException e) {
requestComplete.countDown();
}
});
// write to remote
ChannelUtils.initiateTransfer(
Long.MAX_VALUE,
getChannel().getSourceChannel(),
result.getRequestChannel(),
getBufferPool());
// keep alive
timer.scheduleAtFixedRate(new PingSpdyStream((SpdyStreamStreamSinkChannel) result.getRequestChannel()),
15000, 15000); // OSE times out after 30s
// need to wait for the client to close the request channel
try {
requestComplete.await();
} catch (InterruptedException e) {
e.printStackTrace();
}
}
});
try {
// wait for the request to be sent
latch.await();
} catch (InterruptedException e) {
}
if (holder[0] != null) {
throw holder[0];
}
}
private void setError(String error) {
if (error != null && !error.trim().equals("")) {
System.err.println("Port forwarding error: " + error);
}
errorComplete.countDown();
}
private static final class CancelTimerChannelListener implements ChannelListener<CloseableChannel> {
private final Timer timer;
private CancelTimerChannelListener(Timer timer) {
this.timer = timer;
}
@Override
public void handleEvent(CloseableChannel channel) {
timer.cancel();
}
}
private static final class LatchReleaseChannelListener implements ChannelListener<CloseableChannel> {
private final CountDownLatch latch;
private LatchReleaseChannelListener(CountDownLatch latch) {
this.latch = latch;
}
@Override
public void handleEvent(CloseableChannel channel) {
latch.countDown();
}
}
private final class PingSpdyStream extends TimerTask {
private final SpdyStreamStreamSinkChannel stream;
private PingSpdyStream(SpdyStreamStreamSinkChannel stream) {
super();
this.stream = stream;
}
@Override
public void run() {
getWorker().execute(new Runnable() {
@Override
public void run() {
if (stream.isOpen()) {
stream.getChannel().sendPing(stream.getStreamId());
}
}
});
}
}
}