/* * Copyright (c) 2011-2013 The original author or authors * ------------------------------------------------------ * All rights reserved. This program and the accompanying materials * are made available under the terms of the Eclipse Public License v1.0 * and Apache License v2.0 which accompanies this distribution. * * The Eclipse Public License is available at * http://www.eclipse.org/legal/epl-v10.html * * The Apache License v2.0 is available at * http://www.opensource.org/licenses/apache2.0.php * * You may elect to redistribute this code under either of these licenses. */ package io.vertx.core.http.impl; import io.netty.buffer.ByteBuf; import io.vertx.core.Handler; import io.vertx.core.buffer.Buffer; import io.vertx.core.eventbus.Message; import io.vertx.core.eventbus.MessageConsumer; import io.vertx.core.http.WebSocketBase; import io.vertx.core.http.WebSocketFrame; import io.vertx.core.http.impl.ws.WebSocketFrameImpl; import io.vertx.core.http.impl.ws.WebSocketFrameInternal; import io.vertx.core.impl.VertxInternal; import io.vertx.core.net.SocketAddress; import io.vertx.core.net.impl.ConnectionBase; import javax.net.ssl.SSLPeerUnverifiedException; import javax.security.cert.X509Certificate; import java.util.UUID; /** * This class is optimised for performance when used on the same event loop. However it can be used safely from other threads. * <p> * The internal state is protected using the synchronized keyword. If always used on the same event loop, then * we benefit from biased locking which makes the overhead of synchronized near zero. * * @author <a href="http://tfox.org">Tim Fox</a> * @param <S> self return type */ public abstract class WebSocketImplBase<S extends WebSocketBase> implements WebSocketBase { private final boolean supportsContinuation; private final String textHandlerID; private final String binaryHandlerID; private final int maxWebSocketFrameSize; private final int maxWebSocketMessageSize; private final MessageConsumer binaryHandlerRegistration; private final MessageConsumer textHandlerRegistration; private Object metric; private Handler<WebSocketFrameInternal> frameHandler; private Handler<Buffer> dataHandler; private Handler<Void> drainHandler; private Handler<Throwable> exceptionHandler; private Handler<Void> closeHandler; private Handler<Void> endHandler; protected final ConnectionBase conn; protected boolean closed; WebSocketImplBase(VertxInternal vertx, ConnectionBase conn, boolean supportsContinuation, int maxWebSocketFrameSize, int maxWebSocketMessageSize) { this.supportsContinuation = supportsContinuation; this.textHandlerID = UUID.randomUUID().toString(); this.binaryHandlerID = UUID.randomUUID().toString(); this.conn = conn; Handler<Message<Buffer>> binaryHandler = msg -> writeBinaryFrameInternal(msg.body()); binaryHandlerRegistration = vertx.eventBus().<Buffer>localConsumer(binaryHandlerID).handler(binaryHandler); Handler<Message<String>> textHandler = msg -> writeTextFrameInternal(msg.body()); textHandlerRegistration = vertx.eventBus().<String>localConsumer(textHandlerID).handler(textHandler); this.maxWebSocketFrameSize = maxWebSocketFrameSize; this.maxWebSocketMessageSize = maxWebSocketMessageSize; } public String binaryHandlerID() { return binaryHandlerID; } public String textHandlerID() { return textHandlerID; } public boolean writeQueueFull() { synchronized (conn) { checkClosed(); return conn.isNotWritable(); } } public void close() { synchronized (conn) { checkClosed(); conn.close(); cleanupHandlers(); } } @Override public boolean isSsl() { return conn.isSsl(); } @Override public X509Certificate[] peerCertificateChain() throws SSLPeerUnverifiedException { return conn.peerCertificateChain(); } @Override public SocketAddress localAddress() { return conn.localAddress(); } @Override public SocketAddress remoteAddress() { return conn.remoteAddress(); } @Override public S writeFinalTextFrame(String text) { return (S) writeFrame(WebSocketFrame.textFrame(text, true)); } @Override public S writeFinalBinaryFrame(Buffer data) { return (S) writeFrame(WebSocketFrame.binaryFrame(data, true)); } @Override public S writeBinaryMessage(Buffer data) { synchronized (conn) { checkClosed(); writeMessageInternal(data); return (S) this; } } @Override public S writeTextMessage(String text) { synchronized (conn) { checkClosed(); writeTextMessageInternal(text); return (S) this; } } @Override public S write(Buffer data) { synchronized (conn) { checkClosed(); writeFrame(WebSocketFrame.binaryFrame(data, true)); return (S) this; } } private void writeMessageInternal(Buffer data) { checkClosed(); writePartialMessage(FrameType.BINARY, data, 0); } private void writeTextMessageInternal(String text) { checkClosed(); Buffer data = Buffer.buffer(text); writePartialMessage(FrameType.TEXT, data, 0); } /** * Splits the provided buffer into multiple frames (which do not exceed the maximum web socket frame size) * and writes them in order to the socket. */ private void writePartialMessage(FrameType frameType, Buffer data, int offset) { int end = offset + maxWebSocketFrameSize; boolean isFinal; if (end >= data.length()) { end = data.length(); isFinal = true; } else { isFinal = false; } Buffer slice = data.slice(offset, end); WebSocketFrame frame; if (offset == 0 || !supportsContinuation) { frame = new WebSocketFrameImpl(frameType, slice.getByteBuf(), isFinal); } else { frame = WebSocketFrame.continuationFrame(slice, isFinal); } writeFrame(frame); int newOffset = offset + maxWebSocketFrameSize; if (!isFinal) { writePartialMessage(frameType, data, newOffset); } } private void writeBinaryFrameInternal(Buffer data) { ByteBuf buf = data.getByteBuf(); WebSocketFrame frame = new WebSocketFrameImpl(FrameType.BINARY, buf); writeFrame(frame); } private void writeTextFrameInternal(String str) { WebSocketFrame frame = new WebSocketFrameImpl(str); writeFrame(frame); } @Override public S writeFrame(WebSocketFrame frame) { synchronized (conn) { checkClosed(); conn.reportBytesWritten(frame.binaryData().length()); conn.writeToChannel(frame); } return (S) this; } void checkClosed() { if (closed) { throw new IllegalStateException("WebSocket is closed"); } } void handleFrame(WebSocketFrameInternal frame) { synchronized (conn) { conn.reportBytesRead(frame.binaryData().length()); if (dataHandler != null) { Buffer buff = Buffer.buffer(frame.getBinaryData()); dataHandler.handle(buff); } if (frameHandler != null) { frameHandler.handle(frame); } } } private class FrameAggregator implements Handler<WebSocketFrameInternal> { private Handler<String> textMessageHandler; private Handler<Buffer> binaryMessageHandler; private Buffer textMessageBuffer; private Buffer binaryMessageBuffer; @Override public void handle(WebSocketFrameInternal frame) { switch (frame.type()) { case TEXT: handleTextFrame(frame); break; case BINARY: handleBinaryFrame(frame); break; case CONTINUATION: if (textMessageBuffer != null && textMessageBuffer.length() > 0) { handleTextFrame(frame); } else if (binaryMessageBuffer != null && binaryMessageBuffer.length() > 0) { handleBinaryFrame(frame); } break; } } private void handleTextFrame(WebSocketFrameInternal frame) { Buffer frameBuffer = Buffer.buffer(frame.getBinaryData()); if (textMessageBuffer == null) { textMessageBuffer = frameBuffer; } else { textMessageBuffer.appendBuffer(frameBuffer); } if (textMessageBuffer.length() > maxWebSocketMessageSize) { int len = textMessageBuffer.length() - frameBuffer.length(); textMessageBuffer = null; String msg = "Cannot process text frame of size " + frameBuffer.length() + ", it would cause message buffer (size " + len + ") to overflow max message size of " + maxWebSocketMessageSize; handleException(new IllegalStateException(msg)); return; } if (frame.isFinal()) { String fullMessage = textMessageBuffer.toString(); textMessageBuffer = null; if (textMessageHandler != null) { textMessageHandler.handle(fullMessage); } } } private void handleBinaryFrame(WebSocketFrameInternal frame) { Buffer frameBuffer = Buffer.buffer(frame.getBinaryData()); if (binaryMessageBuffer == null) { binaryMessageBuffer = frameBuffer; } else { binaryMessageBuffer.appendBuffer(frameBuffer); } if (binaryMessageBuffer.length() > maxWebSocketMessageSize) { int len = binaryMessageBuffer.length() - frameBuffer.length(); binaryMessageBuffer = null; String msg = "Cannot process binary frame of size " + frameBuffer.length() + ", it would cause message buffer (size " + len + ") to overflow max message size of " + maxWebSocketMessageSize; handleException(new IllegalStateException(msg)); return; } if (frame.isFinal()) { Buffer fullMessage = binaryMessageBuffer.copy(); binaryMessageBuffer = null; if (binaryMessageHandler != null) { binaryMessageHandler.handle(fullMessage); } } } } @Override public S frameHandler(Handler<WebSocketFrame> handler) { synchronized (conn) { checkClosed(); this.frameHandler = (Handler)handler; return (S) this; } } @Override public WebSocketBase textMessageHandler(Handler<String> handler) { synchronized (conn) { checkClosed(); if (frameHandler == null || frameHandler.getClass() != FrameAggregator.class) { frameHandler = new FrameAggregator(); } ((FrameAggregator) frameHandler).textMessageHandler = handler; return this; } } @Override public S binaryMessageHandler(Handler<Buffer> handler) { synchronized (conn) { checkClosed(); if (frameHandler == null || frameHandler.getClass() != FrameAggregator.class) { frameHandler = new FrameAggregator(); } ((FrameAggregator) frameHandler).binaryMessageHandler = handler; return (S) this; } } void writable() { if (drainHandler != null) { Handler<Void> dh = drainHandler; drainHandler = null; dh.handle(null); } } void handleException(Throwable t) { synchronized (conn) { if (exceptionHandler != null) { exceptionHandler.handle(t); } } } void handleClosed() { synchronized (conn) { cleanupHandlers(); if (endHandler != null) { conn.getContext().runOnContext(endHandler); } if (closeHandler != null) { conn.getContext().runOnContext(closeHandler); } } } private void cleanupHandlers() { if (!closed) { binaryHandlerRegistration.unregister(); textHandlerRegistration.unregister(); closed = true; } } synchronized void setMetric(Object metric) { this.metric = metric; } synchronized Object getMetric() { return metric; } @Override public S handler(Handler<Buffer> handler) { synchronized (conn) { if (handler != null) { checkClosed(); } this.dataHandler = handler; return (S) this; } } @Override public S endHandler(Handler<Void> handler) { synchronized (conn) { if (handler != null) { checkClosed(); } this.endHandler = handler; return (S) this; } } @Override public S exceptionHandler(Handler<Throwable> handler) { synchronized (conn) { if (handler != null) { checkClosed(); } this.exceptionHandler = handler; return (S) this; } } @Override public S closeHandler(Handler<Void> handler) { synchronized (conn) { checkClosed(); this.closeHandler = handler; return (S) this; } } @Override public S drainHandler(Handler<Void> handler) { synchronized (conn) { checkClosed(); this.drainHandler = handler; return (S) this; } } @Override public S pause() { synchronized (conn) { checkClosed(); conn.doPause(); return (S) this; } } @Override public S resume() { synchronized (conn) { checkClosed(); conn.doResume(); return (S) this; } } @Override public S setWriteQueueMaxSize(int maxSize) { synchronized (conn) { checkClosed(); conn.doSetWriteQueueMaxSize(maxSize); return (S) this; } } @Override public void end() { close(); } }