package org.infinispan.server.core.transport;
import java.io.Serializable;
import java.net.InetSocketAddress;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicLong;
import javax.management.AttributeNotFoundException;
import javax.management.InstanceNotFoundException;
import javax.management.MBeanException;
import javax.management.MBeanServer;
import javax.management.MalformedObjectNameException;
import javax.management.ObjectName;
import javax.management.ReflectionException;
import org.infinispan.Cache;
import org.infinispan.commons.CacheException;
import org.infinispan.commons.logging.LogFactory;
import org.infinispan.commons.util.Util;
import org.infinispan.configuration.global.GlobalConfiguration;
import org.infinispan.distexec.DefaultExecutorService;
import org.infinispan.distexec.DistributedCallable;
import org.infinispan.distexec.DistributedExecutorService;
import org.infinispan.jmx.JmxUtil;
import org.infinispan.manager.EmbeddedCacheManager;
import org.infinispan.server.core.configuration.ProtocolServerConfiguration;
import org.infinispan.server.core.logging.Log;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.PooledByteBufAllocator;
import io.netty.channel.Channel;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.ServerChannel;
import io.netty.channel.epoll.EpollEventLoopGroup;
import io.netty.channel.epoll.EpollServerSocketChannel;
import io.netty.channel.group.ChannelGroup;
import io.netty.channel.group.ChannelGroupFuture;
import io.netty.channel.group.DefaultChannelGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.util.concurrent.DefaultThreadFactory;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.ImmediateEventExecutor;
import io.netty.util.internal.logging.InternalLoggerFactory;
import io.netty.util.internal.logging.Log4J2LoggerFactory;
/**
* A Netty based transport.
*
* @author Galder ZamarreƱo
* @author wburns
* @since 4.1
*/
public class NettyTransport implements Transport {
static private final Log log = LogFactory.getLog(NettyTransport.class, Log.class);
static private final boolean isLog4jAvailable;
static {
boolean exception;
try {
Util.loadClassStrict("org.apache.logging.log4j.Logger", Thread.currentThread().getContextClassLoader());
exception = false;
} catch (ClassNotFoundException e) {
exception = true;
}
isLog4jAvailable = !exception;
}
public NettyTransport(InetSocketAddress address, ProtocolServerConfiguration configuration, String threadNamePrefix,
EmbeddedCacheManager cacheManager, boolean useNativeEpoll) {
this.address = address;
this.configuration = configuration;
this.threadNamePrefix = threadNamePrefix;
this.cacheManager = cacheManager;
this.useNativeEpoll = useNativeEpoll;
// Need to initialize these in constructor since they require configuration
masterGroup = buildEventLoop(1, new DefaultThreadFactory(threadNamePrefix + "-ServerMaster"));
workerGroup = buildEventLoop(0, new DefaultThreadFactory(threadNamePrefix + "-ServerWorker"));
isGlobalStatsEnabled = cacheManager.getCacheManagerConfiguration().globalJmxStatistics().enabled();
serverChannels = new DefaultChannelGroup(threadNamePrefix + "-Channels", ImmediateEventExecutor.INSTANCE);
acceptedChannels = new DefaultChannelGroup(threadNamePrefix + "-Accepted", ImmediateEventExecutor.INSTANCE);
}
public void initializeHandler(ChannelInitializer<Channel> handler) {
this.handler = handler;
}
private ChannelInitializer<Channel> handler;
private final InetSocketAddress address;
private final ProtocolServerConfiguration configuration;
private final String threadNamePrefix;
private final EmbeddedCacheManager cacheManager;
private final ChannelGroup serverChannels;
final ChannelGroup acceptedChannels;
private final boolean useNativeEpoll;
private final EventLoopGroup masterGroup;
private final EventLoopGroup workerGroup;
private final AtomicLong totalBytesWritten = new AtomicLong();
private final AtomicLong totalBytesRead = new AtomicLong();
private final boolean isGlobalStatsEnabled;
private Optional<Integer> nettyPort = Optional.empty();
@Override
public void start() {
// Make netty use log4j, otherwise it goes to JDK logging.
if (isLog4jAvailable)
InternalLoggerFactory.setDefaultFactory(Log4J2LoggerFactory.INSTANCE);
ServerBootstrap bootstrap = new ServerBootstrap();
bootstrap.group(masterGroup, workerGroup);
bootstrap.channel(getServerSocketChannel());
bootstrap.childHandler(handler);
bootstrap.childOption(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT);
bootstrap.childOption(ChannelOption.TCP_NODELAY, configuration.tcpNoDelay()); // Sets server side tcpNoDelay
if (configuration.sendBufSize() > 0)
bootstrap.childOption(ChannelOption.SO_SNDBUF, configuration.sendBufSize()); // Sets server side send buffer
if (configuration.recvBufSize() > 0)
bootstrap.childOption(ChannelOption.SO_RCVBUF, configuration.recvBufSize()); // Sets server side receive buffer
Channel ch;
try {
ch = bootstrap.bind(address).sync().channel();
nettyPort = Optional.of(((InetSocketAddress)ch.localAddress()).getPort());
} catch (InterruptedException e) {
throw new CacheException(e);
}
serverChannels.add(ch);
}
@Override
public void stop() {
Future<?> masterTerminationFuture = masterGroup.shutdownGracefully(100, 1000, TimeUnit.MILLISECONDS);
Future<?> workerTerminationFuture = workerGroup.shutdownGracefully(100, 1000, TimeUnit.MILLISECONDS);
masterTerminationFuture.awaitUninterruptibly();
workerTerminationFuture.awaitUninterruptibly();
// This is probably not necessary, all Netty resources should have been freed already
ChannelGroupFuture serverChannelsTerminationFuture = serverChannels.close();
ChannelGroupFuture acceptedChannelsTerminationFuture = acceptedChannels.close();
ChannelGroupFuture future = serverChannelsTerminationFuture.awaitUninterruptibly();
if (!future.isSuccess()) {
log.serverDidNotUnbind();
future.forEach(fut -> {
Channel ch = fut.channel();
if (ch.isActive()) {
log.channelStillBound(ch, ch.remoteAddress());
}
});
}
future = acceptedChannelsTerminationFuture.awaitUninterruptibly();
if (!future.isSuccess()) {
log.serverDidNotClose();
future.forEach(fut -> {
Channel ch = fut.channel();
if (ch.isActive()) {
log.channelStillConnected(ch, ch.remoteAddress());
}
});
}
if (log.isDebugEnabled())
log.debug("Channel group completely closed, external resources released");
nettyPort = Optional.empty();
}
@Override
public String getTotalBytesWritten() {
return totalBytesWritten.toString();
}
@Override
public String getTotalBytesRead() {
return totalBytesRead.toString();
}
@Override
public String getHostName() {
return address.getHostName();
}
@Override
public Integer getPort() {
return nettyPort.orElse(address.getPort());
}
@Override
public String getNumberWorkerThreads() {
return Integer.toString(configuration.workerThreads());
}
@Override
public String getIdleTimeout() {
return Integer.toString(configuration.idleTimeout());
}
@Override
public String getTcpNoDelay() {
return Boolean.toString(configuration.tcpNoDelay());
}
@Override
public String getSendBufferSize() {
return Integer.toString(configuration.sendBufSize());
}
@Override
public String getReceiveBufferSize() {
return Integer.toString(configuration.recvBufSize());
}
@Override
public Integer getNumberOfLocalConnections() {
return Integer.valueOf(acceptedChannels.size());
}
@Override
public Integer getNumberOfGlobalConnections() {
if (needDistributedCalculation()) {
return calculateGlobalConnections();
} else {
return getNumberOfLocalConnections();
}
}
public void updateTotalBytesWritten(int bytes) {
if (isGlobalStatsEnabled)
incrementTotalBytes(totalBytesWritten, bytes);
}
public void updateTotalBytesRead(int bytes) {
if (isGlobalStatsEnabled)
incrementTotalBytes(totalBytesRead, bytes);
}
private void incrementTotalBytes(AtomicLong base, int bytes) {
if (isGlobalStatsEnabled)
base.addAndGet(bytes);
}
private boolean needDistributedCalculation() {
org.infinispan.remoting.transport.Transport transport = cacheManager.getTransport();
return transport != null && transport.getMembers().size() > 1;
}
private Class<? extends ServerChannel> getServerSocketChannel() {
Class<? extends ServerChannel> channel = useNativeEpoll ? EpollServerSocketChannel.class : NioServerSocketChannel.class;
log.createdSocketChannel(channel.getName(), configuration.toString());
return channel;
}
private EventLoopGroup buildEventLoop(int nThreads, DefaultThreadFactory threadFactory) {
EventLoopGroup eventLoop = useNativeEpoll ? new EpollEventLoopGroup(nThreads, threadFactory) :
new NioEventLoopGroup(nThreads, threadFactory);
log.createdNettyEventLoop(eventLoop.getClass().getName(), configuration.toString());
return eventLoop;
}
private int calculateGlobalConnections() {
Cache<Object, Object> cache = cacheManager.getCache();
DistributedExecutorService exec = new DefaultExecutorService(cache);
try {
// Submit calculation task
List<CompletableFuture<Integer>> results = exec.submitEverywhere(
new ConnectionAdderTask(threadNamePrefix));
// Take all results and add them up with a bit of functional programming magic :)
return results.stream().mapToInt(f -> {
try {
return f.get(30, TimeUnit.SECONDS);
} catch (InterruptedException | ExecutionException | TimeoutException e) {
throw new CacheException(e);
}
}).sum();
} finally {
exec.shutdown();
}
}
static class ConnectionAdderTask implements Serializable, DistributedCallable<Object, Object, Integer> {
private final String serverName;
Cache<Object, Object> cache;
ConnectionAdderTask(String serverName) {
this.serverName = serverName;
}
@Override
public void setEnvironment(Cache<Object, Object> cache, Set<Object> inputKeys) {
this.cache = cache;
}
@Override
public Integer call() throws Exception {
GlobalConfiguration globalCfg = cache.getCacheManager().getCacheManagerConfiguration();
String jmxDomain = globalCfg.globalJmxStatistics().domain();
MBeanServer mbeanServer = JmxUtil.lookupMBeanServer(globalCfg);
try {
ObjectName transportMBeanName = new ObjectName(
jmxDomain + ":type=Server,component=Transport,name=" + serverName);
return (Integer) mbeanServer.getAttribute(transportMBeanName, "NumberOfLocalConnections");
} catch (MBeanException | AttributeNotFoundException | InstanceNotFoundException | ReflectionException |
MalformedObjectNameException e) {
throw new RuntimeException(e);
}
}
}
}