/* * JBoss, Home of Professional Open Source. * Copyright 2014 Red Hat, Inc., and individual contributors * as indicated by the @author tags. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package io.undertow.websockets.jsr.annotated; import io.undertow.UndertowLogger; import io.undertow.servlet.api.InstanceHandle; import io.undertow.websockets.core.WebSocketLogger; import io.undertow.websockets.jsr.UndertowSession; import javax.websocket.CloseReason; import javax.websocket.Endpoint; import javax.websocket.EndpointConfig; import javax.websocket.MessageHandler; import javax.websocket.SendHandler; import javax.websocket.SendResult; import javax.websocket.Session; import java.io.IOException; import java.nio.ByteBuffer; import java.util.HashMap; import java.util.Map; /** * @author Stuart Douglas */ public class AnnotatedEndpoint extends Endpoint { private final InstanceHandle<?> instance; private final BoundMethod webSocketOpen; private final BoundMethod webSocketClose; private final BoundMethod webSocketError; private final BoundMethod textMessage; private final BoundMethod binaryMessage; private final BoundMethod pongMessage; private volatile boolean released; AnnotatedEndpoint(final InstanceHandle<?> instance, final BoundMethod webSocketOpen, final BoundMethod webSocketClose, final BoundMethod webSocketError, final BoundMethod textMessage, final BoundMethod binaryMessage, final BoundMethod pongMessage) { this.instance = instance; this.webSocketOpen = webSocketOpen; this.webSocketClose = webSocketClose; this.webSocketError = webSocketError; this.textMessage = textMessage; this.binaryMessage = binaryMessage; this.pongMessage = pongMessage; } @Override public void onOpen(final Session session, final EndpointConfig endpointConfiguration) { this.released = false; final UndertowSession s = (UndertowSession) session; boolean partialText = textMessage == null || (textMessage.hasParameterType(boolean.class) && !textMessage.getMessageType().equals(boolean.class)); boolean partialBinary = binaryMessage == null || (binaryMessage.hasParameterType(boolean.class) && !binaryMessage.getMessageType().equals(boolean.class)); if(textMessage != null) { if(partialText) { addPartialHandler(s, textMessage); } else { if(textMessage.getMaxMessageSize() > 0) { s.setMaxTextMessageBufferSize((int) textMessage.getMaxMessageSize()); } addWholeHandler(s, textMessage); } } if(binaryMessage != null) { if(partialBinary) { addPartialHandler(s, binaryMessage); } else { if(binaryMessage.getMaxMessageSize() > 0) { s.setMaxBinaryMessageBufferSize((int) binaryMessage.getMaxMessageSize()); } addWholeHandler(s, binaryMessage); } } if(pongMessage != null) { addWholeHandler(s, pongMessage); } if (webSocketOpen != null) { final Map<Class<?>, Object> params = new HashMap<>(); params.put(Session.class, session); params.put(EndpointConfig.class, endpointConfiguration); params.put(Map.class, session.getPathParameters()); invokeMethod(params, webSocketOpen, s); } } private void addPartialHandler(final UndertowSession session, final BoundMethod method) { session.addMessageHandler((Class) method.getMessageType(), new MessageHandler.Partial<Object>() { @Override public void onMessage(Object partialMessage, boolean last) { final Map<Class<?>, Object> params = new HashMap<>(); params.put(Session.class, session); params.put(Map.class, session.getPathParameters()); params.put(method.getMessageType(), partialMessage); params.put(boolean.class, last); final Object result; try { result = method.invoke(instance.getInstance(), params); } catch (Throwable e) { AnnotatedEndpoint.this.onError(session, e); return; } sendResult(result, session); } }); } private void addWholeHandler(final UndertowSession session, final BoundMethod method) { session.addMessageHandler((Class) method.getMessageType(), new MessageHandler.Whole<Object>() { @Override public void onMessage(Object partialMessage) { final Map<Class<?>, Object> params = new HashMap<>(); params.put(Session.class, session); params.put(Map.class, session.getPathParameters()); params.put(method.getMessageType(), partialMessage); final Object result; try { result = method.invoke(instance.getInstance(), params); } catch (Exception e) { AnnotatedEndpoint.this.onError(session, e); return; } sendResult(result, session); } }); } private void invokeMethod(final Map<Class<?>, Object> params, final BoundMethod method, final UndertowSession session) { session.getContainer().invokeEndpointMethod(session.getExecutor(), new Runnable() { @Override public void run() { if(!released) { try { method.invoke(instance.getInstance(), params); } catch (Exception e) { onError(session, e); } } } }); } private void sendResult(final Object result, UndertowSession session) { if (result != null) { if (result instanceof String) { session.getAsyncRemote().sendText((String) result, new ErrorReportingSendHandler(session)); } else if (result instanceof byte[]) { session.getAsyncRemote().sendBinary(ByteBuffer.wrap((byte[]) result), new ErrorReportingSendHandler(session)); } else if (result instanceof ByteBuffer) { session.getAsyncRemote().sendBinary((ByteBuffer) result, new ErrorReportingSendHandler(session)); } else { session.getAsyncRemote().sendObject(result, new ErrorReportingSendHandler(session)); } if(session.getAsyncRemote().getBatchingAllowed()) { try { session.getAsyncRemote().flushBatch(); } catch (IOException e) { onError(session, e); } } } } @Override public void onClose(final Session session, final CloseReason closeReason) { if (webSocketClose != null) { final Map<Class<?>, Object> params = new HashMap<>(); params.put(Session.class, session); params.put(Map.class, session.getPathParameters()); params.put(CloseReason.class, closeReason); ((UndertowSession) session).getContainer().invokeEndpointMethod(((UndertowSession)session).getExecutor(), new Runnable() { @Override public void run() { if(!released) { try { webSocketClose.invoke(instance.getInstance(), params); } catch (Exception e) { onError(session, e); } finally { released = true; instance.release(); } } } } ); } } @Override public void onError(final Session session, final Throwable thr) { if (webSocketError != null) { final Map<Class<?>, Object> params = new HashMap<>(); params.put(Session.class, session); params.put(Throwable.class, thr); params.put(Map.class, session.getPathParameters()); ((UndertowSession) session).getContainer().invokeEndpointMethod(((UndertowSession)session).getExecutor(), new Runnable() { @Override public void run() { if(!released) { try { webSocketError.invoke(instance.getInstance(), params); } catch (Exception e) { if (e instanceof RuntimeException) { throw (RuntimeException) e; } throw new RuntimeException(e); //not much we can do here } } } }); } else if (thr instanceof IOException) { UndertowLogger.REQUEST_IO_LOGGER.ioException((IOException) thr); } else { WebSocketLogger.REQUEST_LOGGER.unhandledErrorInAnnotatedEndpoint(instance.getInstance(), thr); } } private final class ErrorReportingSendHandler implements SendHandler { private final Session session; private ErrorReportingSendHandler(Session session) { this.session = session; } @Override public void onResult(final SendResult result) { if (!result.isOK()) { AnnotatedEndpoint.this.onError(session, result.getException()); } } } }