/* * JBoss, Home of Professional Open Source. * Copyright 2013 Red Hat, Inc., and individual contributors * as indicated by the @author tags. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.xnio.nio.test; import java.io.IOException; import java.net.Inet4Address; import java.net.InetSocketAddress; import java.net.UnknownHostException; import java.nio.ByteBuffer; import java.util.Arrays; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import junit.framework.TestCase; import org.jboss.logging.Logger; import org.junit.Ignore; import org.xnio.Buffers; import org.xnio.IoUtils; import org.xnio.Xnio; import org.xnio.OptionMap; import org.xnio.ChannelListener; import org.xnio.Options; import org.xnio.XnioWorker; import org.xnio.channels.MulticastMessageChannel; import org.xnio.channels.SocketAddressBuffer; /** * * Test for UDP connections. * * @author <a href="mailto:david.lloyd@redhat.com">David M. Lloyd</a> * */ public final class NioUdpTestCase extends TestCase { private static final int SERVER_PORT = 12345; private static final InetSocketAddress SERVER_SOCKET_ADDRESS; private static final InetSocketAddress CLIENT_SOCKET_ADDRESS; private static final Logger log = Logger.getLogger("TEST"); static { try { SERVER_SOCKET_ADDRESS = new InetSocketAddress(Inet4Address.getByAddress(new byte[] {127, 0, 0, 1}), SERVER_PORT); CLIENT_SOCKET_ADDRESS = new InetSocketAddress(Inet4Address.getByAddress(new byte[] {127, 0, 0, 1}), 0); } catch (UnknownHostException e) { throw new RuntimeException(e); } } private synchronized void doServerSideTest(final boolean multicast, final ChannelListener<MulticastMessageChannel> handler, final Runnable body) throws IOException { final Xnio xnio = Xnio.getInstance("nio"); doServerSidePart(multicast, handler, body, xnio.createWorker(OptionMap.EMPTY)); } private void doServerSidePart(final boolean multicast, final ChannelListener<MulticastMessageChannel> handler, final Runnable body, final XnioWorker worker) throws IOException { doPart(multicast, handler, body, SERVER_SOCKET_ADDRESS, worker); } private void doClientSidePart(final boolean multicast, final ChannelListener<MulticastMessageChannel> handler, final Runnable body, final XnioWorker worker) throws IOException { doPart(multicast, handler, body, CLIENT_SOCKET_ADDRESS, worker); } private synchronized void doPart(final boolean multicast, final ChannelListener<MulticastMessageChannel> handler, final Runnable body, final InetSocketAddress bindAddress, final XnioWorker worker) throws IOException { final MulticastMessageChannel server = worker.createUdpServer(bindAddress, handler, OptionMap.create(Options.MULTICAST, Boolean.valueOf(multicast))); try { body.run(); server.close(); } catch (RuntimeException e) { log.errorf(e, "Error running part"); throw e; } catch (IOException e) { log.errorf(e, "Error running part"); throw e; } catch (Error e) { log.errorf(e, "Error running part"); throw e; } finally { IoUtils.safeClose(server); } } private synchronized void doClientServerSide(final boolean clientMulticast, final boolean serverMulticast, final ChannelListener<MulticastMessageChannel> serverHandler, final ChannelListener<MulticastMessageChannel> clientHandler, final Runnable body) throws IOException { final Xnio xnio = Xnio.getInstance("nio"); final XnioWorker worker = xnio.createWorker(OptionMap.EMPTY); try { doServerSidePart(serverMulticast, serverHandler, new Runnable() { public void run() { try { doClientSidePart(clientMulticast, clientHandler, body, worker); } catch (IOException e) { throw new RuntimeException(e); } } }, worker); } finally { worker.shutdown(); try { worker.awaitTermination(1L, TimeUnit.MINUTES); } catch (InterruptedException ignored) { } } } private void doServerCreate(boolean multicast) throws Exception { final CountDownLatch latch = new CountDownLatch(2); final AtomicBoolean openedOk = new AtomicBoolean(false); final AtomicBoolean closedOk = new AtomicBoolean(false); doServerSideTest(multicast, new ChannelListener<MulticastMessageChannel>() { public void handleEvent(final MulticastMessageChannel channel) { channel.getCloseSetter().set(new ChannelListener<MulticastMessageChannel>() { public void handleEvent(final MulticastMessageChannel channel) { closedOk.set(true); latch.countDown(); } }); log.infof("In handleEvent for %s", channel); openedOk.set(true); latch.countDown(); } }, new Runnable() { public void run() { } }); assertTrue(latch.await(500L, TimeUnit.MILLISECONDS)); assertTrue(openedOk.get()); assertTrue(closedOk.get()); } public void testServerCreate() throws Exception { log.info("Test: testServerCreate"); doServerCreate(false); } public void testServerCreateMulticast() throws Exception { log.info("Test: testServerCreateMulticast"); doServerCreate(true); } @SuppressWarnings("unused") @Ignore /* XXX - depends on each server getting a separate thread */ public void testClientToServerTransmitNioToNio() throws Exception { if (true) return; log.info("Test: testClientToServerTransmitNioToNio"); final AtomicBoolean clientOK = new AtomicBoolean(false); final AtomicBoolean serverOK = new AtomicBoolean(false); final CountDownLatch startLatch = new CountDownLatch(1); final CountDownLatch receivedLatch = new CountDownLatch(1); final CountDownLatch doneLatch = new CountDownLatch(2); final byte[] payload = new byte[] { 10, 5, 15, 10, 100, -128, 30, 0, 0 }; doClientServerSide(true, true, new ChannelListener<MulticastMessageChannel>() { public void handleEvent(final MulticastMessageChannel channel) { log.infof("In handleEvent for %s", channel); channel.getReadSetter().set(new ChannelListener<MulticastMessageChannel>() { public void handleEvent(final MulticastMessageChannel channel) { log.infof("In handleReadable for %s", channel); try { final ByteBuffer buffer = ByteBuffer.allocate(50); final SocketAddressBuffer addressBuffer = new SocketAddressBuffer(); final int result = channel.receiveFrom(addressBuffer, buffer); if (result == 0) { log.infof("Whoops, spurious read notification for %s", channel); channel.resumeReads(); return; } try { final byte[] testPayload = new byte[payload.length]; Buffers.flip(buffer).get(testPayload); log.infof("We received the packet on %s", channel); assertTrue(Arrays.equals(testPayload, payload)); assertFalse(buffer.hasRemaining()); assertNotNull(addressBuffer.getSourceAddress()); try { channel.close(); serverOK.set(true); } finally { IoUtils.safeClose(channel); } } finally { receivedLatch.countDown(); doneLatch.countDown(); } } catch (IOException e) { IoUtils.safeClose(channel); throw new RuntimeException(e); } } }); channel.resumeReads(); startLatch.countDown(); } }, new ChannelListener<MulticastMessageChannel>() { public void handleEvent(final MulticastMessageChannel channel) { log.infof("In handleEvent for %s", channel); channel.getWriteSetter().set(new ChannelListener<MulticastMessageChannel>() { public void handleEvent(final MulticastMessageChannel channel) { log.infof("In handleWritable for %s", channel); try { if (clientOK.get()) { log.infof("Extra writable notification on %s (?!)", channel); } else if (! channel.sendTo(SERVER_SOCKET_ADDRESS, ByteBuffer.wrap(payload))) { log.infof("Whoops, spurious write notification for %s", channel); channel.resumeWrites(); } else { log.infof("We sent the packet on %s", channel); try { assertTrue(receivedLatch.await(500000L, TimeUnit.MILLISECONDS)); channel.close(); } finally { IoUtils.safeClose(channel); } clientOK.set(true); doneLatch.countDown(); } } catch (IOException e) { IoUtils.safeClose(channel); e.printStackTrace(); } catch (InterruptedException e) { throw new RuntimeException(e); } } }); try { // wait until server is ready assertTrue(startLatch.await(500000L, TimeUnit.MILLISECONDS)); } catch (InterruptedException e) { throw new RuntimeException(e); } channel.resumeWrites(); } }, new Runnable() { public void run() { try { assertTrue(doneLatch.await(500000L, TimeUnit.MILLISECONDS)); } catch (InterruptedException e) { throw new RuntimeException(e); } } }); assertTrue(clientOK.get()); assertTrue(serverOK.get()); } }