/** * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.activemq.util; import javax.net.ssl.SSLServerSocketFactory; import javax.net.ssl.SSLSocketFactory; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.net.InetSocketAddress; import java.net.ServerSocket; import java.net.Socket; import java.net.SocketException; import java.net.SocketTimeoutException; import java.net.URI; import java.util.ArrayList; import java.util.LinkedList; import java.util.List; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import org.slf4j.Logger; import org.slf4j.LoggerFactory; public class SocketProxy { private static final transient Logger LOG = LoggerFactory.getLogger(SocketProxy.class); public static final int ACCEPT_TIMEOUT_MILLIS = 100; private URI proxyUrl; private URI target; private Acceptor acceptor; private ServerSocket serverSocket; private CountDownLatch closed = new CountDownLatch(1); public final List<Bridge> connections = new LinkedList<>(); private int listenPort = 0; private int receiveBufferSize = -1; private boolean pauseAtStart = false; private int acceptBacklog = 50; public SocketProxy() throws Exception { } public SocketProxy(URI uri) throws Exception { this(0, uri); } public SocketProxy(int port, URI uri) throws Exception { listenPort = port; target = uri; open(); } public void setReceiveBufferSize(int receiveBufferSize) { this.receiveBufferSize = receiveBufferSize; } public void setTarget(URI tcpBrokerUri) { target = tcpBrokerUri; } public void open() throws Exception { serverSocket = createServerSocket(target); serverSocket.setReuseAddress(true); if (receiveBufferSize > 0) { serverSocket.setReceiveBufferSize(receiveBufferSize); } if (proxyUrl == null) { serverSocket.bind(new InetSocketAddress(listenPort), acceptBacklog); proxyUrl = urlFromSocket(target, serverSocket); } else { serverSocket.bind(new InetSocketAddress(proxyUrl.getPort())); } acceptor = new Acceptor(serverSocket, target); if (pauseAtStart) { acceptor.pause(); } new Thread(null, acceptor, "SocketProxy-Acceptor-" + serverSocket.getLocalPort()).start(); closed = new CountDownLatch(1); } private boolean isSsl(URI target) { return "ssl".equals(target.getScheme()); } private ServerSocket createServerSocket(URI target) throws Exception { if (isSsl(target)) { return SSLServerSocketFactory.getDefault().createServerSocket(); } return new ServerSocket(); } private Socket createSocket(URI target) throws Exception { if (isSsl(target)) { return SSLSocketFactory.getDefault().createSocket(); } return new Socket(); } public URI getUrl() { return proxyUrl; } /* * close all proxy connections and acceptor */ public void close() { List<Bridge> connections; synchronized (this.connections) { connections = new ArrayList<>(this.connections); } LOG.info("close, numConnections=" + connections.size()); for (Bridge con : connections) { closeConnection(con); } acceptor.close(); closed.countDown(); } /* * close all proxy receive connections, leaving acceptor * open */ public void halfClose() { List<Bridge> connections; synchronized (this.connections) { connections = new ArrayList<>(this.connections); } LOG.info("halfClose, numConnections=" + connections.size()); for (Bridge con : connections) { halfCloseConnection(con); } } public boolean waitUntilClosed(long timeoutSeconds) throws InterruptedException { return closed.await(timeoutSeconds, TimeUnit.SECONDS); } /* * called after a close to restart the acceptor on the same port */ public void reopen() { LOG.info("reopen"); try { open(); } catch (Exception e) { LOG.debug("exception on reopen url:" + getUrl(), e); } } /* * pause accepting new connections and data transfer through existing proxy * connections. All sockets remain open */ public void pause() { synchronized (connections) { LOG.info("pause, numConnections=" + connections.size()); acceptor.pause(); for (Bridge con : connections) { con.pause(); } } } /* * continue after pause */ public void goOn() { synchronized (connections) { LOG.info("goOn, numConnections=" + connections.size()); for (Bridge con : connections) { con.goOn(); } } acceptor.goOn(); } private void closeConnection(Bridge c) { try { c.close(); } catch (Exception e) { LOG.debug("exception on close of: " + c, e); } } private void halfCloseConnection(Bridge c) { try { c.halfClose(); } catch (Exception e) { LOG.debug("exception on half close of: " + c, e); } } public boolean isPauseAtStart() { return pauseAtStart; } public void setPauseAtStart(boolean pauseAtStart) { this.pauseAtStart = pauseAtStart; } public int getAcceptBacklog() { return acceptBacklog; } public void setAcceptBacklog(int acceptBacklog) { this.acceptBacklog = acceptBacklog; } private URI urlFromSocket(URI uri, ServerSocket serverSocket) throws Exception { int listenPort = serverSocket.getLocalPort(); return new URI(uri.getScheme(), uri.getUserInfo(), uri.getHost(), listenPort, uri.getPath(), uri.getQuery(), uri.getFragment()); } public class Bridge { private Socket receiveSocket; private Socket sendSocket; private Pump requestThread; private Pump responseThread; public Bridge(Socket socket, URI target) throws Exception { receiveSocket = socket; sendSocket = createSocket(target); if (receiveBufferSize > 0) { sendSocket.setReceiveBufferSize(receiveBufferSize); } sendSocket.connect(new InetSocketAddress(target.getHost(), target.getPort())); linkWithThreads(receiveSocket, sendSocket); LOG.info("proxy connection " + sendSocket + ", receiveBufferSize=" + sendSocket.getReceiveBufferSize()); } public void goOn() { responseThread.goOn(); requestThread.goOn(); } public void pause() { requestThread.pause(); responseThread.pause(); } public void close() throws Exception { synchronized (connections) { connections.remove(this); } receiveSocket.close(); sendSocket.close(); } public void halfClose() throws Exception { receiveSocket.close(); } private void linkWithThreads(Socket source, Socket dest) { requestThread = new Pump(source, dest); requestThread.start(); responseThread = new Pump(dest, source); responseThread.start(); } public class Pump extends Thread { protected Socket src; private Socket destination; private AtomicReference<CountDownLatch> pause = new AtomicReference<>(); public Pump(Socket source, Socket dest) { super("SocketProxy-DataTransfer-" + source.getPort() + ":" + dest.getPort()); src = source; destination = dest; pause.set(new CountDownLatch(0)); } public void pause() { pause.set(new CountDownLatch(1)); } public void goOn() { pause.get().countDown(); } @Override public void run() { byte[] buf = new byte[1024]; try { InputStream in = src.getInputStream(); OutputStream out = destination.getOutputStream(); while (true) { int len = in.read(buf); if (len == -1) { LOG.debug("read eof from:" + src); break; } pause.get().await(); out.write(buf, 0, len); } } catch (Exception e) { LOG.debug("read/write failed, reason: " + e.getLocalizedMessage()); try { if (!receiveSocket.isClosed()) { // for halfClose, on read/write failure if we close the // remote end will see a close at the same time. close(); } } catch (Exception ignore) { } } } } } public class Acceptor implements Runnable { private ServerSocket socket; private URI target; private AtomicReference<CountDownLatch> pause = new AtomicReference<>(); public Acceptor(ServerSocket serverSocket, URI uri) { socket = serverSocket; target = uri; pause.set(new CountDownLatch(0)); try { socket.setSoTimeout(ACCEPT_TIMEOUT_MILLIS); } catch (SocketException e) { e.printStackTrace(); } } public void pause() { pause.set(new CountDownLatch(1)); } public void goOn() { pause.get().countDown(); } @Override public void run() { try { while (!socket.isClosed()) { pause.get().await(); try { Socket source = socket.accept(); pause.get().await(); if (receiveBufferSize > 0) { source.setReceiveBufferSize(receiveBufferSize); } LOG.info("accepted " + source + ", receiveBufferSize:" + source.getReceiveBufferSize()); synchronized (connections) { connections.add(new Bridge(source, target)); } } catch (SocketTimeoutException expected) { } } } catch (Exception e) { LOG.debug("acceptor: finished for reason: " + e.getLocalizedMessage()); } } public void close() { try { socket.close(); closed.countDown(); goOn(); } catch (IOException ignored) { } } } }