/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.jooby.internal;
import static java.util.Objects.requireNonNull;
import java.nio.channels.ClosedChannelException;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Queue;
import java.util.concurrent.ConcurrentLinkedQueue;
import org.jooby.Err;
import org.jooby.MediaType;
import org.jooby.Mutant;
import org.jooby.Renderer;
import org.jooby.Request;
import org.jooby.WebSocket;
import org.jooby.internal.parser.ParserExecutor;
import org.jooby.spi.NativeWebSocket;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.google.common.collect.ImmutableList;
import com.google.inject.Injector;
import com.google.inject.Key;
import javaslang.control.Try;
import javaslang.control.Try.CheckedRunnable;
@SuppressWarnings("unchecked")
public class WebSocketImpl implements WebSocket {
@SuppressWarnings({"rawtypes" })
private static final OnMessage NOOP = arg -> {
};
private static final OnClose CLOSE_NOOP = arg -> {
};
/** The logging system. */
private final Logger log = LoggerFactory.getLogger(WebSocket.class);
/** All connected websocket. */
private static final Queue<WebSocket> sessions = new ConcurrentLinkedQueue<>();
private Locale locale;
private String path;
private String pattern;
private Map<Object, String> vars;
private MediaType consumes;
private MediaType produces;
private OnOpen handler;
private OnMessage<Mutant> messageCallback = NOOP;
private OnClose closeCallback = CLOSE_NOOP;
private OnError exceptionCallback = cause -> {
log.error("execution of WS" + path() + " resulted in exception", cause);
};
private NativeWebSocket ws;
private Injector injector;
private boolean suspended;
private List<Renderer> renderers;
private volatile boolean open;
public WebSocketImpl(final OnOpen handler, final String path,
final String pattern, final Map<Object, String> vars,
final MediaType consumes, final MediaType produces) {
this.handler = handler;
this.path = path;
this.pattern = pattern;
this.vars = vars;
this.consumes = consumes;
this.produces = produces;
}
@Override
public void close(final CloseStatus status) {
sessions.remove(this);
synchronized (this) {
open = false;
ws.close(status.code(), status.reason());
}
}
@Override
public void resume() {
sessions.add(this);
synchronized (this) {
if (suspended) {
ws.resume();
suspended = false;
}
}
}
@Override
public void pause() {
sessions.remove(this);
synchronized (this) {
if (!suspended) {
ws.pause();
suspended = true;
}
}
}
@Override
public void terminate() throws Exception {
sessions.remove(this);
synchronized (this) {
open = false;
ws.terminate();
}
}
@Override
public boolean isOpen() {
return open && ws.isOpen();
}
@Override
public void broadcast(final Object data, final SuccessCallback success, final OnError err)
throws Exception {
for (WebSocket ws : sessions) {
try {
ws.send(data, success, err);
} catch (Exception ex) {
err.onError(ex);
}
}
}
@Override
public void send(final Object data, final SuccessCallback success, final OnError err)
throws Exception {
requireNonNull(data, "Message required.");
requireNonNull(success, "Success callback required.");
requireNonNull(err, "Error callback required.");
synchronized (this) {
if (isOpen()) {
new WebSocketRendererContext(
renderers,
ws,
produces,
StandardCharsets.UTF_8,
locale,
success,
err).render(data);
} else {
throw new Err(WebSocket.NORMAL, "WebSocket is closed.");
}
}
}
@Override
public void onMessage(final OnMessage<Mutant> callback) throws Exception {
this.messageCallback = requireNonNull(callback, "Message callback required.");
}
public void connect(final Injector injector, final Request req, final NativeWebSocket ws) {
this.open = true;
this.injector = requireNonNull(injector, "Injector required.");
this.ws = requireNonNull(ws, "WebSocket is required.");
this.locale = req.locale();
renderers = ImmutableList.copyOf(injector.getInstance(Renderer.KEY));
/**
* Bind callbacks
*/
ws.onBinaryMessage(buffer -> Try
.run(sync(() -> messageCallback.onMessage(new WsBinaryMessage(buffer))))
.onFailure(this::handleErr));
ws.onTextMessage(message -> Try
.run(sync(() -> messageCallback.onMessage(
new MutantImpl(injector.getInstance(ParserExecutor.class), consumes,
new StrParamReferenceImpl("body", "message", ImmutableList.of(message))))))
.onFailure(this::handleErr));
ws.onCloseMessage((code, reason) -> {
sessions.remove(this);
Try.run(sync(() -> {
this.open = false;
if (closeCallback != null) {
closeCallback.onClose(reason.map(r -> WebSocket.CloseStatus.of(code, r))
.orElse(WebSocket.CloseStatus.of(code)));
}
closeCallback = null;
})).onFailure(this::handleErr);
});
ws.onErrorMessage(this::handleErr);
// connect now
try {
sessions.add(this);
handler.onOpen(req, this);
} catch (Throwable ex) {
handleErr(ex);
}
}
@Override
public String path() {
return path;
}
@Override
public String pattern() {
return pattern;
}
@Override
public Map<Object, String> vars() {
return vars;
}
@Override
public MediaType consumes() {
return consumes;
}
@Override
public MediaType produces() {
return produces;
}
@Override
public <T> T require(final Key<T> key) {
return injector.getInstance(key);
}
@Override
public String toString() {
StringBuilder buffer = new StringBuilder();
buffer.append("WS ").append(path()).append("\n");
buffer.append(" pattern: ").append(pattern()).append("\n");
buffer.append(" vars: ").append(vars()).append("\n");
buffer.append(" consumes: ").append(consumes()).append("\n");
buffer.append(" produces: ").append(produces()).append("\n");
return buffer.toString();
}
@Override
public void onError(final WebSocket.OnError callback) {
this.exceptionCallback = requireNonNull(callback, "A callback is required.");
}
@Override
public void onClose(final WebSocket.OnClose callback) throws Exception {
this.closeCallback = requireNonNull(callback, "A callback is required.");
}
private void handleErr(final Throwable cause) {
try {
boolean silent = ConnectionResetByPeer.test(cause) || cause instanceof ClosedChannelException;
if (silent) {
log.debug("execution of WS" + path() + " resulted in exception", cause);
} else {
exceptionCallback.onError(cause);
}
} finally {
cleanup(cause);
}
}
private void cleanup(final Throwable cause) {
open = false;
NativeWebSocket lws = ws;
this.ws = null;
this.injector = null;
this.handler = null;
this.closeCallback = null;
this.exceptionCallback = null;
this.messageCallback = null;
if (lws != null && lws.isOpen()) {
WebSocket.CloseStatus closeStatus = WebSocket.SERVER_ERROR;
if (cause instanceof IllegalArgumentException) {
closeStatus = WebSocket.BAD_DATA;
} else if (cause instanceof NoSuchElementException) {
closeStatus = WebSocket.BAD_DATA;
} else if (cause instanceof Err) {
Err err = (Err) cause;
if (err.statusCode() == 400) {
closeStatus = WebSocket.BAD_DATA;
}
}
lws.close(closeStatus.code(), closeStatus.reason());
}
}
private CheckedRunnable sync(final CheckedRunnable task) {
return () -> {
synchronized (this) {
task.run();
}
};
}
;
}