/* * Copyright 2014 The Netty Project * * The Netty Project licenses this file to you under the Apache License, * version 2.0 (the "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at: * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. */ package io.netty.testsuite.transport.socket; import io.netty.bootstrap.Bootstrap; import io.netty.bootstrap.ServerBootstrap; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.netty.channel.Channel; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelPipeline; import io.netty.channel.SimpleChannelInboundHandler; import io.netty.handler.logging.LogLevel; import io.netty.handler.logging.LoggingHandler; import io.netty.handler.ssl.JdkSslClientContext; import io.netty.handler.ssl.JdkSslServerContext; import io.netty.handler.ssl.OpenSsl; import io.netty.handler.ssl.OpenSslClientContext; import io.netty.handler.ssl.OpenSslServerContext; import io.netty.handler.ssl.SslContext; import io.netty.handler.ssl.util.SelfSignedCertificate; import io.netty.util.ReferenceCountUtil; import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.junit.runners.Parameterized.Parameters; import java.io.File; import java.io.IOException; import java.security.cert.CertificateException; import java.util.ArrayList; import java.util.Collection; import java.util.List; import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicReference; import static org.junit.Assert.*; @RunWith(Parameterized.class) public class SocketSslGreetingTest extends AbstractSocketTest { private static final InternalLogger logger = InternalLoggerFactory.getInstance(SocketSslGreetingTest.class); private static final LogLevel LOG_LEVEL = LogLevel.TRACE; private static final File CERT_FILE; private static final File KEY_FILE; private final ByteBuf greeting = ReferenceCountUtil.releaseLater(Unpooled.buffer().writeByte('a')); static { SelfSignedCertificate ssc; try { ssc = new SelfSignedCertificate(); } catch (CertificateException e) { throw new Error(e); } CERT_FILE = ssc.certificate(); KEY_FILE = ssc.privateKey(); } @Parameters(name = "{index}: serverEngine = {0}, clientEngine = {1}") public static Collection<Object[]> data() throws Exception { List<SslContext> serverContexts = new ArrayList<SslContext>(); serverContexts.add(new JdkSslServerContext(CERT_FILE, KEY_FILE)); List<SslContext> clientContexts = new ArrayList<SslContext>(); clientContexts.add(new JdkSslClientContext(CERT_FILE)); boolean hasOpenSsl = OpenSsl.isAvailable(); if (hasOpenSsl) { serverContexts.add(new OpenSslServerContext(CERT_FILE, KEY_FILE)); clientContexts.add(new OpenSslClientContext(CERT_FILE)); } else { logger.warn("OpenSSL is unavailable and thus will not be tested.", OpenSsl.unavailabilityCause()); } List<Object[]> params = new ArrayList<Object[]>(); for (SslContext sc: serverContexts) { for (SslContext cc: clientContexts) { params.add(new Object[] { sc, cc }); } } return params; } private final SslContext serverCtx; private final SslContext clientCtx; public SocketSslGreetingTest(SslContext serverCtx, SslContext clientCtx) { this.serverCtx = serverCtx; this.clientCtx = clientCtx; } // Test for https://github.com/netty/netty/pull/2437 @Test(timeout = 30000) public void testSslGreeting() throws Throwable { run(); } public void testSslGreeting(ServerBootstrap sb, Bootstrap cb) throws Throwable { final ServerHandler sh = new ServerHandler(); final ClientHandler ch = new ClientHandler(); sb.childHandler(new ChannelInitializer<Channel>() { @Override public void initChannel(Channel sch) throws Exception { ChannelPipeline p = sch.pipeline(); p.addLast(serverCtx.newHandler(sch.alloc())); p.addLast(new LoggingHandler(LOG_LEVEL)); p.addLast(sh); } }); cb.handler(new ChannelInitializer<Channel>() { @Override public void initChannel(Channel sch) throws Exception { ChannelPipeline p = sch.pipeline(); p.addLast(clientCtx.newHandler(sch.alloc())); p.addLast(new LoggingHandler(LOG_LEVEL)); p.addLast(ch); } }); Channel sc = sb.bind().sync().channel(); Channel cc = cb.connect().sync().channel(); ch.latch.await(); sh.channel.close().awaitUninterruptibly(); cc.close().awaitUninterruptibly(); sc.close().awaitUninterruptibly(); if (sh.exception.get() != null && !(sh.exception.get() instanceof IOException)) { throw sh.exception.get(); } if (ch.exception.get() != null && !(ch.exception.get() instanceof IOException)) { throw ch.exception.get(); } if (sh.exception.get() != null) { throw sh.exception.get(); } if (ch.exception.get() != null) { throw ch.exception.get(); } } private class ClientHandler extends SimpleChannelInboundHandler<ByteBuf> { final AtomicReference<Throwable> exception = new AtomicReference<Throwable>(); final CountDownLatch latch = new CountDownLatch(1); @Override public void channelRead0(ChannelHandlerContext ctx, ByteBuf buf) throws Exception { assertEquals(greeting, buf); latch.countDown(); ctx.close(); } @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { if (logger.isWarnEnabled()) { logger.warn("Unexpected exception from the client side", cause); } exception.compareAndSet(null, cause); ctx.close(); } } private class ServerHandler extends SimpleChannelInboundHandler<String> { volatile Channel channel; final AtomicReference<Throwable> exception = new AtomicReference<Throwable>(); @Override protected void channelRead0(ChannelHandlerContext ctx, String msg) throws Exception { // discard } @Override public void channelActive(ChannelHandlerContext ctx) throws Exception { channel = ctx.channel(); channel.writeAndFlush(greeting.duplicate().retain()); } @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { if (logger.isWarnEnabled()) { logger.warn("Unexpected exception from the server side", cause); } exception.compareAndSet(null, cause); ctx.close(); } } }