package com.linkedin.databus2.test.container; /* * * Copyright 2013 LinkedIn Corp. All rights reserved * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. * */ import java.net.InetSocketAddress; import java.net.SocketAddress; import java.nio.ByteOrder; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import java.util.concurrent.locks.Condition; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; import org.jboss.netty.bootstrap.ServerBootstrap; import org.jboss.netty.buffer.DirectChannelBufferFactory; import org.jboss.netty.channel.Channel; import org.jboss.netty.channel.ChannelFactory; import org.jboss.netty.channel.ChannelFuture; import org.jboss.netty.channel.ChannelHandlerContext; import org.jboss.netty.channel.ChannelPipelineFactory; import org.jboss.netty.channel.ChannelStateEvent; import org.jboss.netty.channel.ChildChannelStateEvent; import org.jboss.netty.channel.SimpleChannelUpstreamHandler; import org.jboss.netty.channel.local.DefaultLocalServerChannelFactory; import org.jboss.netty.channel.local.LocalAddress; import org.jboss.netty.channel.socket.nio.NioServerSocketChannelFactory; import org.jboss.netty.channel.socket.oio.OioServerSocketChannelFactory; import org.testng.Assert; /** Simple network server for unit tests. One has to specify the channel pipeline to use. */ public class SimpleTestServerConnection { public enum ServerType { NIO, LOCAL, OIO } private Channel _channel; private Thread _thread; private final ServerBootstrap _srvBootstrap; private final Lock _lock = new ReentrantLock(true); private boolean _shutdownRequested; private boolean _shutdown; private boolean _started; private final Condition _startedCondition = _lock.newCondition(); private final Condition _shutdownReqCondition = _lock.newCondition(); private final Condition _shutdownCondition = _lock.newCondition(); private Channel _lastConnChannel; private final Map<SocketAddress, Channel> _childrenChannels = new ConcurrentHashMap<SocketAddress, Channel>(); private final ServerType _serverType; public SimpleTestServerConnection(ByteOrder bufferByteOrder) { this(bufferByteOrder, ServerType.LOCAL); } public SimpleTestServerConnection(ByteOrder bufferByteOrder, ServerType serverType) { ChannelFactory channelFactory; _serverType = serverType; switch (serverType) { case LOCAL: channelFactory = new DefaultLocalServerChannelFactory(); break; case NIO: { channelFactory = new NioServerSocketChannelFactory(Executors.newCachedThreadPool(), Executors.newCachedThreadPool()); break; } case OIO: { channelFactory = new OioServerSocketChannelFactory(Executors.newCachedThreadPool(), Executors.newCachedThreadPool()); break; } default: throw new RuntimeException("unsupported server type: " + serverType ); } _srvBootstrap = new ServerBootstrap(channelFactory); _srvBootstrap.setOption("child.bufferFactory", DirectChannelBufferFactory.getInstance(bufferByteOrder)); _srvBootstrap.setParentHandler(new ChildChannelTracker()); } public void setPipelineFactory(ChannelPipelineFactory pipelineFactory) { _srvBootstrap.setPipelineFactory(pipelineFactory); } public void start(final int localAddr) { _shutdownRequested = false; _shutdown = false; _thread = new Thread(new Runnable() { @Override public void run() { SocketAddress serverAddr = (ServerType.LOCAL == _serverType) ? new LocalAddress(localAddr) : new InetSocketAddress(localAddr); //System.err.println("Server running on thread: " + Thread.currentThread()); _channel = _srvBootstrap.bind(serverAddr); _lock.lock(); try { _started = true; _startedCondition.signalAll(); while (!_shutdownRequested) { try { _shutdownReqCondition.await(); } catch (InterruptedException ie) {} } _shutdown = true; _shutdownCondition.signalAll(); } finally { _lock.unlock(); } } }); _thread.setDaemon(true); _thread.start(); } public boolean startSynchronously(final int localAddr, long timeoutMillis) { start(localAddr); try {awaitStarted(timeoutMillis);} catch (InterruptedException ie) {}; return isStarted(); } public boolean isStarted() { _lock.lock(); try { return _started; } finally { _lock.unlock(); } } public void awaitStarted(long timeoutMillis) throws InterruptedException { _lock.lock(); try { if (!_started) _startedCondition.await(timeoutMillis, TimeUnit.MILLISECONDS); } finally { _lock.unlock(); } } public void stop() { _lock.lock(); try { _shutdownRequested = true; _shutdownReqCondition.signalAll(); while (! _shutdown) { try { _shutdownCondition.await(); } catch (InterruptedException ie) {} } } finally { _lock.unlock(); } ChannelFuture closeFuture = _channel.close(); closeFuture.awaitUninterruptibly(); _srvBootstrap.releaseExternalResources(); } public Channel getChannel() { return _channel; } class ChildChannelTracker extends SimpleChannelUpstreamHandler { @Override public synchronized void childChannelOpen(ChannelHandlerContext ctx, ChildChannelStateEvent e) throws Exception { _lastConnChannel = e.getChildChannel(); e.getChildChannel().getPipeline().addFirst("childChannelMapHandler", new ChildChannelMapHandler()); super.childChannelOpen(ctx, e); } class ChildChannelMapHandler extends SimpleChannelUpstreamHandler { @Override public void channelConnected(ChannelHandlerContext ctx, ChannelStateEvent e) throws Exception { SocketAddress remoteAddr = e.getChannel().getRemoteAddress(); _childrenChannels.put(remoteAddr, e.getChannel()); super.channelConnected(ctx, e); } @Override public void channelClosed(ChannelHandlerContext ctx, ChannelStateEvent e) throws Exception { SocketAddress remoteAddr = e.getChannel().getRemoteAddress(); _childrenChannels.remove(remoteAddr); super.channelClosed(ctx, e); } } } public Channel getLastConnChannel() { return _lastConnChannel; } public Channel getChildChannel(SocketAddress clientAddr) { return _childrenChannels.get(clientAddr); } public void sendServerResponse(SocketAddress clientAddr, Object response, long timeoutMillis) { Channel childChannel = getChildChannel(clientAddr); Assert.assertNotEquals(childChannel, null); ChannelFuture writeFuture = childChannel.write(response); if (timeoutMillis > 0) { try { writeFuture.await(timeoutMillis); } catch (InterruptedException e) { //NOOP } Assert.assertTrue(writeFuture.isDone()); Assert.assertTrue(writeFuture.isSuccess()); } } public void sendServerClose(SocketAddress clientAddr, long timeoutMillis) { Channel childChannel = getChildChannel(clientAddr); Assert.assertNotEquals(childChannel, null); ChannelFuture closeFuture = childChannel.close(); if (timeoutMillis > 0) { try { closeFuture.await(timeoutMillis); } catch (InterruptedException e) { //NOOP } Assert.assertTrue(closeFuture.isDone()); Assert.assertTrue(closeFuture.isSuccess()); } } }