/* * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS HEADER. * * Copyright (c) 2015 Oracle and/or its affiliates. All rights reserved. * * The contents of this file are subject to the terms of either the GNU * General Public License Version 2 only ("GPL") or the Common Development * and Distribution License("CDDL") (collectively, the "License"). You * may not use this file except in compliance with the License. You can * obtain a copy of the License at * http://glassfish.java.net/public/CDDL+GPL_1_1.html * or packager/legal/LICENSE.txt. See the License for the specific * language governing permissions and limitations under the License. * * When distributing the software, include this License Header Notice in each * file and include the License file at packager/legal/LICENSE.txt. * * GPL Classpath Exception: * Oracle designates this particular file as subject to the "Classpath" * exception as provided by Oracle in the GPL Version 2 section of the License * file that accompanied this code. * * Modifications: * If applicable, add the following below the License Header, with the fields * enclosed by brackets [] replaced by your own identifying information: * "Portions Copyright [year] [name of copyright owner]" * * Contributor(s): * If you wish your version of this file to be governed by only the CDDL or * only the GPL Version 2, indicate your decision by adding "[Contributor] * elects to include this software in this distribution under the [CDDL or GPL * Version 2] license." If you don't indicate a single choice of license, a * recipient has the option to distribute your version of this file under * either the CDDL, the GPL Version 2 or to extend the choice of license to * its licensees as provided above. However, if you add GPL Version 2 code * and therefore, elected the GPL Version 2 license, then the option applies * only if the new code is made subject to such option by the copyright * holder. */ package org.glassfish.jersey.jdk.connector; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.nio.ByteBuffer; import java.nio.channels.WritePendingException; import java.util.LinkedList; import java.util.Queue; import java.util.concurrent.Callable; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import org.junit.Test; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static junit.framework.Assert.assertNotNull; /** * @author Petr Janouch (petr.janouch at oracle.com) */ public class AsynchronousBodyOutputStreamTest { @Test public void testBasicAsyncWrite() throws IOException { doTestAsyncWrite(false); } @Test public void testBasicAsyncArrayWrite() throws IOException { doTestAsyncWrite(true); } @Test public void testSetListenerAfterOpeningStream() throws IOException { TestStream stream = new TestStream(6); MockTransportFilter transportFilter = new MockTransportFilter(); String msg1 = "AAAAAAAAAAAAAAAAAAAA"; String msg2 = "BBBBBBBBBBBBB"; String msg3 = "CCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCC"; TestWriteListener writeListener = new TestWriteListener(stream, -1); writeListener.write(msg1); stream.open(transportFilter); stream.setWriteListener(writeListener); writeListener.write(msg2); writeListener.write(msg3); stream.close(); if (writeListener.getError() != null) { writeListener.getError().printStackTrace(); fail(); } assertEquals(msg1 + msg2 + msg3, transportFilter.getWrittenData()); } @Test public void testTestAsyncWriteWithDelay() throws IOException { doTestAsyncWriteWithDelay(false); } @Test public void testTestAsyncWriteArrayWithDelay() throws IOException { doTestAsyncWriteWithDelay(true); } @Test public void testAsyncFlush() { TestStream stream = new TestStream(6); String msg1 = "AAAAAAAAAAAAAAAAAAAA"; String msg2 = "BBBBBBBBBBBBB"; String msg3 = "CCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCC"; String msg4 = "DDDDDDD"; TestWriteListener writeListener = new TestWriteListener(stream, -1); stream.setWriteListener(writeListener); MockTransportFilter transportFilter = new MockTransportFilter(); writeListener.write(msg1); transportFilter.block(); stream.open(transportFilter); writeListener.flush(); // test someone going crazy with flush writeListener.flush(); transportFilter.unblock(); transportFilter.block(); writeListener.write(msg2); transportFilter.unblock(); writeListener.flush(); transportFilter.block(); writeListener.write(msg3); writeListener.flush(); writeListener.write(msg4); writeListener.flush(); writeListener.close(); transportFilter.unblock(); if (writeListener.getError() != null) { writeListener.getError().printStackTrace(); fail(); } assertEquals(msg1 + msg2 + msg3 + msg4, transportFilter.getWrittenData()); } @Test public void testAsyncException() { TestStream stream = new TestStream(6); String msg1 = "AAAAAAAAAAAAAAAAAAAA"; TestWriteListener writeListener = new TestWriteListener(stream, -1); stream.setWriteListener(writeListener); MockTransportFilter transportFilter = new MockTransportFilter(); stream.open(transportFilter); Throwable t = new Throwable(); transportFilter.setException(t); writeListener.write(msg1); assertNotNull(writeListener.getError()); assertTrue(t == writeListener.getError()); } @Test public void testBasicSyncWrite() throws IOException, InterruptedException, TimeoutException, ExecutionException { doTestSyncWrite(false); } @Test public void testBasicSyncArrayWrite() throws IOException, InterruptedException, TimeoutException, ExecutionException { doTestSyncWrite(true); } @Test public void testSyncWriteWithDelay() throws IOException, InterruptedException, TimeoutException, ExecutionException { doTestSyncWriteWithDelay(false); } @Test public void testSyncArrayWriteWithDelay() throws IOException, InterruptedException, TimeoutException, ExecutionException { doTestSyncWriteWithDelay(true); } @Test public void testAsyncWriteWhenNotReady() throws IOException { TestStream stream = new TestStream(6); TestWriteListener writeListener = new TestWriteListener(stream, -1); stream.setWriteListener(writeListener); try { stream.write((byte) 'a'); fail(); } catch (IllegalStateException e) { // expected } } @Test public void testUnsupportedSync() { final TestStream stream = new TestStream(10); stream.open(new MockTransportFilter()); try { // touch this stream to make it synchronous stream.write((byte) 'a'); } catch (IOException e) { e.printStackTrace(); fail(); } assertUnsupported(() -> { stream.isReady(); return null; }); assertUnsupported(() -> { stream.setWriteListener(new TestWriteListener(stream)); return null; }); } @Test public void testSyncException() throws IOException { TestStream stream = new TestStream(1); MockTransportFilter transportFilter = new MockTransportFilter(); stream.open(transportFilter); Throwable t = new Throwable(); transportFilter.setException(t); try { stream.write("aaa".getBytes()); fail(); } catch (IOException e) { assertTrue(t == e.getCause()); } } private void doTestSyncWrite(final boolean useArray) throws IOException, InterruptedException, TimeoutException, ExecutionException { ExecutorService executor = Executors.newSingleThreadExecutor(); try { final TestStream stream = new TestStream(6); final String msg1 = "AAAAAAAAAAAAAAAAAAAA"; String msg2 = "BBBBBBBBBBBBB"; String msg3 = "CCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCC"; MockTransportFilter transportFilter = new MockTransportFilter(); Future<Boolean> future = executor.submit(() -> { try { writeToStream(stream, msg1, useArray); } catch (IOException e) { e.printStackTrace(); return false; } return true; }); // test that synchronous write really blocks until the stream is opened assertFalse(future.isDone()); stream.open(transportFilter); assertTrue(future.get(300, TimeUnit.SECONDS)); writeToStream(stream, msg2, useArray); writeToStream(stream, msg3, useArray); stream.close(); assertEquals(msg1 + msg2 + msg3, transportFilter.getWrittenData()); } finally { executor.shutdownNow(); } } private void doTestAsyncWriteWithDelay(boolean useArray) throws IOException { int arraySize = -1; if (useArray) { arraySize = 10; } TestStream stream = new TestStream(6); String msg1 = "AAAAAAAAAAAAAAAAAAAA"; String msg2 = "BBBBBBBBBBBBB"; String msg3 = "CCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCC"; String msg4 = "DDDDDDD"; TestWriteListener writeListener = new TestWriteListener(stream, arraySize); stream.setWriteListener(writeListener); MockTransportFilter transportFilter = new MockTransportFilter(); writeListener.write(msg1); transportFilter.block(); stream.open(transportFilter); transportFilter.unblock(); transportFilter.block(); writeListener.write(msg2); transportFilter.unblock(); transportFilter.block(); writeListener.write(msg3); writeListener.write(msg4); writeListener.close(); transportFilter.unblock(); if (writeListener.getError() != null) { writeListener.getError().printStackTrace(); fail(); } assertEquals(msg1 + msg2 + msg3 + msg4, transportFilter.getWrittenData()); } private void writeToStream(TestStream stream, String msg, boolean useArray) throws IOException { if (useArray) { stream.write(msg.getBytes()); } else { byte[] bytes = msg.getBytes(); for (byte b : bytes) { stream.write(b); } } } private void doTestSyncWriteWithDelay(final boolean useArray) throws IOException, InterruptedException, TimeoutException, ExecutionException { ExecutorService executor = Executors.newSingleThreadExecutor(); try { final TestStream stream = new TestStream(6); final String msg1 = "AAAAAAAAAAAAAAAAAAAA"; final String msg2 = "BBBBBBBBBBBBB"; final String msg3 = "CCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCC"; final String msg4 = "DDDDDDD"; MockTransportFilter transportFilter = new MockTransportFilter(); stream.open(transportFilter); final CountDownLatch blockLatch1 = new CountDownLatch(1); transportFilter.block(blockLatch1); Future<Boolean> future = executor.submit(() -> { try { writeToStream(stream, msg1, useArray); } catch (IOException e) { e.printStackTrace(); return false; } return true; }); assertTrue(blockLatch1.await(5, TimeUnit.SECONDS)); assertFalse(future.isDone()); transportFilter.unblock(); assertTrue(future.get(5, TimeUnit.SECONDS)); final CountDownLatch blockLatch2 = new CountDownLatch(1); transportFilter.block(blockLatch2); future = executor.submit(() -> { try { writeToStream(stream, msg2, useArray); writeToStream(stream, msg3, useArray); } catch (IOException e) { e.printStackTrace(); return false; } return true; }); assertTrue(blockLatch2.await(5, TimeUnit.SECONDS)); assertFalse(future.isDone()); transportFilter.unblock(); assertTrue(future.get(5, TimeUnit.SECONDS)); final CountDownLatch blockLatch3 = new CountDownLatch(1); transportFilter.block(blockLatch3); future = executor.submit(() -> { try { writeToStream(stream, msg4, useArray); stream.close(); } catch (IOException e) { e.printStackTrace(); return false; } return true; }); assertTrue(blockLatch3.await(5, TimeUnit.SECONDS)); assertFalse(future.isDone()); transportFilter.unblock(); assertTrue(future.get(5, TimeUnit.SECONDS)); assertEquals(msg1 + msg2 + msg3 + msg4, transportFilter.getWrittenData()); } finally { executor.shutdownNow(); } } private void doTestAsyncWrite(boolean useArray) throws IOException { int arraySize = -1; if (useArray) { arraySize = 10; } TestStream stream = new TestStream(6); String msg1 = "AAAAAAAAAAAAAAAAAAAA"; String msg2 = "BBBBBBBBBBBBB"; String msg3 = "CCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCC"; TestWriteListener writeListener = new TestWriteListener(stream, arraySize); stream.setWriteListener(writeListener); MockTransportFilter transportFilter = new MockTransportFilter(); writeListener.write(msg1); stream.open(transportFilter); writeListener.write(msg2); writeListener.write(msg3); stream.close(); if (writeListener.getError() != null) { writeListener.getError().printStackTrace(); fail(); } assertEquals(msg1 + msg2 + msg3, transportFilter.getWrittenData()); } private static void assertUnsupported(Callable unsupported) { try { unsupported.call(); fail(); } catch (UnsupportedOperationException e) { // expected } catch (Exception e) { e.printStackTrace(); fail(); } } private static class TestWriteListener implements WriteListener { private static final ByteBuffer CLOSE = ByteBuffer.allocate(0); private static final ByteBuffer FLUSH = ByteBuffer.allocate(0); private final ChunkedBodyOutputStream outputStream; private final Queue<ByteBuffer> message = new LinkedList<>(); private final int outputArraySize; private volatile boolean listenerCallExpected = true; private volatile Throwable error; TestWriteListener(ChunkedBodyOutputStream outputStream) { this(outputStream, -1); } TestWriteListener(ChunkedBodyOutputStream outputStream, int outputArraySize) { this.outputStream = outputStream; this.outputArraySize = outputArraySize; } void write(String message) { byte[] bytes = message.getBytes(); this.message.add(ByteBuffer.wrap(bytes)); doWrite(); } void close() { message.add(CLOSE); doWrite(); } void flush() { message.add(FLUSH); doWrite(); } @Override public void onWritePossible() { if (!listenerCallExpected) { fail(); } listenerCallExpected = false; doWrite(); } private void doWrite() { while (message.peek() != null && (outputStream.isReady() || message.peek() == CLOSE || message.peek() == FLUSH)) { try { ByteBuffer headBuffer = message.peek(); if (headBuffer == CLOSE) { outputStream.close(); message.poll(); continue; } if (headBuffer == FLUSH) { outputStream.flush(); message.poll(); continue; } if (outputArraySize == -1) { outputStream.write(headBuffer.get()); } else { int arraySize = outputArraySize; if (headBuffer.remaining() < arraySize) { arraySize = headBuffer.remaining(); } byte[] outputArray = new byte[arraySize]; headBuffer.get(outputArray); outputStream.write(outputArray); } if (!headBuffer.hasRemaining()) { message.poll(); } } catch (IOException e) { error = e; } } if (!outputStream.isReady()) { listenerCallExpected = true; } } @Override public void onError(Throwable t) { error = t; } public Throwable getError() { return error; } } private static class TestStream extends ChunkedBodyOutputStream { TestStream(int bufferSize) { super(bufferSize); } @Override protected ByteBuffer encodeToHttp(ByteBuffer byteBuffer) { return byteBuffer; } } private static class MockTransportFilter extends Filter<ByteBuffer, Void, Void, Void> { private final ByteArrayOutputStream writtenData = new ByteArrayOutputStream(); private volatile boolean pendingWrite = false; private volatile boolean block = false; private volatile CountDownLatch blockLatch; private volatile CompletionHandler<ByteBuffer> completionHandler; private volatile Throwable exception; MockTransportFilter() { super(null); } @Override void write(ByteBuffer data, CompletionHandler<ByteBuffer> completionHandler) { if (pendingWrite) { completionHandler.failed(new WritePendingException()); } pendingWrite = true; while (data.hasRemaining()) { writtenData.write(data.get()); } if (block) { this.completionHandler = completionHandler; if (blockLatch != null) { blockLatch.countDown(); } return; } pendingWrite = false; if (exception == null) { completionHandler.completed(data); } else { completionHandler.failed(exception); } } String getWrittenData() { return new String(writtenData.toByteArray()); } void block(CountDownLatch blockLatch) { this.blockLatch = blockLatch; block = true; } void block() { block = true; } void unblock() { block = false; pendingWrite = false; completionHandler.completed(null); completionHandler = null; } public void setException(Throwable exception) { this.exception = exception; } } }