/* * JBoss, Home of Professional Open Source. * Copyright 2013 Red Hat, Inc., and individual contributors * as indicated by the @author tags. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.xnio.nio.test; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import java.io.IOException; import java.nio.ByteBuffer; import java.nio.channels.ClosedChannelException; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import org.junit.Before; import org.junit.Test; import org.xnio.ChannelListener; import org.xnio.IoUtils; import org.xnio.OptionMap; import org.xnio.Options; import org.xnio.channels.ConnectedChannel; import org.xnio.conduits.ConduitStreamSinkChannel; import org.xnio.conduits.ConduitStreamSourceChannel; import org.xnio.ssl.SslConnection; /** * Test for {@code XnioSsl} connections with the start TLS option enabled. * * @author <a href="mailto:frainone@redhat.com">Flavia Rainone</a> * */ public class NioStartTLSTcpConnectionTestCase extends NioSslTcpConnectionTestCase { @Before public void setStartTLSOption() { final OptionMap optionMap = OptionMap.create(Options.SSL_STARTTLS, true); super.setServerOptionMap(optionMap); super.setClientOptionMap(optionMap); } @Test public void oneWayTransfer3() throws Exception { log.info("Test: oneWayTransfer"); final CountDownLatch latch = new CountDownLatch(2); final AtomicInteger clientSent = new AtomicInteger(0); final AtomicInteger serverReceived = new AtomicInteger(0); final AtomicBoolean clientHandshakeStarted = new AtomicBoolean(false); final AtomicBoolean serverHandshakeStarted = new AtomicBoolean(false); doConnectionTest(new Runnable() { public void run() { try { assertTrue(latch.await(500L, TimeUnit.MILLISECONDS)); } catch (InterruptedException e) { throw new RuntimeException(e); } } }, new ChannelListener<SslConnection>() { public void handleEvent(final SslConnection connection) { connection.getCloseSetter().set(new ChannelListener<SslConnection>() { public void handleEvent(final SslConnection channel) { latch.countDown(); } }); connection.getSinkChannel().setWriteListener(new ChannelListener<ConduitStreamSinkChannel>() { private boolean continueWriting() throws IOException { if (clientSent.get() > 100) { if (!clientHandshakeStarted.get()) { if (serverReceived.get() == clientSent.get()) { connection.startHandshake(); log.info("client starting handshake"); clientHandshakeStarted.set(true); return true; } return false; } if (serverHandshakeStarted.get()) { return true; } return false; } return true; } public void handleEvent(final ConduitStreamSinkChannel channel) { try { final ByteBuffer buffer = ByteBuffer.allocate(100); buffer.put("This Is A Test\r\n".getBytes("UTF-8")).flip(); int c; try { while (continueWriting() && (c = channel.write(buffer)) > 0) { log.info("client wrote " + (c + clientSent.get())); if (clientSent.addAndGet(c) > 1000) { final ChannelListener<ConduitStreamSinkChannel> listener = new ChannelListener<ConduitStreamSinkChannel>() { public void handleEvent(final ConduitStreamSinkChannel channel) { try { if (channel.flush()) { final ChannelListener<ConduitStreamSinkChannel> listener = new ChannelListener<ConduitStreamSinkChannel>() { public void handleEvent(final ConduitStreamSinkChannel channel) { // really lame, but due to the way SSL shuts down... if (serverReceived.get() == clientSent.get()) { try { log.info("client shutting down writes"); channel.shutdownWrites(); if (connection.isWriteShutdown()) { log.info("client write handler closing connection"); connection.close(); } } catch (Throwable t) { t.printStackTrace(); throw new RuntimeException(t); } } } }; channel.getWriteSetter().set(listener); listener.handleEvent(channel); return; } } catch (Throwable t) { t.printStackTrace(); throw new RuntimeException(t); } } }; channel.setWriteListener(listener); listener.handleEvent(channel); return; } } buffer.rewind(); } catch (ClosedChannelException e) { try { channel.shutdownWrites(); } catch (Exception exception) {} throw e; } } catch (Throwable t) { t.printStackTrace(); throw new RuntimeException(t); } } }); connection.getSinkChannel().resumeWrites(); } }, new ChannelListener<SslConnection>() { public void handleEvent(final SslConnection connection) { connection.getCloseSetter().set(new ChannelListener<SslConnection>() { public void handleEvent(final SslConnection channel) { latch.countDown(); } }); connection.getSourceChannel().setReadListener(new ChannelListener<ConduitStreamSourceChannel>() { public void handleEvent(final ConduitStreamSourceChannel channel) { try { int c; while ((c = channel.read(ByteBuffer.allocate(100))) > 0) { log.info("server received " + (c + serverReceived.get())); if (serverReceived.addAndGet(c) > 100 && !serverHandshakeStarted.get() ) { connection.startHandshake(); serverHandshakeStarted.set(true); } } if (c == -1) { log.info("server shutting down reads"); channel.shutdownReads(); if(connection.isReadShutdown()) { log.info("server read handler closing connection"); connection.close(); } } } catch (Throwable t) { t.printStackTrace(); throw new RuntimeException(t); } } }); connection.getSourceChannel().resumeReads(); } }); assertEquals(clientSent.get(), serverReceived.get()); } public void oneWayTransfer4() throws Exception { log.info("Test: oneWayTransfer4"); final CountDownLatch latch = new CountDownLatch(2); final AtomicInteger clientReceived = new AtomicInteger(0); final AtomicInteger serverSent = new AtomicInteger(0); final AtomicBoolean clientHandshakeStarted = new AtomicBoolean(false); final AtomicBoolean serverHandshakeStarted = new AtomicBoolean(false); doConnectionTest(new Runnable() { public void run() { try { assertTrue(latch.await(500L, TimeUnit.MILLISECONDS)); } catch (InterruptedException e) { throw new RuntimeException(e); } } }, new ChannelListener<SslConnection>() { public void handleEvent(final SslConnection connection) { connection.getCloseSetter().set(new ChannelListener<SslConnection>() { public void handleEvent(final SslConnection connection) { latch.countDown(); } }); connection.getSourceChannel().setReadListener(new ChannelListener<ConduitStreamSourceChannel>() { public void handleEvent(final ConduitStreamSourceChannel channel) { try { int c; while ((c = channel.read(ByteBuffer.allocate(100))) > 0) { log.info("client received " + (c + clientReceived.get())); if (clientReceived.addAndGet(c) > 100 && !clientHandshakeStarted.get()) { connection.startHandshake(); clientHandshakeStarted.set(true); } } if (c == -1) { channel.shutdownReads(); if (connection.isReadShutdown()) connection.close(); } } catch (Throwable t) { t.printStackTrace(); throw new RuntimeException(t); } } }); connection.getSourceChannel().resumeReads(); } }, new ChannelListener<SslConnection>() { public void handleEvent(final SslConnection connection) { connection.getCloseSetter().set(new ChannelListener<SslConnection>() { public void handleEvent(final SslConnection connection) { latch.countDown(); } }); connection.getSinkChannel().setWriteListener(new ChannelListener<ConduitStreamSinkChannel>() { private boolean continueWriting() throws IOException { if (serverSent.get() > 100) { if (!serverHandshakeStarted.get()) { if (clientReceived.get() == serverSent.get()) { connection.startHandshake(); log.info("server starting handshake"); serverHandshakeStarted.set(true); return true; } return false; } if (clientHandshakeStarted.get()) { return true; } return false; } return true; } public void handleEvent(final ConduitStreamSinkChannel channel) { try { final ByteBuffer buffer = ByteBuffer.allocate(100); buffer.put("This Is A Test\r\n".getBytes("UTF-8")).flip(); int c; try { while (continueWriting() && (c = channel.write(buffer)) > 0) { log.info("server wrote " + (c + serverSent.get())); if (serverSent.addAndGet(c) > 100) { connection.startHandshake(); if (serverSent.get() > 1000) { final ChannelListener<ConduitStreamSinkChannel> listener = new ChannelListener<ConduitStreamSinkChannel>() { public void handleEvent(final ConduitStreamSinkChannel channel) { try { if (channel.flush()) { final ChannelListener<ConduitStreamSinkChannel> listener = new ChannelListener<ConduitStreamSinkChannel>() { public void handleEvent(final ConduitStreamSinkChannel channel) { // really lame, but due to the way SSL shuts down... if (clientReceived.get() == serverSent.get()) { try { channel.shutdownWrites(); if (connection.isWriteShutdown()) connection.close(); } catch (Throwable t) { t.printStackTrace(); throw new RuntimeException(t); } } } }; channel.setWriteListener(listener); listener.handleEvent(channel); return; } } catch (Throwable t) { t.printStackTrace(); throw new RuntimeException(t); } } }; channel.getWriteSetter().set(listener); listener.handleEvent(channel); return; } } buffer.rewind(); } } catch (ClosedChannelException e) { channel.shutdownWrites(); throw e; } } catch (Throwable t) { t.printStackTrace(); throw new RuntimeException(t); } } }); connection.getSinkChannel().resumeWrites(); } }); assertEquals(serverSent.get(), clientReceived.get()); } @Test public void twoWayTransferWithHandshake() throws Exception { log.info("Test: twoWayTransferWithHandshake"); final CountDownLatch latch = new CountDownLatch(2); final AtomicInteger clientSent = new AtomicInteger(0); final AtomicInteger clientReceived = new AtomicInteger(0); final AtomicInteger serverSent = new AtomicInteger(0); final AtomicInteger serverReceived = new AtomicInteger(0); final AtomicBoolean clientHandshakeStarted = new AtomicBoolean(false); final AtomicBoolean serverHandshakeStarted = new AtomicBoolean(false); doConnectionTest(new Runnable() { public void run() { try { assertTrue(latch.await(500L, TimeUnit.MILLISECONDS)); } catch (InterruptedException e) { throw new RuntimeException(e); } } }, new ChannelListener<SslConnection>() { public void handleEvent(final SslConnection connection) { connection.getCloseSetter().set(new ChannelListener<SslConnection>() { public void handleEvent(final SslConnection connection) { latch.countDown(); } }); final ConduitStreamSourceChannel sourceChannel = connection.getSourceChannel(); sourceChannel.setReadListener(new ChannelListener<ConduitStreamSourceChannel>() { private boolean continueReading() throws IOException { return clientHandshakeStarted.get() || clientReceived.get() < 101; } public void handleEvent(final ConduitStreamSourceChannel sourceChannel) { log.info("client handle read events"); try { int c = 0; while (continueReading() && (c = sourceChannel.read(ByteBuffer.allocate(100))) > 0) { log.info("client received: "+ (clientReceived.get() + c)); clientReceived.addAndGet(c); } if (c == -1) { log.info("client shutdown reads"); connection.close(); } } catch (Throwable t) { t.printStackTrace(); } } }); final ConduitStreamSinkChannel sinkChannel = connection.getSinkChannel(); sinkChannel.setWriteListener(new ChannelListener<ConduitStreamSinkChannel>() { private boolean continueWriting(ConduitStreamSinkChannel sinkChannel) throws IOException { if (clientSent.get() > 100) { if (!clientHandshakeStarted.get()) { if (serverReceived.get() == clientSent.get() && serverSent.get() > 100 && clientReceived.get() == serverSent.get() ) { connection.startHandshake(); log.info("client starting handshake"); clientHandshakeStarted.set(true); return true; } return false; } if (clientHandshakeStarted.get()) { return true; } return false; } return true; } public void handleEvent(final ConduitStreamSinkChannel sinkChannel) { try { final ByteBuffer buffer = ByteBuffer.allocate(100); buffer.put("This Is A Test\r\n".getBytes("UTF-8")).flip(); int c = 0; try { while (continueWriting(sinkChannel) && (clientSent.get() > 1000 || (c = sinkChannel.write(buffer)) > 0)) { log.info("clientSent: " + (clientSent.get() + c)); if (clientSent.addAndGet(c) > 1000) { sinkChannel.setWriteListener(new ChannelListener<ConduitStreamSinkChannel>() { public void handleEvent(final ConduitStreamSinkChannel sinkChannel) { try { if (sinkChannel.flush()) { try { log.info("client closing channel"); sinkChannel.shutdownWrites(); } catch (Throwable t) { t.printStackTrace(); throw new RuntimeException(t); } return; } } catch (Throwable t) { t.printStackTrace(); throw new RuntimeException(t); } } }); return; } buffer.rewind(); } } catch (ClosedChannelException e) { try { sinkChannel.shutdownWrites(); } catch (Exception cce) {/* do nothing */} throw e; } } catch (Throwable t) { t.printStackTrace(); throw new RuntimeException(t); } } }); sourceChannel.resumeReads(); sinkChannel.resumeWrites(); } }, new ChannelListener<SslConnection>() { public void handleEvent(final SslConnection connection) { connection.getCloseSetter().set(new ChannelListener<SslConnection>() { public void handleEvent(final SslConnection connection) { latch.countDown(); } }); final ConduitStreamSourceChannel sourceChannel = connection.getSourceChannel(); sourceChannel.setReadListener(new ChannelListener<ConduitStreamSourceChannel>() { private boolean continueReading() throws IOException { return serverHandshakeStarted.get() || serverReceived.get() < 101; } public void handleEvent(final ConduitStreamSourceChannel sourceChannel) { try { int c = 0; while (continueReading() && (c = sourceChannel.read(ByteBuffer.allocate(100))) > 0) { log.info("server received: "+ (serverReceived.get() + c)); serverReceived.addAndGet(c); } if (c == -1) { log.info("server shutdown reads"); connection.close(); } } catch (Throwable t) { t.printStackTrace(); throw new RuntimeException(t); } } }); final ConduitStreamSinkChannel sinkChannel = connection.getSinkChannel(); sinkChannel.setWriteListener(new ChannelListener<ConduitStreamSinkChannel>() { private boolean continueWriting(ConduitStreamSinkChannel sinkChannel) throws IOException { if (serverSent.get() > 100) { if (!serverHandshakeStarted.get()) { if (clientReceived.get() == serverSent.get() && clientSent.get() > 100 && serverReceived.get() == clientSent.get() ) { connection.startHandshake(); log.info("server starting handshake"); serverHandshakeStarted.set(true); return true; } return false; } if (clientHandshakeStarted.get()) { return true; } return false; } return true; } public void handleEvent(final ConduitStreamSinkChannel sinkChannel) { try { final ByteBuffer buffer = ByteBuffer.allocate(100); buffer.put("This Is A Test\r\n".getBytes("UTF-8")).flip(); int c; try { while (continueWriting(sinkChannel) && (c = sinkChannel.write(buffer)) > 0) { log.info("server sent: "+ (serverSent.get() + c)); if (serverSent.addAndGet(c) > 1000) { sinkChannel.setWriteListener(new ChannelListener<ConduitStreamSinkChannel>() { public void handleEvent(final ConduitStreamSinkChannel sinkChannel) { try { if (sinkChannel.flush()) { try { log.info("server closing channel"); sinkChannel.shutdownWrites(); } catch (Throwable t) { t.printStackTrace(); throw new RuntimeException(t); } return; } } catch (Throwable t) { t.printStackTrace(); throw new RuntimeException(t); } } }); return; } } buffer.rewind(); } catch (ClosedChannelException e) { sinkChannel.shutdownWrites(); throw e; } } catch (Throwable t) { t.printStackTrace(); throw new RuntimeException(t); } } }); sourceChannel.resumeReads(); sinkChannel.resumeWrites(); } }); assertEquals(serverSent.get(), clientReceived.get()); assertEquals(clientSent.get(), serverReceived.get()); } }