package ameba.websocket.internal; import ameba.i18n.Messages; import ameba.util.ClassUtils; import ameba.websocket.CloseReasons; import com.google.common.collect.Sets; import com.google.common.primitives.Primitives; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import javax.websocket.*; import java.lang.invoke.MethodHandle; import java.lang.reflect.InvocationTargetException; import java.nio.ByteBuffer; import java.nio.charset.Charset; import java.util.Set; /** * <p>Abstract EndpointMeta class.</p> * * @author icode * */ public abstract class EndpointMeta { private static final Logger logger = LoggerFactory.getLogger(EndpointMeta.class); protected final Set<MessageHandlerFactory> messageHandlerFactories = Sets.newLinkedHashSet(); private Class endpointClass; /** * <p>Constructor for EndpointMeta.</p> * * @param endpointClass a {@link java.lang.Class} object. */ public EndpointMeta(Class endpointClass) { this.endpointClass = endpointClass; } static Class<?> getHandlerType(MessageHandler handler) { if (handler instanceof TypeMessageHandler) { return ((TypeMessageHandler) handler).getType(); } Class<?> result = ClassUtils.getGenericClass(handler.getClass()); return result == null ? Object.class : result; } /** * <p>checkMessageSize.</p> * * @param message a {@link java.lang.Object} object. * @param maxMessageSize a long. */ protected static void checkMessageSize(Object message, long maxMessageSize) { if (maxMessageSize != -1) { final long messageSize = (message instanceof String ? ((String) message).getBytes(Charset.defaultCharset()).length : ((ByteBuffer) message).remaining()); if (messageSize > maxMessageSize) { throw new MessageTooBigException( Messages.get("web.socket.error.message.too.long", maxMessageSize, messageSize) ); } } } /** * <p>Getter for the field <code>endpointClass</code>.</p> * * @return a {@link java.lang.Class} object. */ public Class getEndpointClass() { return endpointClass; } /** * <p>getEndpoint.</p> * * @return a {@link java.lang.Object} object. */ public abstract Object getEndpoint(); /** * <p>getOnCloseHandle.</p> * * @return a {@link java.lang.invoke.MethodHandle} object. */ public abstract MethodHandle getOnCloseHandle(); /** * <p>getOnErrorHandle.</p> * * @return a {@link java.lang.invoke.MethodHandle} object. */ public abstract MethodHandle getOnErrorHandle(); /** * <p>getOnOpenHandle.</p> * * @return a {@link java.lang.invoke.MethodHandle} object. */ public abstract MethodHandle getOnOpenHandle(); /** * <p>getOnOpenParameters.</p> * * @return an array of {@link ameba.websocket.internal.EndpointMeta.ParameterExtractor} objects. */ public abstract ParameterExtractor[] getOnOpenParameters(); /** * <p>getOnCloseParameters.</p> * * @return an array of {@link ameba.websocket.internal.EndpointMeta.ParameterExtractor} objects. */ public abstract ParameterExtractor[] getOnCloseParameters(); /** * <p>getOnErrorParameters.</p> * * @return an array of {@link ameba.websocket.internal.EndpointMeta.ParameterExtractor} objects. */ public abstract ParameterExtractor[] getOnErrorParameters(); /** * <p>callMethod.</p> * * @param method a {@link java.lang.invoke.MethodHandle} object. * @param extractors an array of {@link ameba.websocket.internal.EndpointMeta.ParameterExtractor} objects. * @param session a {@link javax.websocket.Session} object. * @param callOnError a boolean. * @param params a {@link java.lang.Object} object. * @return a {@link java.lang.Object} object. */ protected Object callMethod(MethodHandle method, ParameterExtractor[] extractors, Session session, boolean callOnError, Object... params) { Object[] paramValues = new Object[extractors.length + 1]; try { // TYRUS-325: Server do not close session properly if non-instantiable endpoint class is provided if (callOnError && getEndpoint() == null) { try { session.close(CloseReasons.UNEXPECTED_CONDITION.getCloseReason()); } catch (Exception e) { logger.error(e.getMessage(), e); } return null; } paramValues[0] = getEndpoint(); for (int i = 0; i < extractors.length; i++) { paramValues[i + 1] = extractors[i].value(session, params); } return method.invokeWithArguments(paramValues); } catch (Throwable e) { if (callOnError) { onError(session, (e instanceof InvocationTargetException ? e.getCause() : e)); } else { logger.error(Messages.get("web.socket.error.endpoint"), e); } } return null; } /** * <p>onOpen.</p> * * @param session a {@link javax.websocket.Session} object. * @param configuration a {@link javax.websocket.EndpointConfig} object. */ @SuppressWarnings("unchecked") public void onOpen(Session session, EndpointConfig configuration) { for (MessageHandlerFactory f : messageHandlerFactories) { MessageHandler handler = f.create(session); final Class<?> handlerClass = getHandlerType(handler); if (handler instanceof MessageHandler.Whole) { //WHOLE MESSAGE HANDLER session.addMessageHandler(handlerClass, (MessageHandler.Whole) handler); } else if (handler instanceof MessageHandler.Partial) { // PARTIAL MESSAGE HANDLER session.addMessageHandler(handlerClass, (MessageHandler.Partial) handler); } } if (getOnOpenHandle() != null) { callMethod(getOnOpenHandle(), getOnOpenParameters(), session, true); } } /** * <p>onClose.</p> * * @param session a {@link javax.websocket.Session} object. * @param closeReason a {@link javax.websocket.CloseReason} object. */ public void onClose(Session session, CloseReason closeReason) { if (getOnCloseHandle() != null) { callMethod(getOnCloseHandle(), getOnCloseParameters(), session, true, closeReason); } } /** * <p>onError.</p> * * @param session a {@link javax.websocket.Session} object. * @param thr a {@link java.lang.Throwable} object. */ public void onError(Session session, Throwable thr) { if (getOnErrorHandle() != null) { callMethod(getOnErrorHandle(), getOnErrorParameters(), session, false, thr); } else { logger.error(Messages.get("web.socket.error"), thr); } } private Class getMessageType(Class type) { return type == ameba.websocket.PongMessage.class ? PongMessage.class : type; } protected interface ParameterExtractor { Object value(Session session, Object... paramValues) throws DecodeException; } protected static class ParamValue implements ParameterExtractor { private final int index; public ParamValue(int index) { this.index = index; } @Override public Object value(Session session, Object... paramValues) { return paramValues[index]; } } protected abstract class MessageHandlerFactory { final MethodHandle method; final ParameterExtractor[] extractors; final Class<?> type; final long maxMessageSize; MessageHandlerFactory(MethodHandle method, ParameterExtractor[] extractors, Class<?> type, long maxMessageSize) { this.method = method; this.extractors = extractors; this.type = Primitives.isWrapperType(type) ? type : Primitives.wrap(type); this.maxMessageSize = maxMessageSize; } abstract MessageHandler create(Session session); protected void sendObject(final Session session, Object msg) { session.getAsyncRemote().sendObject(msg, result -> { Throwable e = result.getException(); if (e != null) { onError(session, e); } }); } } protected class WholeHandler extends MessageHandlerFactory { public WholeHandler(MethodHandle method, ParameterExtractor[] extractors, Class<?> type, long maxMessageSize) { super(method, extractors, type, maxMessageSize); } @Override public MessageHandler create(final Session session) { return new BasicMessageHandler() { @Override public void onMessage(Object message) { checkMessageSize(message, getMaxMessageSize()); Object result = callMethod(method, extractors, session, true, message); if (result != null) { sendObject(session, result); } } @Override public Class<?> getType() { return getMessageType(type); } @Override public long getMaxMessageSize() { return maxMessageSize; } }; } } protected class PartialHandler extends MessageHandlerFactory { public PartialHandler(MethodHandle method, ParameterExtractor[] extractors, Class<?> type, long maxMessageSize) { super(method, extractors, type, maxMessageSize); } @Override public MessageHandler create(final Session session) { return new AsyncMessageHandler() { @Override public void onMessage(Object partialMessage, boolean last) { checkMessageSize(partialMessage, getMaxMessageSize()); Object result = callMethod(method, extractors, session, true, partialMessage, last); if (result != null) { sendObject(session, result); } } @Override public Class<?> getType() { return getMessageType(type); } @Override public long getMaxMessageSize() { return maxMessageSize; } }; } } }