/* * Copyright (c) 2002-2017 "Neo Technology," * Network Engine for Objects in Lund AB [http://neotechnology.com] * * This file is part of Neo4j. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.driver.v1.integration; import org.junit.After; import org.junit.Before; import org.junit.Test; import java.io.IOException; import java.net.ServerSocket; import java.net.Socket; import java.net.SocketException; import java.nio.ByteBuffer; import java.nio.channels.ByteChannel; import java.security.KeyStore; import java.security.cert.CertificateException; import java.security.cert.X509Certificate; import java.util.concurrent.ExecutorService; import java.util.concurrent.Future; import javax.net.ssl.KeyManagerFactory; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLServerSocketFactory; import javax.net.ssl.TrustManager; import javax.net.ssl.X509TrustManager; import static java.util.concurrent.Executors.newSingleThreadExecutor; import static java.util.concurrent.TimeUnit.SECONDS; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; import static org.neo4j.driver.v1.util.DaemonThreadFactory.daemon; /** * This tests that the TLSSocketChannel handles every combination of network buffer sizes that we * can reasonably expect to see in the wild. It exhaustively tests power-of-two sizes up to 2^16 * for the following variables: * <p> * - Network frame size * - Bolt message size * - Read buffer size * <p> * It tests every possible combination, and it does this currently only for the read path, expanding * to the write path as well would be useful. For each size, it sets up a TLS server and tests the * handshake, transferring the data, and verifying the data is correct after decryption. */ public abstract class TLSSocketChannelFragmentation { SSLContext sslCtx; ServerSocket serverSocket; volatile byte[] blobOfData; private ExecutorService serverExecutor; private Future<?> serverTask; @Before public void setUp() throws Throwable { sslCtx = createSSLContext(); serverSocket = createServerSocket( sslCtx ); serverExecutor = createServerExecutor(); serverTask = launchServer( serverExecutor, createServerRunnable( sslCtx ) ); } @After public void tearDown() throws Exception { serverSocket.close(); serverExecutor.shutdownNow(); assertTrue( "Unable to terminate server socket", serverExecutor.awaitTermination( 30, SECONDS ) ); assertNull( serverTask.get( 30, SECONDS ) ); } @Test public void shouldHandleFuzziness() throws Throwable { // Given int networkFrameSize, userBufferSize, blobOfDataSize; for ( int dataBlobMagnitude = 1; dataBlobMagnitude < 16; dataBlobMagnitude += 2 ) { blobOfDataSize = (int) Math.pow( 2, dataBlobMagnitude ); blobOfData = blobOfData( blobOfDataSize ); for ( int frameSizeMagnitude = 1; frameSizeMagnitude < 16; frameSizeMagnitude += 2 ) { networkFrameSize = (int) Math.pow( 2, frameSizeMagnitude ); for ( int userBufferMagnitude = 1; userBufferMagnitude < 16; userBufferMagnitude += 2 ) { userBufferSize = (int) Math.pow( 2, userBufferMagnitude ); testForBufferSizes( blobOfData, networkFrameSize, userBufferSize ); } } } } protected abstract void testForBufferSizes( byte[] blobOfData, int networkFrameSize, int userBufferSize ) throws Exception; protected abstract Runnable createServerRunnable( SSLContext sslContext ) throws IOException; private static SSLContext createSSLContext() throws Exception { KeyStore ks = KeyStore.getInstance( "JKS" ); char[] password = "password".toCharArray(); ks.load( TLSSocketChannelFragmentation.class.getResourceAsStream( "/keystore.jks" ), password ); KeyManagerFactory kmf = KeyManagerFactory.getInstance( "SunX509" ); kmf.init( ks, password ); SSLContext sslCtx = SSLContext.getInstance( "TLS" ); sslCtx.init( kmf.getKeyManagers(), new TrustManager[]{new X509TrustManager() { @Override public void checkClientTrusted( X509Certificate[] chain, String authType ) throws CertificateException { } @Override public void checkServerTrusted( X509Certificate[] chain, String authType ) throws CertificateException { } @Override public X509Certificate[] getAcceptedIssuers() { return null; } }}, null ); return sslCtx; } private static ServerSocket createServerSocket( SSLContext sslContext ) throws IOException { SSLServerSocketFactory ssf = sslContext.getServerSocketFactory(); return ssf.createServerSocket( 0 ); } private ExecutorService createServerExecutor() { return newSingleThreadExecutor( daemon( getClass().getSimpleName() + "-Server-" ) ); } private Future<?> launchServer( ExecutorService executor, Runnable runnable ) { return executor.submit( runnable ); } static byte[] blobOfData( int dataBlobSize ) { byte[] blobOfData = new byte[dataBlobSize]; // If the blob is all zeros, we'd miss data corruption problems in assertions, so // fill the data blob with different values. for ( int i = 0; i < blobOfData.length; i++ ) { blobOfData[i] = (byte) (i % 128); } return blobOfData; } static Socket accept( ServerSocket serverSocket ) throws IOException { try { return serverSocket.accept(); } catch ( SocketException e ) { String message = e.getMessage(); if ( "Socket closed".equalsIgnoreCase( message ) ) { return null; } throw e; } } /** * Delegates to underlying channel, but only reads up to the set amount at a time, used to emulate * different network frame sizes in this test. */ protected static class LittleAtATimeChannel implements ByteChannel { private final ByteChannel delegate; private final int maxFrameSize; LittleAtATimeChannel( ByteChannel delegate, int maxFrameSize ) { this.delegate = delegate; this.maxFrameSize = maxFrameSize; } @Override public boolean isOpen() { return delegate.isOpen(); } @Override public void close() throws IOException { delegate.close(); } @Override public int write( ByteBuffer src ) throws IOException { int originalLimit = src.limit(); try { src.limit( Math.min( src.limit(), src.position() + maxFrameSize ) ); return delegate.write( src ); } finally { src.limit( originalLimit ); } } @Override public int read( ByteBuffer dst ) throws IOException { int originalLimit = dst.limit(); try { dst.limit( Math.min( dst.limit(), dst.position() + maxFrameSize ) ); return delegate.read( dst ); } finally { dst.limit( originalLimit ); } } } }