package org.async.rmi.server; import io.netty.channel.ChannelHandlerContext; import io.netty.util.AttributeKey; import org.async.rmi.*; import org.async.rmi.messages.CancelInvokeRequest; import org.async.rmi.messages.InvokeRequest; import org.async.rmi.messages.Response; import org.async.rmi.modules.Util; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.net.InetSocketAddress; import java.rmi.Remote; import java.util.*; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; /** * Created by Barak Bar Orion * 28/10/14. */ public class ObjectRef { private static final Logger logger = LoggerFactory.getLogger(ObjectRef.class); private final Remote impl; private final Map<Long, OneWay> oneWayMap; private final String implClassName; private Map<Long, Method> methodIdToMethodMap; private final Set<Long> resultSetSet; private Map<Long, Trace> traceMap; private Map<String, CompletableFuture> inProgressCalls; public static final AttributeKey<ServerResultSetCallback> SERVER_RESULT_SET_CALLBACK_ATTRIBUTE_KEY = AttributeKey.valueOf("serverResultSetCallback"); public ObjectRef(Remote impl, Class[] remoteInterfaces, Map<Long, OneWay> oneWayMap, Set<Long> resultSetSet, Map<Long, Trace> traceMap, String implClassName) { this.impl = impl; this.oneWayMap = oneWayMap; this.resultSetSet = resultSetSet; this.traceMap = traceMap; this.methodIdToMethodMap = createMethodIdToMethodMap(remoteInterfaces); this.implClassName = implClassName; this.inProgressCalls = new ConcurrentHashMap<>(); } public void invoke(InvokeRequest invokeRequest, ChannelHandlerContext ctx) { Method method = methodIdToMethodMap.get(invokeRequest.getMethodId()); boolean isResultSet = resultSetSet.contains(invokeRequest.getMethodId()); OneWay oneWay = oneWayMap.get(invokeRequest.getMethodId()); if (method == null) { logger.error("Unknown method id {} in request {} of object ", invokeRequest.getMethodId(), invokeRequest, impl); if (oneWay == null) { writeResponse(ctx, new Response(invokeRequest.getRequestId() , null, invokeRequest.callDescription(), new IllegalArgumentException("Unknown method id " + invokeRequest.getMethodId() + " in request " + invokeRequest + " of object " + impl)) , invokeRequest); } return; } invokeRequest.setMethodName(method.getName()); invokeRequest.setImplClassName(implClassName); trace(invokeRequest, ctx); try { Object[] params = Modules.getInstance().getUtil().unMarshalParams(invokeRequest.getParams()); Object res; if(isResultSet){ invokeInThread(method, impl, params, ctx); return; }else{ res = method.invoke(impl, params); } if (oneWay != null) { return; } if (res instanceof Future) { final CompletableFuture<Object> completableFuture = toCompletableFuture((Future) res); inProgressCalls.put(invokeRequest.getUniqueId(), completableFuture); //noinspection unchecked completableFuture.whenComplete((o, e) -> { if(null == inProgressCalls.remove(invokeRequest.getUniqueId())){ // future was canceled by client. return; } if (o != null) { writeResponse(ctx, new Response(invokeRequest.getRequestId(), o, invokeRequest.callDescription()), invokeRequest); } else { writeResponse(ctx, new Response(invokeRequest.getRequestId(), null, invokeRequest.callDescription(), e), invokeRequest); } }); } else { writeResponse(ctx, new Response(invokeRequest.getRequestId(), res, invokeRequest.callDescription()), invokeRequest); } } catch (IllegalAccessException e) { logger.error("error while processing request {} object is {} method is {}", invokeRequest, impl, method.toGenericString(), e); writeResponse(ctx, new Response(invokeRequest.getRequestId(), null, invokeRequest.callDescription(), e), invokeRequest); } catch (InvocationTargetException e) { logger.error("error while processing request {} object is {} method is {}", invokeRequest, impl, method.toGenericString(), e); writeResponse(ctx, new Response(invokeRequest.getRequestId(), null, invokeRequest.callDescription(), e.getTargetException()), invokeRequest); } catch (Throwable e) { logger.error("error while processing request {} object is {} method is {}", invokeRequest, impl, method.toGenericString(), e); writeResponse(ctx, new Response(invokeRequest.getRequestId(), null, invokeRequest.callDescription(), e), invokeRequest); } } private Object invokeInThread(final Method method, final Remote impl, final Object[] params, final ChannelHandlerContext ctx) { new Thread(() -> { try { ServerResultSetCallback<Object> serverResultSetCallback = new ServerResultSetCallback<>(ctx); ResultSets.set(serverResultSetCallback); ctx.attr(SERVER_RESULT_SET_CALLBACK_ATTRIBUTE_KEY).set(serverResultSetCallback); method.invoke(impl, params); }catch(Exception ignored){ } }).start(); return null; } public void cancelRequest(CancelInvokeRequest request) { CompletableFuture future = inProgressCalls.remove(request.getUniqueId()); if(future != null) { future.cancel(request.isMayInterruptIfRunning()); } } private void trace(InvokeRequest invokeRequest, ChannelHandlerContext ctx) { Trace trace = traceMap.get(invokeRequest.getMethodId()); if (trace != null && trace.value() != TraceType.OFF) { if(trace.value() == TraceType.DETAILED) { logger.debug("{} <-- {} : {}", getTo(ctx), getFrom(ctx), invokeRequest.toDetailedString()); }else { logger.debug("{} <-- {} : {}", getTo(ctx), getFrom(ctx), invokeRequest); } } } private void trace(ChannelHandlerContext ctx, Response response, long methodId) { Trace trace = traceMap.get(methodId); if (trace != null && trace.value() != TraceType.OFF) { logger.debug("{} --> {} : {}", getFrom(ctx), getTo(ctx), response); } } private CompletableFuture<Object> toCompletableFuture(Future future) { if (future instanceof CompletableFuture) { //noinspection unchecked return (CompletableFuture<Object>) future; } else { CompletableFuture<Object> res = new CompletableFuture<>(); CompletableFuture.runAsync(() -> { try { //noinspection unchecked res.complete(future.get()); } catch (InterruptedException e) { res.completeExceptionally(e); Thread.currentThread().interrupt(); } catch (ExecutionException e) { res.completeExceptionally(e.getCause()); } }); return res; } } private void writeResponse(ChannelHandlerContext ctx, Response response, InvokeRequest invokeRequest) { trace(ctx, response, invokeRequest.getMethodId()); ctx.writeAndFlush(response); } private String getFrom(ChannelHandlerContext ctx) { return addressAsString((InetSocketAddress) ctx.channel().localAddress()); } private String addressAsString(InetSocketAddress socketAddress) { return socketAddress.getHostString() + ":" + socketAddress.getPort(); } private String getTo(ChannelHandlerContext ctx) { return addressAsString((InetSocketAddress) ctx.channel().remoteAddress()); } private Map<Long, Method> createMethodIdToMethodMap(Class[] remoteInterfaces) { Util util = Modules.getInstance().getUtil(); List<Method> sortedMethodList = util.getSortedMethodList(remoteInterfaces); Map<Long, Method> mapping = new HashMap<>(sortedMethodList.size()); for (Method method : sortedMethodList) { mapping.put(util.computeMethodHash(method), method); } return mapping; } @Override public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; ObjectRef objectRef = (ObjectRef) o; return impl.equals(objectRef.impl); } @Override public int hashCode() { return impl.hashCode(); } }