/* * Copyright 2014 The Netty Project * * The Netty Project licenses this file to you under the Apache License, * version 2.0 (the "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at: * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. */ package io.netty.handler.proxy; import io.netty.channel.Channel; import io.netty.channel.ChannelDuplexHandler; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPromise; import io.netty.channel.PendingWriteQueue; import io.netty.util.ReferenceCountUtil; import io.netty.util.concurrent.DefaultPromise; import io.netty.util.concurrent.EventExecutor; import io.netty.util.concurrent.Future; import io.netty.util.concurrent.ScheduledFuture; import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; import java.net.SocketAddress; import java.nio.channels.ConnectionPendingException; import java.util.concurrent.TimeUnit; public abstract class ProxyHandler extends ChannelDuplexHandler { private static final InternalLogger logger = InternalLoggerFactory.getInstance(ProxyHandler.class); /** * The default connect timeout: 10 seconds. */ private static final long DEFAULT_CONNECT_TIMEOUT_MILLIS = 10000; /** * A string that signifies 'no authentication' or 'anonymous'. */ static final String AUTH_NONE = "none"; private final SocketAddress proxyAddress; private volatile SocketAddress destinationAddress; private volatile long connectTimeoutMillis = DEFAULT_CONNECT_TIMEOUT_MILLIS; private volatile ChannelHandlerContext ctx; private PendingWriteQueue pendingWrites; private boolean finished; private boolean suppressChannelReadComplete; private boolean flushedPrematurely; private final LazyChannelPromise connectPromise = new LazyChannelPromise(); private ScheduledFuture<?> connectTimeoutFuture; private final ChannelFutureListener writeListener = new ChannelFutureListener() { @Override public void operationComplete(ChannelFuture future) throws Exception { if (!future.isSuccess()) { setConnectFailure(future.cause()); } } }; protected ProxyHandler(SocketAddress proxyAddress) { if (proxyAddress == null) { throw new NullPointerException("proxyAddress"); } this.proxyAddress = proxyAddress; } /** * Returns the name of the proxy protocol in use. */ public abstract String protocol(); /** * Returns the name of the authentication scheme in use. */ public abstract String authScheme(); /** * Returns the address of the proxy server. */ @SuppressWarnings("unchecked") public final <T extends SocketAddress> T proxyAddress() { return (T) proxyAddress; } /** * Returns the address of the destination to connect to via the proxy server. */ @SuppressWarnings("unchecked") public final <T extends SocketAddress> T destinationAddress() { return (T) destinationAddress; } /** * Returns {@code true} if and only if the connection to the destination has been established successfully. */ public final boolean isConnected() { return connectPromise.isSuccess(); } /** * Returns a {@link Future} that is notified when the connection to the destination has been established * or the connection attempt has failed. */ public final Future<Channel> connectFuture() { return connectPromise; } /** * Returns the connect timeout in millis. If the connection attempt to the destination does not finish within * the timeout, the connection attempt will be failed. */ public final long connectTimeoutMillis() { return connectTimeoutMillis; } /** * Sets the connect timeout in millis. If the connection attempt to the destination does not finish within * the timeout, the connection attempt will be failed. */ public final void setConnectTimeoutMillis(long connectTimeoutMillis) { if (connectTimeoutMillis <= 0) { connectTimeoutMillis = 0; } this.connectTimeoutMillis = connectTimeoutMillis; } @Override public final void handlerAdded(ChannelHandlerContext ctx) throws Exception { this.ctx = ctx; addCodec(ctx); if (ctx.channel().isActive()) { // channelActive() event has been fired already, which means this.channelActive() will // not be invoked. We have to initialize here instead. sendInitialMessage(ctx); } else { // channelActive() event has not been fired yet. this.channelOpen() will be invoked // and initialization will occur there. } } /** * Adds the codec handlers required to communicate with the proxy server. */ protected abstract void addCodec(ChannelHandlerContext ctx) throws Exception; /** * Removes the encoders added in {@link #addCodec(ChannelHandlerContext)}. */ protected abstract void removeEncoder(ChannelHandlerContext ctx) throws Exception; /** * Removes the decoders added in {@link #addCodec(ChannelHandlerContext)}. */ protected abstract void removeDecoder(ChannelHandlerContext ctx) throws Exception; @Override public final void connect( ChannelHandlerContext ctx, SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) throws Exception { if (destinationAddress != null) { promise.setFailure(new ConnectionPendingException()); return; } destinationAddress = remoteAddress; ctx.connect(proxyAddress, localAddress, promise); } @Override public final void channelActive(ChannelHandlerContext ctx) throws Exception { sendInitialMessage(ctx); ctx.fireChannelActive(); } /** * Sends the initial message to be sent to the proxy server. This method also starts a timeout task which marks * the {@link #connectPromise} as failure if the connection attempt does not success within the timeout. */ private void sendInitialMessage(final ChannelHandlerContext ctx) throws Exception { final long connectTimeoutMillis = this.connectTimeoutMillis; if (connectTimeoutMillis > 0) { connectTimeoutFuture = ctx.executor().schedule(new Runnable() { @Override public void run() { if (!connectPromise.isDone()) { setConnectFailure(new ProxyConnectException(exceptionMessage("timeout"))); } } }, connectTimeoutMillis, TimeUnit.MILLISECONDS); } final Object initialMessage = newInitialMessage(ctx); if (initialMessage != null) { sendToProxyServer(initialMessage); } readIfNeeded(ctx); } /** * Returns a new message that is sent at first time when the connection to the proxy server has been established. * * @return the initial message, or {@code null} if the proxy server is expected to send the first message instead */ protected abstract Object newInitialMessage(ChannelHandlerContext ctx) throws Exception; /** * Sends the specified message to the proxy server. Use this method to send a response to the proxy server in * {@link #handleResponse(ChannelHandlerContext, Object)}. */ protected final void sendToProxyServer(Object msg) { ctx.writeAndFlush(msg).addListener(writeListener); } @Override public final void channelInactive(ChannelHandlerContext ctx) throws Exception { if (finished) { ctx.fireChannelInactive(); } else { // Disconnected before connected to the destination. setConnectFailure(new ProxyConnectException(exceptionMessage("disconnected"))); } } @Override public final void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { if (finished) { ctx.fireExceptionCaught(cause); } else { // Exception was raised before the connection attempt is finished. setConnectFailure(cause); } } @Override public final void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { if (finished) { // Received a message after the connection has been established; pass through. suppressChannelReadComplete = false; ctx.fireChannelRead(msg); } else { suppressChannelReadComplete = true; Throwable cause = null; try { boolean done = handleResponse(ctx, msg); if (done) { setConnectSuccess(); } } catch (Throwable t) { cause = t; } finally { ReferenceCountUtil.release(msg); if (cause != null) { setConnectFailure(cause); } } } } /** * Handles the message received from the proxy server. * * @return {@code true} if the connection to the destination has been established, * {@code false} if the connection to the destination has not been established and more messages are * expected from the proxy server */ protected abstract boolean handleResponse(ChannelHandlerContext ctx, Object response) throws Exception; private void setConnectSuccess() { finished = true; cancelConnectTimeoutFuture(); if (!connectPromise.isDone()) { boolean removedCodec = true; removedCodec &= safeRemoveEncoder(); ctx.fireUserEventTriggered( new ProxyConnectionEvent(protocol(), authScheme(), proxyAddress, destinationAddress)); removedCodec &= safeRemoveDecoder(); if (removedCodec) { writePendingWrites(); if (flushedPrematurely) { ctx.flush(); } connectPromise.trySuccess(ctx.channel()); } else { // We are at inconsistent state because we failed to remove all codec handlers. Exception cause = new ProxyConnectException( "failed to remove all codec handlers added by the proxy handler; bug?"); failPendingWritesAndClose(cause); } } } private boolean safeRemoveDecoder() { try { removeDecoder(ctx); return true; } catch (Exception e) { logger.warn("Failed to remove proxy decoders:", e); } return false; } private boolean safeRemoveEncoder() { try { removeEncoder(ctx); return true; } catch (Exception e) { logger.warn("Failed to remove proxy encoders:", e); } return false; } private void setConnectFailure(Throwable cause) { finished = true; cancelConnectTimeoutFuture(); if (!connectPromise.isDone()) { if (!(cause instanceof ProxyConnectException)) { cause = new ProxyConnectException( exceptionMessage(cause.toString()), cause); } safeRemoveDecoder(); safeRemoveEncoder(); failPendingWritesAndClose(cause); } } private void failPendingWritesAndClose(Throwable cause) { failPendingWrites(cause); connectPromise.tryFailure(cause); ctx.fireExceptionCaught(cause); ctx.close(); } private void cancelConnectTimeoutFuture() { if (connectTimeoutFuture != null) { connectTimeoutFuture.cancel(false); connectTimeoutFuture = null; } } /** * Decorates the specified exception message with the common information such as the current protocol, * authentication scheme, proxy address, and destination address. */ protected final String exceptionMessage(String msg) { if (msg == null) { msg = ""; } StringBuilder buf = new StringBuilder(128 + msg.length()) .append(protocol()) .append(", ") .append(authScheme()) .append(", ") .append(proxyAddress) .append(" => ") .append(destinationAddress); if (!msg.isEmpty()) { buf.append(", ").append(msg); } return buf.toString(); } @Override public final void channelReadComplete(ChannelHandlerContext ctx) throws Exception { if (suppressChannelReadComplete) { suppressChannelReadComplete = false; readIfNeeded(ctx); } else { ctx.fireChannelReadComplete(); } } @Override public final void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { if (finished) { writePendingWrites(); ctx.write(msg, promise); } else { addPendingWrite(ctx, msg, promise); } } @Override public final void flush(ChannelHandlerContext ctx) throws Exception { if (finished) { writePendingWrites(); ctx.flush(); } else { flushedPrematurely = true; } } private static void readIfNeeded(ChannelHandlerContext ctx) { if (!ctx.channel().config().isAutoRead()) { ctx.read(); } } private void writePendingWrites() { if (pendingWrites != null) { pendingWrites.removeAndWriteAll(); pendingWrites = null; } } private void failPendingWrites(Throwable cause) { if (pendingWrites != null) { pendingWrites.removeAndFailAll(cause); pendingWrites = null; } } private void addPendingWrite(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) { PendingWriteQueue pendingWrites = this.pendingWrites; if (pendingWrites == null) { this.pendingWrites = pendingWrites = new PendingWriteQueue(ctx); } pendingWrites.add(msg, promise); } private final class LazyChannelPromise extends DefaultPromise<Channel> { @Override protected EventExecutor executor() { if (ctx == null) { throw new IllegalStateException(); } return ctx.executor(); } } }