/* * Copyright (c) 2002-2017 "Neo Technology," * Network Engine for Objects in Lund AB [http://neotechnology.com] * * This file is part of Neo4j. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.driver.internal.net; import org.junit.Test; import java.io.IOException; import java.nio.ByteBuffer; import java.nio.channels.ReadableByteChannel; import java.util.Arrays; import java.util.Random; import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.MatcherAssert.assertThat; public class BufferingChunkedInputFuzzTest { @Test public void shouldHandleAllMessageBoundaries() throws IOException { byte[] expected = new byte[256]; for ( int i = 0; i < 256; i++ ) { expected[i] = (byte) (Byte.MIN_VALUE + i); } for ( int i = 0; i < 256; i++ ) { BufferingChunkedInput input = new BufferingChunkedInput( splitChannel( expected, i ) ); byte[] dst = new byte[256]; input.readBytes( dst, 0, dst.length ); assertThat( dst, equalTo( expected ) ); } } @Test public void messageSizeFuzzTest() throws IOException { int maxSize = 1 << 16; // 0x10000 Random random = new Random(); for ( int i = 0; i < 1000; i++) { int size = random.nextInt( maxSize - 1 ) + 1; //[0, 0xFFFF - 1] + 1 = [1, 0xFFFF] byte[] expected = new byte[size]; Arrays.fill(expected, (byte)42); BufferingChunkedInput input = new BufferingChunkedInput( channel( expected, 0, size ) ); byte[] dst = new byte[size]; input.readBytes( dst, 0, size); assertThat( dst, equalTo( expected ) ); } } ReadableByteChannel splitChannel( byte[] bytes, int split ) { assert split >= 0 && split < bytes.length; assert split <= Short.MAX_VALUE; assert bytes.length <= Short.MAX_VALUE; return packets( channel( bytes, 0, split ), channel( bytes, split, bytes.length ) ); } ReadableByteChannel channel( byte[] bytes, int from, int to ) { int size = to - from; ByteBuffer packet = ByteBuffer.allocate( 4 + size ); packet.put( (byte) ((size >> 8) & 0xFF) ); packet.put( (byte) (size & 0xFF) ); for ( int i = from; i < to; i++ ) { packet.put( bytes[i] ); } packet.put( (byte) 0 ); packet.put( (byte) 0 ); packet.flip(); return asChannel( packet ); } private ReadableByteChannel packets( final ReadableByteChannel... channels ) { return new ReadableByteChannel() { private int index = 0; @Override public int read( ByteBuffer dst ) throws IOException { return channels[index++].read( dst ); } @Override public boolean isOpen() { return false; } @Override public void close() throws IOException { } }; } private ReadableByteChannel asChannel( final ByteBuffer buffer ) { return new ReadableByteChannel() { @Override public int read( ByteBuffer dst ) throws IOException { int len = Math.min( dst.remaining(), buffer.remaining() ); for ( int i = 0; i < len; i++ ) { dst.put( buffer.get() ); } return len; } @Override public boolean isOpen() { return true; } @Override public void close() throws IOException { } }; } }