/* * JBoss, Home of Professional Open Source * * Copyright 2014 Red Hat, Inc. and/or its affiliates. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.xnio.nio.test; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import java.io.IOException; import java.nio.ByteBuffer; import java.nio.channels.ClosedChannelException; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import org.junit.Test; import org.xnio.ChannelListener; import org.xnio.channels.ConnectedChannel; import org.xnio.channels.StreamSinkChannel; import org.xnio.channels.StreamSourceChannel; /** * Superclass for all ssl tcp test cases. * * @author <a href="mailto:frainone@redhat.com">Flavia Rainone</a> * */ public abstract class AbstractNioSslTcpTest<T extends ConnectedChannel, R extends StreamSourceChannel, W extends StreamSinkChannel> extends AbstractNioTcpTest<T, R, W> { protected abstract void shutdownReads(T channel) throws IOException; protected abstract void shutdownWrites(T channel) throws IOException; @Override public void connect() throws Exception { clientClose(); // with SSL, START_TLS false, we don't support the normal close, because handshake is started immediately // so the only way to test connection is to have either the server or the client shutdown writes and wait for a read to return -1 } @Override public void clientClose() throws Exception { log.info("Test: clientClose"); final CountDownLatch latch = new CountDownLatch(4); final AtomicBoolean clientOK = new AtomicBoolean(false); final AtomicBoolean serverOK = new AtomicBoolean(false); doConnectionTest(new Runnable() { public void run() { try { assertTrue(latch.await(500L, TimeUnit.MILLISECONDS)); } catch (InterruptedException e) { throw new RuntimeException(e); } } }, new ChannelListener<T>() { public void handleEvent(final T channel) { log.info("In client open"); try { channel.getCloseSetter().set(new ChannelListener<ConnectedChannel>() { public void handleEvent(final ConnectedChannel channel) { log.info("In client close"); latch.countDown(); } }); shutdownWrites(channel); setReadListener(channel, new ChannelListener<R>() { @Override public void handleEvent(R sourceChannel) { int c; try { c = sourceChannel.read(ByteBuffer.allocate(100)); if (c == -1) { channel.close(); clientOK.set(true); latch.countDown(); } } catch (Throwable t) { log.error("In client", t); latch.countDown(); throw new RuntimeException(t); } } }); resumeReads(channel); } catch (Throwable t) { log.error("In client", t); latch.countDown(); throw new RuntimeException(t); } } }, new ChannelListener<T>() { public void handleEvent(final T channel) { log.info("In server opened"); channel.getCloseSetter().set(new ChannelListener<ConnectedChannel>() { public void handleEvent(final ConnectedChannel channel) { log.info("In server close"); latch.countDown(); } }); setReadListener(channel, new ChannelListener<R>() { public void handleEvent(final R sourceChannel) { log.info("In server readable"); try { int c = sourceChannel.read(ByteBuffer.allocate(100)); if (c == -1) { serverOK.set(true); channel.close(); latch.countDown(); } } catch (IOException t) { t.printStackTrace(); throw new RuntimeException(t); } } }); resumeReads(channel); } }); assertTrue(serverOK.get()); assertTrue(clientOK.get()); } @Test public void serverClose() throws Exception { log.info("Test: serverClose"); final CountDownLatch latch = new CountDownLatch(2); final AtomicBoolean clientOK = new AtomicBoolean(false); final AtomicBoolean serverOK = new AtomicBoolean(false); doConnectionTest(new Runnable() { public void run() { try { assertTrue(latch.await(500L, TimeUnit.MILLISECONDS)); } catch (InterruptedException e) { throw new RuntimeException(e); } } }, new ChannelListener<T>() { public void handleEvent(final T channel) { try { channel.getCloseSetter().set(new ChannelListener<ConnectedChannel>() { public void handleEvent(final ConnectedChannel channel) { latch.countDown(); } }); setReadListener(channel, new ChannelListener<R>() { public void handleEvent(final R sourceChannel) { try { final int c = sourceChannel.read(ByteBuffer.allocate(100)); if (c == -1) { log.info("client closing connection"); clientOK.set(true); channel.close(); return; } return; } catch (IOException e) { throw new RuntimeException(e); } } }); resumeReads(channel); } catch (Throwable t) { try { channel.close(); } catch (Throwable t2) { log.errorf(t2, "Failed to close channel (propagating as RT exception)"); latch.countDown(); throw new RuntimeException(t); } throw new RuntimeException(t); } } }, new ChannelListener<T>() { public void handleEvent(final T channel) { try { channel.getCloseSetter().set(new ChannelListener<ConnectedChannel>() { public void handleEvent(final ConnectedChannel channel) { log.info("In server close"); latch.countDown(); } }); shutdownWrites(channel); setReadListener(channel, new ChannelListener<R>() { @Override public void handleEvent(R sourceChannel) { int c; try { c = sourceChannel.read(ByteBuffer.allocate(100)); if (c == -1) { log.info("server closing connection"); channel.close(); serverOK.set(true); } } catch (Throwable t) { log.error("In server", t); latch.countDown(); throw new RuntimeException(t); } } }); resumeReads(channel); } catch (Throwable t) { log.errorf(t, "Failed to close channel (propagating as RT exception)"); latch.countDown(); throw new RuntimeException(t); } } }); assertTrue(serverOK.get()); assertTrue(clientOK.get()); } @Override public void twoWayTransfer() throws Exception { log.info("Test: twoWayTransfer"); final CountDownLatch latch = new CountDownLatch(2); final AtomicInteger clientSent = new AtomicInteger(0); final AtomicInteger clientReceived = new AtomicInteger(0); final AtomicInteger serverSent = new AtomicInteger(0); final AtomicInteger serverReceived = new AtomicInteger(0); doConnectionTest(new Runnable() { public void run() { try { assertTrue(latch.await(500L, TimeUnit.MILLISECONDS)); } catch (InterruptedException e) { throw new RuntimeException(e); } } }, new ChannelListener<T>() { public void handleEvent(final T channel) { channel.getCloseSetter().set(new ChannelListener<ConnectedChannel>() { public void handleEvent(final ConnectedChannel connection) { latch.countDown(); } }); setReadListener(channel, new ChannelListener<R>() { public void handleEvent(final R sourceChannel) { try { log.info("client handle readable"); int c; while ((c = sourceChannel.read(ByteBuffer.allocate(100))) > 0) { clientReceived.addAndGet(c); log.info("client received: " + clientReceived.get()); } if (c == -1) { log.info("client shutdown reads"); sourceChannel.shutdownReads(); } } catch (Throwable t) { t.printStackTrace(); throw new RuntimeException(t); } } }); setWriteListener(channel, new ChannelListener<W>() { public void handleEvent(final W sinkChannel) { log.info("client handle writable"); try { final ByteBuffer buffer = ByteBuffer.allocate(100); buffer.put("This Is A Test\r\n".getBytes("UTF-8")).flip(); int c; try { while ((c = sinkChannel.write(buffer)) > 0) { log.info("client sent: " + (clientSent.get() + c)); if (clientSent.addAndGet(c) > 1000) { setWriteListener(channel, new ChannelListener<W>() { public void handleEvent(final W sinkChannel) { try { if (sinkChannel.flush()) { setWriteListener(channel, new ChannelListener<W>() { public void handleEvent(final W sinkChannel) { // really lame, but due to the way SSL shuts down... if (serverReceived.get() == clientSent.get() && serverSent.get() == clientReceived.get() && serverSent.get() > 1000) { try { log.info("client shutdown writes"); sinkChannel.shutdownWrites(); log.info("client write handler closing connection"); channel.close(); } catch (Throwable t) { t.printStackTrace(); throw new RuntimeException(t); } } } }); return; } } catch (Throwable t) { t.printStackTrace(); throw new RuntimeException(t); } } }); return; } buffer.rewind(); } } catch (ClosedChannelException e) { sinkChannel.shutdownWrites(); throw e; } } catch (Throwable t) { t.printStackTrace(); throw new RuntimeException(t); } } }); resumeReads(channel); resumeWrites(channel); } }, new ChannelListener<T>() { public void handleEvent(final T channel) { channel.getCloseSetter().set(new ChannelListener<ConnectedChannel>() { public void handleEvent(final ConnectedChannel channel) { latch.countDown(); } }); setReadListener(channel, new ChannelListener<R>() { public void handleEvent(final R sourceChannel) { log.info("server handle readable"); try { int c; while ((c = sourceChannel.read(ByteBuffer.allocate(100))) > 0) { serverReceived.addAndGet(c); log.info("server received: " + serverReceived.get()); } if (c == -1) { log.info("server shutting down reads"); sourceChannel.shutdownReads(); } } catch (Throwable t) { t.printStackTrace(); throw new RuntimeException(t); } } }); setWriteListener(channel, new ChannelListener<W>() { public void handleEvent(final W sinkChannel) { log.info("server handle writable"); try { final ByteBuffer buffer = ByteBuffer.allocate(100); buffer.put("This Is A Test\r\n".getBytes("UTF-8")).flip(); int c; try { while ((c = sinkChannel.write(buffer)) > 0) { log.info("server sent: " + (serverSent.get() + c)); if (serverSent.addAndGet(c) > 1000) { setWriteListener(channel, new ChannelListener<W>() { public void handleEvent(final W sinkChannel) { try { if (sinkChannel.flush()) { setWriteListener(channel, new ChannelListener<W>() { public void handleEvent(final W sinkChannel) { // really lame, but due to the way SSL shuts down... if (clientReceived.get() == serverSent.get() && serverReceived.get() == clientSent.get() && clientSent.get() > 1000) { try { log.info("server shutting down writes"); sinkChannel.shutdownWrites(); log.info("server write handler closing connection"); channel.close(); } catch (Throwable t) { t.printStackTrace(); throw new RuntimeException(t); } } } }); return; } } catch (Throwable t) { t.printStackTrace(); throw new RuntimeException(t); } } }); return; } buffer.rewind(); } } catch (ClosedChannelException e) { sinkChannel.shutdownWrites(); throw e; } } catch (Throwable t) { t.printStackTrace(); throw new RuntimeException(t); } } }); resumeReads(channel); resumeWrites(channel); } }); assertEquals(serverSent.get(), clientReceived.get()); assertEquals(clientSent.get(), serverReceived.get()); } }