/* * Copyright (C) 2014 Square, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package okhttp3.internal.ws; import java.io.EOFException; import java.io.IOException; import java.util.Random; import okhttp3.RequestBody; import okhttp3.internal.Util; import okio.Buffer; import okio.BufferedSink; import okio.ByteString; import okio.Okio; import okio.Sink; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TestRule; import org.junit.runner.Description; import org.junit.runners.model.Statement; import static okhttp3.TestUtil.repeat; import static okhttp3.internal.ws.WebSocketProtocol.OPCODE_BINARY; import static okhttp3.internal.ws.WebSocketProtocol.OPCODE_TEXT; import static okhttp3.internal.ws.WebSocketProtocol.PAYLOAD_BYTE_MAX; import static okhttp3.internal.ws.WebSocketProtocol.PAYLOAD_SHORT_MAX; import static okhttp3.internal.ws.WebSocketProtocol.toggleMask; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; public final class WebSocketWriterTest { private final Buffer data = new Buffer(); private final Random random = new Random(0); /** * Check all data as verified inside of the test. We do this in a rule instead of @After so that * exceptions thrown from the test do not cause this check to fail. */ @Rule public final TestRule noDataLeftBehind = new TestRule() { @Override public Statement apply(final Statement base, Description description) { return new Statement() { @Override public void evaluate() throws Throwable { base.evaluate(); assertEquals("Data not empty", "", data.readByteString().hex()); } }; } }; // Mutually exclusive. Use the one corresponding to the peer whose behavior you wish to test. private final WebSocketWriter serverWriter = new WebSocketWriter(false, data, random); private final WebSocketWriter clientWriter = new WebSocketWriter(true, data, random); @Test public void serverTextMessage() throws IOException { BufferedSink sink = Okio.buffer(serverWriter.newMessageSink(OPCODE_TEXT, -1)); sink.writeUtf8("Hel").flush(); assertData("010348656c"); sink.writeUtf8("lo").flush(); assertData("00026c6f"); sink.close(); assertData("8000"); } @Test public void serverSmallBufferedPayloadWrittenAsOneFrame() throws IOException { int length = 5; byte[] bytes = binaryData(length); RequestBody body = RequestBody.create(null, bytes); BufferedSink sink = Okio.buffer(serverWriter.newMessageSink(OPCODE_TEXT, length)); body.writeTo(sink); sink.close(); assertData("8105"); assertData(bytes); assertTrue(data.exhausted()); } @Test public void serverLargeBufferedPayloadWrittenAsOneFrame() throws IOException { int length = 12345; byte[] bytes = binaryData(length); RequestBody body = RequestBody.create(null, bytes); BufferedSink sink = Okio.buffer(serverWriter.newMessageSink(OPCODE_TEXT, length)); body.writeTo(sink); sink.close(); assertData("817e"); assertData(Util.format("%04x", length)); assertData(bytes); assertTrue(data.exhausted()); } @Test public void serverLargeNonBufferedPayloadWrittenAsMultipleFrames() throws IOException { int length = 100_000; Buffer bytes = new Buffer().write(binaryData(length)); BufferedSink sink = Okio.buffer(serverWriter.newMessageSink(OPCODE_TEXT, length)); Buffer body = bytes.clone(); sink.write(body.readByteString(20_000)); sink.write(body.readByteString(20_000)); sink.write(body.readByteString(20_000)); sink.write(body.readByteString(20_000)); sink.write(body.readByteString(20_000)); sink.close(); assertData("017e4000"); assertData(bytes.readByteArray(16_384)); assertData("007e4000"); assertData(bytes.readByteArray(16_384)); assertData("007e6000"); assertData(bytes.readByteArray(24_576)); assertData("007e4000"); assertData(bytes.readByteArray(16_384)); assertData("007e6000"); assertData(bytes.readByteArray(24_576)); assertData("807e06a0"); assertData(bytes.readByteArray(1_696)); assertTrue(data.exhausted()); } @Test public void closeFlushes() throws IOException { BufferedSink sink = Okio.buffer(serverWriter.newMessageSink(OPCODE_TEXT, -1)); sink.writeUtf8("Hel").flush(); assertData("010348656c"); sink.writeUtf8("lo").close(); assertData("80026c6f"); } @Test public void noWritesAfterClose() throws IOException { Sink sink = serverWriter.newMessageSink(OPCODE_TEXT, -1); sink.close(); assertData("8100"); Buffer payload = new Buffer().writeUtf8("Hello"); try { // Write to the unbuffered sink as BufferedSink keeps its own closed state. sink.write(payload, payload.size()); fail(); } catch (IOException e) { assertEquals("closed", e.getMessage()); } } @Test public void clientTextMessage() throws IOException { BufferedSink sink = Okio.buffer(clientWriter.newMessageSink(OPCODE_TEXT, -1)); sink.writeUtf8("Hel").flush(); assertData("018360b420bb28d14c"); sink.writeUtf8("lo").flush(); assertData("00823851d9d4543e"); sink.close(); assertData("80807acb933d"); } @Test public void serverBinaryMessage() throws IOException { BufferedSink sink = Okio.buffer(serverWriter.newMessageSink(OPCODE_BINARY, -1)); sink.write(binaryData(50)).flush(); assertData("0232"); assertData(binaryData(50)); sink.write(binaryData(50)).flush(); assertData("0032"); assertData(binaryData(50)); sink.close(); assertData("8000"); } @Test public void serverMessageLengthShort() throws IOException { Sink sink = serverWriter.newMessageSink(OPCODE_BINARY, -1); // Create a payload which will overflow the normal payload byte size. Buffer payload = new Buffer(); while (payload.completeSegmentByteCount() <= PAYLOAD_BYTE_MAX) { payload.writeByte('0'); } long byteCount = payload.completeSegmentByteCount(); // Write directly to the unbuffered sink. This ensures it will become single frame. sink.write(payload.clone(), byteCount); assertData("027e"); // 'e' == 4-byte follow-up length. assertData(Util.format("%04X", payload.completeSegmentByteCount())); assertData(payload.readByteArray()); sink.close(); assertData("8000"); } @Test public void serverMessageLengthLong() throws IOException { Sink sink = serverWriter.newMessageSink(OPCODE_BINARY, -1); // Create a payload which will overflow the normal and short payload byte size. Buffer payload = new Buffer(); while (payload.completeSegmentByteCount() <= PAYLOAD_SHORT_MAX) { payload.writeByte('0'); } long byteCount = payload.completeSegmentByteCount(); // Write directly to the unbuffered sink. This ensures it will become single frame. sink.write(payload.clone(), byteCount); assertData("027f"); // 'f' == 16-byte follow-up length. assertData(Util.format("%016X", byteCount)); assertData(payload.readByteArray(byteCount)); sink.close(); assertData("8000"); } @Test public void clientBinary() throws IOException { byte[] maskKey1 = new byte[4]; random.nextBytes(maskKey1); byte[] maskKey2 = new byte[4]; random.nextBytes(maskKey2); random.setSeed(0); // Reset the seed so real data matches. BufferedSink sink = Okio.buffer(clientWriter.newMessageSink(OPCODE_BINARY, -1)); byte[] part1 = binaryData(50); sink.write(part1).flush(); toggleMask(part1, 50, maskKey1, 0); assertData("02b2"); assertData(maskKey1); assertData(part1); byte[] part2 = binaryData(50); sink.write(part2).close(); toggleMask(part2, 50, maskKey2, 0); assertData("80b2"); assertData(maskKey2); assertData(part2); } @Test public void serverEmptyClose() throws IOException { serverWriter.writeClose(0, null); assertData("8800"); } @Test public void serverCloseWithCode() throws IOException { serverWriter.writeClose(1001, null); assertData("880203e9"); } @Test public void serverCloseWithCodeAndReason() throws IOException { serverWriter.writeClose(1001, ByteString.encodeUtf8("Hello")); assertData("880703e948656c6c6f"); } @Test public void clientEmptyClose() throws IOException { clientWriter.writeClose(0, null); assertData("888060b420bb"); } @Test public void clientCloseWithCode() throws IOException { clientWriter.writeClose(1001, null); assertData("888260b420bb635d"); } @Test public void clientCloseWithCodeAndReason() throws IOException { clientWriter.writeClose(1001, ByteString.encodeUtf8("Hello")); assertData("888760b420bb635d68de0cd84f"); } @Test public void closeWithOnlyReasonThrows() throws IOException { clientWriter.writeClose(0, ByteString.encodeUtf8("Hello")); assertData("888760b420bb60b468de0cd84f"); } @Test public void closeCodeOutOfRangeThrows() throws IOException { try { clientWriter.writeClose(98724976, ByteString.encodeUtf8("Hello")); fail(); } catch (IllegalArgumentException e) { assertEquals("Code must be in range [1000,5000): 98724976", e.getMessage()); } } @Test public void closeReservedThrows() throws IOException { try { clientWriter.writeClose(1005, ByteString.encodeUtf8("Hello")); fail(); } catch (IllegalArgumentException e) { assertEquals("Code 1005 is reserved and may not be used.", e.getMessage()); } } @Test public void serverEmptyPing() throws IOException { serverWriter.writePing(ByteString.EMPTY); assertData("8900"); } @Test public void clientEmptyPing() throws IOException { clientWriter.writePing(ByteString.EMPTY); assertData("898060b420bb"); } @Test public void serverPingWithPayload() throws IOException { serverWriter.writePing(ByteString.encodeUtf8("Hello")); assertData("890548656c6c6f"); } @Test public void clientPingWithPayload() throws IOException { clientWriter.writePing(ByteString.encodeUtf8("Hello")); assertData("898560b420bb28d14cd70f"); } @Test public void serverEmptyPong() throws IOException { serverWriter.writePong(ByteString.EMPTY); assertData("8a00"); } @Test public void clientEmptyPong() throws IOException { clientWriter.writePong(ByteString.EMPTY); assertData("8a8060b420bb"); } @Test public void serverPongWithPayload() throws IOException { serverWriter.writePong(ByteString.encodeUtf8("Hello")); assertData("8a0548656c6c6f"); } @Test public void clientPongWithPayload() throws IOException { clientWriter.writePong(ByteString.encodeUtf8("Hello")); assertData("8a8560b420bb28d14cd70f"); } @Test public void pingTooLongThrows() throws IOException { try { serverWriter.writePing(ByteString.of(binaryData(1000))); fail(); } catch (IllegalArgumentException e) { assertEquals("Payload size must be less than or equal to 125", e.getMessage()); } } @Test public void pongTooLongThrows() throws IOException { try { serverWriter.writePong(ByteString.of(binaryData(1000))); fail(); } catch (IllegalArgumentException e) { assertEquals("Payload size must be less than or equal to 125", e.getMessage()); } } @Test public void closeTooLongThrows() throws IOException { try { ByteString longReason = ByteString.encodeUtf8(repeat('X', 124)); serverWriter.writeClose(1000, longReason); fail(); } catch (IllegalArgumentException e) { assertEquals("Payload size must be less than or equal to 125", e.getMessage()); } } @Test public void twoMessageSinksThrows() { clientWriter.newMessageSink(OPCODE_TEXT, -1); try { clientWriter.newMessageSink(OPCODE_TEXT, -1); fail(); } catch (IllegalStateException e) { assertEquals("Another message writer is active. Did you call close()?", e.getMessage()); } } private void assertData(String hex) throws EOFException { ByteString expected = ByteString.decodeHex(hex); ByteString actual = data.readByteString(expected.size()); assertEquals(expected, actual); } private void assertData(byte[] data) throws IOException { int byteCount = 16; for (int i = 0; i < data.length; i += byteCount) { int count = Math.min(byteCount, data.length - i); Buffer expectedChunk = new Buffer(); expectedChunk.write(data, i, count); assertEquals("At " + i, expectedChunk.readByteString(), this.data.readByteString(count)); } } private static byte[] binaryData(int length) { byte[] junk = new byte[length]; new Random(0).nextBytes(junk); return junk; } }