/** * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.jooby.internal.netty; import static io.netty.channel.ChannelFutureListener.CLOSE; import static java.util.Objects.requireNonNull; import java.io.IOException; import java.nio.ByteBuffer; import java.util.Optional; import java.util.concurrent.CountDownLatch; import java.util.function.BiConsumer; import java.util.function.Consumer; import org.jooby.WebSocket; import org.jooby.WebSocket.OnError; import org.jooby.WebSocket.SuccessCallback; import org.jooby.spi.NativeWebSocket; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.netty.channel.ChannelConfig; import io.netty.channel.ChannelHandlerContext; import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame; import io.netty.handler.codec.http.websocketx.TextWebSocketFrame; import io.netty.handler.codec.http.websocketx.WebSocketServerHandshaker; import io.netty.util.Attribute; import io.netty.util.AttributeKey; import io.netty.util.concurrent.Future; import io.netty.util.concurrent.GenericFutureListener; public class NettyWebSocket implements NativeWebSocket { public static final AttributeKey<NettyWebSocket> KEY = AttributeKey .newInstance(NettyWebSocket.class.getName()); /** The logging system. */ private final Logger log = LoggerFactory.getLogger(getClass()); private ChannelHandlerContext ctx; private Consumer<NettyWebSocket> handshake; private Runnable onConnectCallback; private WebSocketServerHandshaker handshaker; private Consumer<String> onTextCallback; private Consumer<ByteBuffer> onBinaryCallback; private BiConsumer<Integer, Optional<String>> onCloseCallback; private Consumer<Throwable> onErrorCallback; private final CountDownLatch ready = new CountDownLatch(1); public NettyWebSocket(final ChannelHandlerContext ctx, final WebSocketServerHandshaker handshaker, final Consumer<NettyWebSocket> handshake) { this.ctx = ctx; this.handshaker = handshaker; this.handshake = handshake; } @Override public void close(final int status, final String reason) { handshaker.close(ctx.channel(), new CloseWebSocketFrame(status, reason)) .addListener(CLOSE); Attribute<NettyWebSocket> ws = ctx.channel().attr(KEY); if (ws != null) { ws.set(null); } } @Override public void resume() { ChannelConfig config = ctx.channel().config(); if (!config.isAutoRead()) { config.setAutoRead(true); } } @Override public void onConnect(final Runnable callback) { this.onConnectCallback = requireNonNull(callback, "A callback is required."); } @Override public void onTextMessage(final Consumer<String> callback) { this.onTextCallback = requireNonNull(callback, "A callback is required."); } @Override public void onBinaryMessage(final Consumer<ByteBuffer> callback) { this.onBinaryCallback = requireNonNull(callback, "A callback is required."); } @Override public void onCloseMessage(final BiConsumer<Integer, Optional<String>> callback) { this.onCloseCallback = requireNonNull(callback, "A callback is required."); } @Override public void onErrorMessage(final Consumer<Throwable> callback) { this.onErrorCallback = requireNonNull(callback, "A callback is required."); } @Override public void pause() { ChannelConfig config = ctx.channel().config(); if (config.isAutoRead()) { config.setAutoRead(false); } } @Override public void terminate() throws IOException { this.onCloseCallback.accept(1006, Optional.of("Harsh disconnect")); ctx.disconnect().addListener(CLOSE); } @Override public void sendBytes(final ByteBuffer data, final SuccessCallback success, final OnError err) { sendBytes(Unpooled.wrappedBuffer(data), success, err); } @Override public void sendBytes(final byte[] data, final SuccessCallback success, final OnError err) { sendBytes(Unpooled.wrappedBuffer(data), success, err); } @Override public void sendText(final String data, final SuccessCallback success, final OnError err) { ctx.channel().writeAndFlush(new TextWebSocketFrame(data)) .addListener(listener(success, err)); } @Override public void sendText(final ByteBuffer data, final SuccessCallback success, final OnError err) { ByteBuf buffer = Unpooled.wrappedBuffer(data); ctx.channel().writeAndFlush(new TextWebSocketFrame(buffer)) .addListener(listener(success, err)); } @Override public void sendText(final byte[] data, final SuccessCallback success, final OnError err) { ByteBuf buffer = Unpooled.wrappedBuffer(data); ctx.channel().writeAndFlush(new TextWebSocketFrame(buffer)) .addListener(listener(success, err)); } @Override public boolean isOpen() { return ctx.channel().isOpen(); } public void connect() { onConnectCallback.run(); ready.countDown(); } public void hankshake() { handshake.accept(this); } public void handle(final Object msg) { ready(); if (msg instanceof TextWebSocketFrame) { onTextCallback.accept(((TextWebSocketFrame) msg).text()); } else if (msg instanceof BinaryWebSocketFrame) { onBinaryCallback.accept(((BinaryWebSocketFrame) msg).content().nioBuffer()); } else if (msg instanceof CloseWebSocketFrame) { CloseWebSocketFrame closeFrame = ((CloseWebSocketFrame) msg).retain(); int statusCode = closeFrame.statusCode(); onCloseCallback.accept(statusCode == -1 ? WebSocket.NORMAL.code() : statusCode, Optional.ofNullable(closeFrame.reasonText())); handshaker.close(ctx.channel(), closeFrame).addListener(CLOSE); } else if (msg instanceof Throwable) { onErrorCallback.accept((Throwable) msg); } } /** * Make sure hankshake/connect is set. */ private void ready() { try { ready.await(); } catch (InterruptedException ex) { log.error("Connect call was inturrupted", ex); Thread.currentThread().interrupt(); } } private void sendBytes(final ByteBuf buffer, final SuccessCallback success, final OnError err) { ctx.channel().writeAndFlush(new BinaryWebSocketFrame(buffer)) .addListener(listener(success, err)); } private GenericFutureListener<? extends Future<? super Void>> listener( final SuccessCallback success, final OnError err) { return f -> { if (f.isSuccess()) { success.invoke(); } else { err.onError(f.cause()); } }; } }