/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.activemq.transport.auto;
import java.io.IOException;
import java.io.InputStream;
import java.net.Socket;
import java.net.URI;
import java.net.URISyntaxException;
import java.nio.ByteBuffer;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.Future;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicInteger;
import javax.net.ServerSocketFactory;
import org.apache.activemq.broker.BrokerService;
import org.apache.activemq.broker.BrokerServiceAware;
import org.apache.activemq.openwire.OpenWireFormatFactory;
import org.apache.activemq.transport.InactivityIOException;
import org.apache.activemq.transport.Transport;
import org.apache.activemq.transport.TransportFactory;
import org.apache.activemq.transport.TransportServer;
import org.apache.activemq.transport.protocol.AmqpProtocolVerifier;
import org.apache.activemq.transport.protocol.MqttProtocolVerifier;
import org.apache.activemq.transport.protocol.OpenWireProtocolVerifier;
import org.apache.activemq.transport.protocol.ProtocolVerifier;
import org.apache.activemq.transport.protocol.StompProtocolVerifier;
import org.apache.activemq.transport.tcp.TcpTransport;
import org.apache.activemq.transport.tcp.TcpTransport.InitBuffer;
import org.apache.activemq.transport.tcp.TcpTransportFactory;
import org.apache.activemq.transport.tcp.TcpTransportServer;
import org.apache.activemq.util.FactoryFinder;
import org.apache.activemq.util.IOExceptionSupport;
import org.apache.activemq.util.IntrospectionSupport;
import org.apache.activemq.util.ServiceStopper;
import org.apache.activemq.wireformat.WireFormat;
import org.apache.activemq.wireformat.WireFormatFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* A TCP based implementation of {@link TransportServer}
*/
public class AutoTcpTransportServer extends TcpTransportServer {
private static final Logger LOG = LoggerFactory.getLogger(AutoTcpTransportServer.class);
protected Map<String, Map<String, Object>> wireFormatOptions;
protected Map<String, Object> autoTransportOptions;
protected Set<String> enabledProtocols;
protected final Map<String, ProtocolVerifier> protocolVerifiers = new ConcurrentHashMap<String, ProtocolVerifier>();
protected BrokerService brokerService;
protected final ThreadPoolExecutor newConnectionExecutor;
protected final ThreadPoolExecutor protocolDetectionExecutor;
protected int maxConnectionThreadPoolSize = Integer.MAX_VALUE;
protected int protocolDetectionTimeOut = 15000;
private static final FactoryFinder TRANSPORT_FACTORY_FINDER = new FactoryFinder("META-INF/services/org/apache/activemq/transport/");
private final ConcurrentMap<String, TransportFactory> transportFactories = new ConcurrentHashMap<String, TransportFactory>();
private static final FactoryFinder WIREFORMAT_FACTORY_FINDER = new FactoryFinder("META-INF/services/org/apache/activemq/wireformat/");
public WireFormatFactory findWireFormatFactory(String scheme, Map<String, Map<String, Object>> options) throws IOException {
WireFormatFactory wff = null;
try {
wff = (WireFormatFactory)WIREFORMAT_FACTORY_FINDER.newInstance(scheme);
if (options != null) {
final Map<String, Object> wfOptions = new HashMap<>();
if (options.get(AutoTransportUtils.ALL) != null) {
wfOptions.putAll(options.get(AutoTransportUtils.ALL));
}
if (options.get(scheme) != null) {
wfOptions.putAll(options.get(scheme));
}
IntrospectionSupport.setProperties(wff, wfOptions);
}
if (wff instanceof OpenWireFormatFactory) {
protocolVerifiers.put(AutoTransportUtils.OPENWIRE, new OpenWireProtocolVerifier((OpenWireFormatFactory) wff));
}
return wff;
} catch (Throwable e) {
throw IOExceptionSupport.create("Could not create wire format factory for: " + scheme + ", reason: " + e, e);
}
}
public TransportFactory findTransportFactory(String scheme, Map<String, ?> options) throws IOException {
scheme = append(scheme, "nio");
scheme = append(scheme, "ssl");
if (scheme.isEmpty()) {
scheme = "tcp";
}
TransportFactory tf = transportFactories.get(scheme);
if (tf == null) {
// Try to load if from a META-INF property.
try {
tf = (TransportFactory)TRANSPORT_FACTORY_FINDER.newInstance(scheme);
if (options != null) {
IntrospectionSupport.setProperties(tf, options);
}
transportFactories.put(scheme, tf);
} catch (Throwable e) {
throw IOExceptionSupport.create("Transport scheme NOT recognized: [" + scheme + "]", e);
}
}
return tf;
}
protected String append(String currentScheme, String scheme) {
if (this.getBindLocation().getScheme().contains(scheme)) {
if (!currentScheme.isEmpty()) {
currentScheme += "+";
}
currentScheme += scheme;
}
return currentScheme;
}
/**
* @param transportFactory
* @param location
* @param serverSocketFactory
* @throws IOException
* @throws URISyntaxException
*/
public AutoTcpTransportServer(TcpTransportFactory transportFactory,
URI location, ServerSocketFactory serverSocketFactory, BrokerService brokerService,
Set<String> enabledProtocols)
throws IOException, URISyntaxException {
super(transportFactory, location, serverSocketFactory);
//Use an executor service here to handle new connections. Setting the max number
//of threads to the maximum number of connections the thread count isn't unbounded
newConnectionExecutor = new ThreadPoolExecutor(maxConnectionThreadPoolSize,
maxConnectionThreadPoolSize,
30L, TimeUnit.SECONDS,
new LinkedBlockingQueue<Runnable>());
//allow the thread pool to shrink if the max number of threads isn't needed
//and the pool can grow and shrink as needed if contention is high
newConnectionExecutor.allowCoreThreadTimeOut(true);
//Executor for waiting for bytes to detection of protocol
protocolDetectionExecutor = new ThreadPoolExecutor(maxConnectionThreadPoolSize,
maxConnectionThreadPoolSize,
30L, TimeUnit.SECONDS,
new LinkedBlockingQueue<Runnable>());
//allow the thread pool to shrink if the max number of threads isn't needed
protocolDetectionExecutor.allowCoreThreadTimeOut(true);
this.brokerService = brokerService;
this.enabledProtocols = enabledProtocols;
initProtocolVerifiers();
}
public int getMaxConnectionThreadPoolSize() {
return maxConnectionThreadPoolSize;
}
/**
* Set the number of threads to be used for processing connections. Defaults
* to Integer.MAX_SIZE. Set this value to be lower to reduce the
* number of simultaneous connection attempts. If not set then the maximum number of
* threads will generally be controlled by the transport maxConnections setting:
* {@link TcpTransportServer#setMaximumConnections(int)}.
*<p>
* Note that this setter controls two thread pools because connection attempts
* require 1 thread to start processing the connection and another thread to read from the
* socket and to detect the protocol. Two threads are needed because some transports
* block on socket read so the first thread needs to be able to abort the second thread on timeout.
* Therefore this setting will set each thread pool to the size passed in essentially giving
* 2 times as many potential threads as the value set.
*<p>
* Both thread pools will close idle threads after a period of time
* essentially allowing the thread pools to grow and shrink dynamically based on load.
*
* @see {@link TcpTransportServer#setMaximumConnections(int)}.
* @param maxConnectionThreadPoolSize
*/
public void setMaxConnectionThreadPoolSize(int maxConnectionThreadPoolSize) {
this.maxConnectionThreadPoolSize = maxConnectionThreadPoolSize;
newConnectionExecutor.setCorePoolSize(maxConnectionThreadPoolSize);
newConnectionExecutor.setMaximumPoolSize(maxConnectionThreadPoolSize);
protocolDetectionExecutor.setCorePoolSize(maxConnectionThreadPoolSize);
protocolDetectionExecutor.setMaximumPoolSize(maxConnectionThreadPoolSize);
}
public void setProtocolDetectionTimeOut(int protocolDetectionTimeOut) {
this.protocolDetectionTimeOut = protocolDetectionTimeOut;
}
@Override
public void setWireFormatFactory(WireFormatFactory factory) {
super.setWireFormatFactory(factory);
initOpenWireProtocolVerifier();
}
protected void initProtocolVerifiers() {
initOpenWireProtocolVerifier();
if (isAllProtocols() || enabledProtocols.contains(AutoTransportUtils.AMQP)) {
protocolVerifiers.put(AutoTransportUtils.AMQP, new AmqpProtocolVerifier());
}
if (isAllProtocols() || enabledProtocols.contains(AutoTransportUtils.STOMP)) {
protocolVerifiers.put(AutoTransportUtils.STOMP, new StompProtocolVerifier());
}
if (isAllProtocols()|| enabledProtocols.contains(AutoTransportUtils.MQTT)) {
protocolVerifiers.put(AutoTransportUtils.MQTT, new MqttProtocolVerifier());
}
}
protected void initOpenWireProtocolVerifier() {
if (isAllProtocols() || enabledProtocols.contains(AutoTransportUtils.OPENWIRE)) {
OpenWireProtocolVerifier owpv;
if (wireFormatFactory instanceof OpenWireFormatFactory) {
owpv = new OpenWireProtocolVerifier((OpenWireFormatFactory) wireFormatFactory);
} else {
owpv = new OpenWireProtocolVerifier(new OpenWireFormatFactory());
}
protocolVerifiers.put(AutoTransportUtils.OPENWIRE, owpv);
}
}
protected boolean isAllProtocols() {
return enabledProtocols == null || enabledProtocols.isEmpty();
}
@Override
protected void handleSocket(final Socket socket) {
final AutoTcpTransportServer server = this;
//This needs to be done in a new thread because
//the socket might be waiting on the client to send bytes
//doHandleSocket can't complete until the protocol can be detected
newConnectionExecutor.submit(new Runnable() {
@Override
public void run() {
server.doHandleSocket(socket);
}
});
}
@Override
protected TransportInfo configureTransport(final TcpTransportServer server, final Socket socket) throws Exception {
final InputStream is = socket.getInputStream();
final AtomicInteger readBytes = new AtomicInteger(0);
final ByteBuffer data = ByteBuffer.allocate(8);
// We need to peak at the first 8 bytes of the buffer to detect the protocol
Future<?> future = protocolDetectionExecutor.submit(new Runnable() {
@Override
public void run() {
try {
do {
//will block until enough bytes or read or a timeout
//and the socket is closed
int read = is.read();
if (read == -1) {
throw new IOException("Connection failed, stream is closed.");
}
data.put((byte) read);
readBytes.incrementAndGet();
} while (readBytes.get() < 8 && !Thread.interrupted());
} catch (Exception e) {
throw new IllegalStateException(e);
}
}
});
try {
//If this fails and throws an exception and the socket will be closed
waitForProtocolDetectionFinish(future, readBytes);
} finally {
//call cancel in case task didn't complete
future.cancel(true);
}
data.flip();
ProtocolInfo protocolInfo = detectProtocol(data.array());
InitBuffer initBuffer = new InitBuffer(readBytes.get(), ByteBuffer.allocate(readBytes.get()));
initBuffer.buffer.put(data.array());
if (protocolInfo.detectedTransportFactory instanceof BrokerServiceAware) {
((BrokerServiceAware) protocolInfo.detectedTransportFactory).setBrokerService(brokerService);
}
WireFormat format = protocolInfo.detectedWireFormatFactory.createWireFormat();
Transport transport = createTransport(socket, format, protocolInfo.detectedTransportFactory, initBuffer);
return new TransportInfo(format, transport, protocolInfo.detectedTransportFactory);
}
protected void waitForProtocolDetectionFinish(final Future<?> future, final AtomicInteger readBytes) throws Exception {
try {
//Wait for protocolDetectionTimeOut if defined
if (protocolDetectionTimeOut > 0) {
future.get(protocolDetectionTimeOut, TimeUnit.MILLISECONDS);
} else {
future.get();
}
} catch (TimeoutException e) {
throw new InactivityIOException("Client timed out before wire format could be detected. " +
" 8 bytes are required to detect the protocol but only: " + readBytes.get() + " byte(s) were sent.");
}
}
/**
* @param socket
* @param format
* @param detectedTransportFactory
* @return
*/
protected TcpTransport createTransport(Socket socket, WireFormat format,
TcpTransportFactory detectedTransportFactory, InitBuffer initBuffer) throws IOException {
return new TcpTransport(format, socket, initBuffer);
}
public void setWireFormatOptions(Map<String, Map<String, Object>> wireFormatOptions) {
this.wireFormatOptions = wireFormatOptions;
}
public void setEnabledProtocols(Set<String> enabledProtocols) {
this.enabledProtocols = enabledProtocols;
}
public void setAutoTransportOptions(Map<String, Object> autoTransportOptions) {
this.autoTransportOptions = autoTransportOptions;
if (autoTransportOptions.get("protocols") != null) {
this.enabledProtocols = AutoTransportUtils.parseProtocols((String) autoTransportOptions.get("protocols"));
}
}
@Override
protected void doStop(ServiceStopper stopper) throws Exception {
if (newConnectionExecutor != null) {
newConnectionExecutor.shutdownNow();
try {
if (!newConnectionExecutor.awaitTermination(3, TimeUnit.SECONDS)) {
LOG.warn("Auto Transport newConnectionExecutor didn't shutdown cleanly");
}
} catch (InterruptedException e) {
}
}
if (protocolDetectionExecutor != null) {
protocolDetectionExecutor.shutdownNow();
try {
if (!protocolDetectionExecutor.awaitTermination(3, TimeUnit.SECONDS)) {
LOG.warn("Auto Transport protocolDetectionExecutor didn't shutdown cleanly");
}
} catch (InterruptedException e) {
}
}
super.doStop(stopper);
}
protected ProtocolInfo detectProtocol(byte[] buffer) throws IOException {
TcpTransportFactory detectedTransportFactory = transportFactory;
WireFormatFactory detectedWireFormatFactory = wireFormatFactory;
boolean found = false;
for (String scheme : protocolVerifiers.keySet()) {
if (protocolVerifiers.get(scheme).isProtocol(buffer)) {
LOG.debug("Detected protocol " + scheme);
detectedWireFormatFactory = findWireFormatFactory(scheme, wireFormatOptions);
if (scheme.equals("default")) {
scheme = "";
}
detectedTransportFactory = (TcpTransportFactory) findTransportFactory(scheme, transportOptions);
found = true;
break;
}
}
if (!found) {
throw new IllegalStateException("Could not detect the wire format");
}
return new ProtocolInfo(detectedTransportFactory, detectedWireFormatFactory);
}
protected class ProtocolInfo {
public final TcpTransportFactory detectedTransportFactory;
public final WireFormatFactory detectedWireFormatFactory;
public ProtocolInfo(TcpTransportFactory detectedTransportFactory,
WireFormatFactory detectedWireFormatFactory) {
super();
this.detectedTransportFactory = detectedTransportFactory;
this.detectedWireFormatFactory = detectedWireFormatFactory;
}
}
}