/* * Copyright 2013 Stanley Shyiko * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.github.shyiko.mysql.binlog; import java.io.Closeable; import java.io.IOException; import java.net.InetSocketAddress; import java.nio.ByteBuffer; import java.nio.channels.SelectionKey; import java.nio.channels.Selector; import java.nio.channels.ServerSocketChannel; import java.nio.channels.SocketChannel; import java.util.Collections; import java.util.IdentityHashMap; import java.util.Set; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.logging.Level; import java.util.logging.Logger; /** * @author <a href="mailto:stanley.shyiko@gmail.com">Stanley Shyiko</a> */ public class TCPReverseProxy { private final Logger logger = Logger.getLogger(getClass().getSimpleName()); private final int port; private final String targetHost; private final int targetPort; @SuppressWarnings("unchecked") private Set<Closeable> closeable = Collections.synchronizedSet(Collections.newSetFromMap(new IdentityHashMap())); private volatile Selector selector; private volatile CountDownLatch latch; public TCPReverseProxy(int port, int targetPort) { this(port, "localhost", targetPort); } public TCPReverseProxy(int port, String targetHost, int targetPort) { this.port = port; this.targetHost = targetHost; this.targetPort = targetPort; resetInternalState(); } public int getPort() { return port; } public String getTargetHost() { return targetHost; } public int getTargetPort() { return targetPort; } public void bind() throws IOException { if (!closeable.isEmpty()) { throw new IllegalStateException(); } ServerSocketChannel serverSocketChannel = ServerSocketChannel.open(); serverSocketChannel.socket().bind(new InetSocketAddress(port)); closeable.add(serverSocketChannel); selector = Selector.open(); serverSocketChannel.configureBlocking(false).register(selector, SelectionKey.OP_ACCEPT); ByteBuffer buffer = ByteBuffer.allocate(1024); if (logger.isLoggable(Level.FINEST)) { logger.finest("Listening on port " + port); } latch.countDown(); selector.select(); for (; selector.isOpen(); selector.select()) { Set<SelectionKey> keys = selector.selectedKeys(); for (SelectionKey key : keys) { if (!key.isValid()) { continue; } if (key.isAcceptable()) { ServerSocketChannel channel = (ServerSocketChannel) key.channel(); SocketChannel clientSocketChannel = channel.accept(); closeable.add(clientSocketChannel); InetSocketAddress remoteAddress = new InetSocketAddress(targetHost, targetPort); SocketChannel remoteSocketChannel = SocketChannel.open(remoteAddress); closeable.add(remoteSocketChannel); if (logger.isLoggable(Level.FINEST)) { logger.finest("Established new connection " + System.identityHashCode(remoteSocketChannel) + " to " + targetHost + ":" + targetPort); } clientSocketChannel.configureBlocking(false). register(selector, SelectionKey.OP_READ, remoteSocketChannel); remoteSocketChannel.configureBlocking(false). register(selector, SelectionKey.OP_READ, clientSocketChannel); } else if (key.isReadable()) { SocketChannel channel = (SocketChannel) key.channel(); SocketChannel targetChannel = (SocketChannel) key.attachment(); try { int x = channel.read(buffer); if (x == -1) { if (logger.isLoggable(Level.FINEST)) { logger.finest("Closed connection " + System.identityHashCode(targetChannel)); } closeQuietly(targetChannel, channel); continue; } buffer.flip(); targetChannel.write(buffer); buffer.rewind(); } catch (IOException e) { closeQuietly(targetChannel, channel); } } } keys.clear(); } } public void await(long timeout, TimeUnit unit) throws InterruptedException { latch.await(timeout, unit); } public void unbind() throws IOException { try { selector.close(); } catch (IOException e) { e.printStackTrace(); } closeQuietly(closeable.toArray(new Closeable[closeable.size()])); if (logger.isLoggable(Level.FINEST)) { logger.finest("Released port " + port); } resetInternalState(); } private void closeQuietly(Closeable... arrayOfCloseable) { for (Closeable closeable : arrayOfCloseable) { try { closeable.close(); } catch (IOException e) { e.printStackTrace(); } this.closeable.remove(closeable); } } private void resetInternalState() { latch = new CountDownLatch(1); } }