/** * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.zookeeper.server; import static org.jboss.netty.buffer.ChannelBuffers.dynamicBuffer; import java.io.IOException; import java.net.InetAddress; import java.net.InetSocketAddress; import java.util.HashMap; import java.util.HashSet; import java.util.Set; import java.util.concurrent.Executors; import org.apache.zookeeper.Login; import org.apache.zookeeper.server.auth.SaslServerCallbackHandler; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.jboss.netty.bootstrap.ServerBootstrap; import org.jboss.netty.buffer.ChannelBuffer; import org.jboss.netty.buffer.ChannelBuffers; import org.jboss.netty.channel.Channel; import org.jboss.netty.channel.ChannelHandler.Sharable; import org.jboss.netty.channel.ChannelHandlerContext; import org.jboss.netty.channel.ChannelStateEvent; import org.jboss.netty.channel.ExceptionEvent; import org.jboss.netty.channel.MessageEvent; import org.jboss.netty.channel.SimpleChannelHandler; import org.jboss.netty.channel.WriteCompletionEvent; import org.jboss.netty.channel.group.ChannelGroup; import org.jboss.netty.channel.group.DefaultChannelGroup; import org.jboss.netty.channel.socket.nio.NioServerSocketChannelFactory; import javax.security.auth.login.Configuration; import javax.security.auth.login.LoginException; public class NettyServerCnxnFactory extends ServerCnxnFactory { Logger LOG = LoggerFactory.getLogger(NettyServerCnxnFactory.class); ServerBootstrap bootstrap; Channel parentChannel; ChannelGroup allChannels = new DefaultChannelGroup("zkServerCnxns"); HashSet<ServerCnxn> cnxns = new HashSet<ServerCnxn>(); HashMap<InetAddress, Set<NettyServerCnxn>> ipMap = new HashMap<InetAddress, Set<NettyServerCnxn>>( ); InetSocketAddress localAddress; int maxClientCnxns = 60; /** * This is an inner class since we need to extend SimpleChannelHandler, but * NettyServerCnxnFactory already extends ServerCnxnFactory. By making it inner * this class gets access to the member variables and methods. */ @Sharable class CnxnChannelHandler extends SimpleChannelHandler { @Override public void channelClosed(ChannelHandlerContext ctx, ChannelStateEvent e) throws Exception { if (LOG.isTraceEnabled()) { LOG.trace("Channel closed " + e); } allChannels.remove(ctx.getChannel()); } @Override public void channelConnected(ChannelHandlerContext ctx, ChannelStateEvent e) throws Exception { if (LOG.isTraceEnabled()) { LOG.trace("Channel connected " + e); } allChannels.add(ctx.getChannel()); NettyServerCnxn cnxn = new NettyServerCnxn(ctx.getChannel(), zkServer, NettyServerCnxnFactory.this); ctx.setAttachment(cnxn); addCnxn(cnxn); } @Override public void channelDisconnected(ChannelHandlerContext ctx, ChannelStateEvent e) throws Exception { if (LOG.isTraceEnabled()) { LOG.trace("Channel disconnected " + e); } NettyServerCnxn cnxn = (NettyServerCnxn) ctx.getAttachment(); if (cnxn != null) { if (LOG.isTraceEnabled()) { LOG.trace("Channel disconnect caused close " + e); } cnxn.close(); } } @Override public void exceptionCaught(ChannelHandlerContext ctx, ExceptionEvent e) throws Exception { LOG.warn("Exception caught " + e, e.getCause()); NettyServerCnxn cnxn = (NettyServerCnxn) ctx.getAttachment(); if (cnxn != null) { if (LOG.isDebugEnabled()) { LOG.debug("Closing " + cnxn); cnxn.close(); } } } @Override public void messageReceived(ChannelHandlerContext ctx, MessageEvent e) throws Exception { if (LOG.isTraceEnabled()) { LOG.trace("message received called " + e.getMessage()); } try { if (LOG.isDebugEnabled()) { LOG.debug("New message " + e.toString() + " from " + ctx.getChannel()); } NettyServerCnxn cnxn = (NettyServerCnxn)ctx.getAttachment(); synchronized(cnxn) { processMessage(e, cnxn); } } catch(Exception ex) { LOG.error("Unexpected exception in receive", ex); throw ex; } } private void processMessage(MessageEvent e, NettyServerCnxn cnxn) { if (LOG.isDebugEnabled()) { LOG.debug(Long.toHexString(cnxn.sessionId) + " queuedBuffer: " + cnxn.queuedBuffer); } if (e instanceof NettyServerCnxn.ResumeMessageEvent) { LOG.debug("Received ResumeMessageEvent"); if (cnxn.queuedBuffer != null) { if (LOG.isTraceEnabled()) { LOG.trace("processing queue " + Long.toHexString(cnxn.sessionId) + " queuedBuffer 0x" + ChannelBuffers.hexDump(cnxn.queuedBuffer)); } cnxn.receiveMessage(cnxn.queuedBuffer); if (!cnxn.queuedBuffer.readable()) { LOG.debug("Processed queue - no bytes remaining"); cnxn.queuedBuffer = null; } else { LOG.debug("Processed queue - bytes remaining"); } } else { LOG.debug("queue empty"); } cnxn.channel.setReadable(true); } else { ChannelBuffer buf = (ChannelBuffer)e.getMessage(); if (LOG.isTraceEnabled()) { LOG.trace(Long.toHexString(cnxn.sessionId) + " buf 0x" + ChannelBuffers.hexDump(buf)); } if (cnxn.throttled) { LOG.debug("Received message while throttled"); // we are throttled, so we need to queue if (cnxn.queuedBuffer == null) { LOG.debug("allocating queue"); cnxn.queuedBuffer = dynamicBuffer(buf.readableBytes()); } cnxn.queuedBuffer.writeBytes(buf); LOG.debug(Long.toHexString(cnxn.sessionId) + " queuedBuffer 0x" + ChannelBuffers.hexDump(cnxn.queuedBuffer)); } else { LOG.debug("not throttled"); if (cnxn.queuedBuffer != null) { if (LOG.isTraceEnabled()) { LOG.trace(Long.toHexString(cnxn.sessionId) + " queuedBuffer 0x" + ChannelBuffers.hexDump(cnxn.queuedBuffer)); } cnxn.queuedBuffer.writeBytes(buf); if (LOG.isTraceEnabled()) { LOG.trace(Long.toHexString(cnxn.sessionId) + " queuedBuffer 0x" + ChannelBuffers.hexDump(cnxn.queuedBuffer)); } cnxn.receiveMessage(cnxn.queuedBuffer); if (!cnxn.queuedBuffer.readable()) { LOG.debug("Processed queue - no bytes remaining"); cnxn.queuedBuffer = null; } else { LOG.debug("Processed queue - bytes remaining"); } } else { cnxn.receiveMessage(buf); if (buf.readable()) { if (LOG.isTraceEnabled()) { LOG.trace("Before copy " + buf); } cnxn.queuedBuffer = dynamicBuffer(buf.readableBytes()); cnxn.queuedBuffer.writeBytes(buf); if (LOG.isTraceEnabled()) { LOG.trace("Copy is " + cnxn.queuedBuffer); LOG.trace(Long.toHexString(cnxn.sessionId) + " queuedBuffer 0x" + ChannelBuffers.hexDump(cnxn.queuedBuffer)); } } } } } } @Override public void writeComplete(ChannelHandlerContext ctx, WriteCompletionEvent e) throws Exception { if (LOG.isTraceEnabled()) { LOG.trace("write complete " + e); } } } CnxnChannelHandler channelHandler = new CnxnChannelHandler(); NettyServerCnxnFactory() { bootstrap = new ServerBootstrap( new NioServerSocketChannelFactory( Executors.newCachedThreadPool(), Executors.newCachedThreadPool())); // parent channel bootstrap.setOption("reuseAddress", true); // child channels bootstrap.setOption("child.tcpNoDelay", true); bootstrap.setOption("child.soLinger", 2); bootstrap.getPipeline().addLast("servercnxnfactory", channelHandler); } @Override public void closeAll() { if (LOG.isDebugEnabled()) { LOG.debug("closeAll()"); } synchronized (cnxns) { // got to clear all the connections that we have in the selector for (NettyServerCnxn cnxn : cnxns.toArray(new NettyServerCnxn[cnxns.size()])) { try { cnxn.close(); } catch (Exception e) { LOG.warn("Ignoring exception closing cnxn sessionid 0x" + Long.toHexString(cnxn.getSessionId()), e); } } } if (LOG.isDebugEnabled()) { LOG.debug("allChannels size:" + allChannels.size() + " cnxns size:" + cnxns.size()); } } @Override public void closeSession(long sessionId) { if (LOG.isDebugEnabled()) { LOG.debug("closeSession sessionid:0x" + sessionId); } synchronized (cnxns) { for (NettyServerCnxn cnxn : cnxns.toArray(new NettyServerCnxn[cnxns.size()])) { if (cnxn.getSessionId() == sessionId) { try { cnxn.close(); } catch (Exception e) { LOG.warn("exception during session close", e); } break; } } } } @Override public void configure(InetSocketAddress addr, int maxClientCnxns) throws IOException { if (System.getProperty("java.security.auth.login.config") != null) { try { saslServerCallbackHandler = new SaslServerCallbackHandler(Configuration.getConfiguration()); login = new Login("Server",saslServerCallbackHandler); login.startThreadIfNeeded(); } catch (LoginException e) { throw new IOException("Could not configure server because SASL configuration did not allow the " + " Zookeeper server to authenticate itself properly: " + e); } } localAddress = addr; this.maxClientCnxns = maxClientCnxns; } /** {@inheritDoc} */ public int getMaxClientCnxnsPerHost() { return maxClientCnxns; } /** {@inheritDoc} */ public void setMaxClientCnxnsPerHost(int max) { maxClientCnxns = max; } @Override public int getLocalPort() { return localAddress.getPort(); } boolean killed; @Override public void join() throws InterruptedException { synchronized(this) { while(!killed) { wait(); } } } @Override public void shutdown() { LOG.info("shutdown called " + localAddress); if (login != null) { login.shutdown(); } // null if factory never started if (parentChannel != null) { parentChannel.close().awaitUninterruptibly(); closeAll(); allChannels.close().awaitUninterruptibly(); bootstrap.releaseExternalResources(); } if (zkServer != null) { zkServer.shutdown(); } synchronized(this) { killed = true; notifyAll(); } } @Override public void start() { LOG.info("binding to port " + localAddress); parentChannel = bootstrap.bind(localAddress); } @Override public void startup(ZooKeeperServer zks) throws IOException, InterruptedException { start(); zks.startdata(); zks.startup(); setZooKeeperServer(zks); } @Override public Iterable<ServerCnxn> getConnections() { return cnxns; } @Override public InetSocketAddress getLocalAddress() { return localAddress; } private void addCnxn(NettyServerCnxn cnxn) { synchronized (cnxns) { cnxns.add(cnxn); synchronized (ipMap){ InetAddress addr = ((InetSocketAddress)cnxn.channel.getRemoteAddress()) .getAddress(); Set<NettyServerCnxn> s = ipMap.get(addr); if (s == null) { s = new HashSet<NettyServerCnxn>(); } s.add(cnxn); ipMap.put(addr,s); } } } }