/** * Copyright 2016 LinkedIn Corp. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ package com.github.ambry.network; import com.codahale.metrics.MetricRegistry; import com.github.ambry.commons.SSLFactory; import com.github.ambry.commons.TestSSLUtils; import com.github.ambry.config.SSLConfig; import java.io.DataInputStream; import java.io.File; import java.io.IOException; import java.io.InputStream; import java.nio.ByteBuffer; import java.util.Random; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLSocketFactory; import org.junit.After; import org.junit.AfterClass; import org.junit.Assert; import org.junit.Before; import org.junit.BeforeClass; import org.junit.Test; import static org.junit.Assert.*; public class SSLBlockingChannelTest { private static SSLFactory sslFactory; private static SSLConfig clientSSLConfig; private static SSLSocketFactory sslSocketFactory; private static EchoServer sslEchoServer; private static String hostName = "localhost"; private static int sslPort = 18284; /** * Run only once for all tests */ @BeforeClass public static void initializeTests() throws Exception { File trustStoreFile = File.createTempFile("truststore", ".jks"); SSLConfig sslConfig = new SSLConfig(TestSSLUtils.createSslProps("DC1,DC2,DC3", SSLFactory.Mode.SERVER, trustStoreFile, "server")); clientSSLConfig = new SSLConfig(TestSSLUtils.createSslProps("DC1,DC2,DC3", SSLFactory.Mode.CLIENT, trustStoreFile, "client")); sslFactory = new SSLFactory(sslConfig); sslEchoServer = new EchoServer(sslFactory, sslPort); sslEchoServer.start(); //client sslFactory = new SSLFactory(clientSSLConfig); SSLContext sslContext = sslFactory.getSSLContext(); sslSocketFactory = sslContext.getSocketFactory(); } /** * Run only once for all tests */ @AfterClass public static void finalizeTests() throws Exception { int serverExceptionCount = sslEchoServer.getExceptionCount(); assertEquals(serverExceptionCount, 0); sslEchoServer.close(); } @Before public void setup() throws Exception { } @After public void teardown() throws Exception { } @Test public void testSendAndReceive() throws Exception { BlockingChannel channel = new SSLBlockingChannel(hostName, sslPort, new MetricRegistry(), 10000, 10000, 10000, 2000, sslSocketFactory, clientSSLConfig); sendAndReceive(channel); channel.disconnect(); } @Test public void testRenegotiation() throws Exception { BlockingChannel channel = new SSLBlockingChannel(hostName, sslPort, new MetricRegistry(), 10000, 10000, 10000, 2000, sslSocketFactory, clientSSLConfig); sendAndReceive(channel); sslEchoServer.renegotiate(); sendAndReceive(channel); channel.disconnect(); } @Test public void testWrongPortConnection() throws Exception { BlockingChannel channel = new SSLBlockingChannel(hostName, sslPort + 1, new MetricRegistry(), 10000, 10000, 10000, 2000, sslSocketFactory, clientSSLConfig); try { // send request channel.connect(); fail("should have thrown!"); } catch (IOException e) { } } private void sendAndReceive(BlockingChannel channel) throws Exception { long blobSize = 1028; byte[] bytesToSend = new byte[(int) blobSize]; new Random().nextBytes(bytesToSend); ByteBuffer byteBufferToSend = ByteBuffer.wrap(bytesToSend); byteBufferToSend.putLong(0, blobSize); BoundedByteBufferSend bufferToSend = new BoundedByteBufferSend(byteBufferToSend); // send request channel.connect(); channel.send(bufferToSend); // receive response InputStream streamResponse = channel.receive().getInputStream(); DataInputStream input = new DataInputStream(streamResponse); byte[] bytesReceived = new byte[(int) blobSize - 8]; input.readFully(bytesReceived); for (int i = 0; i < blobSize - 8; i++) { Assert.assertEquals(bytesToSend[8 + i], bytesReceived[i]); } } }