package org.async.rmi.client; import org.async.rmi.*; import org.async.rmi.messages.CancelInvokeRequest; import org.async.rmi.messages.Message; import org.async.rmi.messages.InvokeRequest; import org.async.rmi.messages.Response; import org.async.rmi.netty.NettyClientConnectionFactory; import org.async.rmi.pool.Pool; import org.async.rmi.pool.ShrinkableConnectionPool; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.IOException; import java.io.ObjectInput; import java.io.ObjectOutput; import java.lang.reflect.Method; import java.rmi.Remote; import java.rmi.RemoteException; import java.util.Arrays; import java.util.Map; import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; import java.util.concurrent.atomic.AtomicLong; /** * Created by Barak Bar Orion * 27/10/14. */ @SuppressWarnings({"UnusedDeclaration", "SpellCheckingInspection"}) public class UnicastRef implements RemoteRef { private static final Logger logger = LoggerFactory.getLogger(UnicastRef.class); private RemoteObjectAddress remoteObjectAddress; private Class[] remoteInterfaces; private static final AtomicLong nextRequestId = new AtomicLong(0); private Pool<Connection<Message>> pool; private long objectid; private Map<Long, Trace> traceMap; private String callDescription; private UUID clientId; private RequestQueue requestQueue; public UnicastRef() { clientId = UUID.randomUUID(); requestQueue = new RequestQueue(this); } public UnicastRef(RemoteObjectAddress remoteObjectAddress, Class[] remoteInterfaces, long objectid , Map<Long, Trace> traceMap, String callDescription) { this.remoteObjectAddress = remoteObjectAddress; this.remoteInterfaces = remoteInterfaces; this.objectid = objectid; this.traceMap = traceMap; this.callDescription = callDescription; } public long getObjectid() { return objectid; } @Override public Object invoke(Remote obj, Method method, Object[] params, long opHash, OneWay oneWay, boolean isResultSet) throws Throwable { Modules.getInstance().getTransport().startClassLoaderServer(Thread.currentThread().getContextClassLoader()); MarshalledObject [] marshalledParams = Modules.getInstance().getUtil().marshalParams(params); final InvokeRequest invokeRequest = new InvokeRequest(nextRequestId.getAndIncrement() , remoteObjectAddress.getObjectId(), opHash, oneWay != null , marshalledParams, method.getName(), callDescription); final CompletableFuture<Object> result = new ClientCompletableFuture<>(mayInterruptIfRunning -> send(new CancelInvokeRequest(invokeRequest, mayInterruptIfRunning), null, false, null)); SendResult sendResult = send(invokeRequest, oneWay, isResultSet, result); CompletableFuture<Response> future = sendResult != null ? sendResult.responseFuture : null; if (future != null && oneWay == null && Future.class.isAssignableFrom(method.getReturnType())) { //noinspection unchecked future.handle((response, throwable) -> { if (null != throwable) { result.completeExceptionally(throwable); } else if (response.isError()) { result.completeExceptionally(response.getError()); } else { result.complete(response.getResult()); } return null; }); return result; } else if (oneWay != null) { if (future != null && Future.class.isAssignableFrom(method.getReturnType())) { //noinspection unchecked future.handle((response, throwable) -> { if (null != throwable) { result.completeExceptionally(throwable); } else { result.complete(null); } return null; }); return result; } else if (oneWay.full()) { return null; } else { return getResponseResult(translateClientError(future)); } } else if(isResultSet && (sendResult != null)){ ClientResultSet clientResultSet = new ClientResultSet(sendResult.connectionFuture); sendResult.connectionFuture.get().attach(clientResultSet); clientResultSet.readyFuture().get(); return clientResultSet; } else { return getResponseResult(translateClientError(future)); } } @Override public void close() throws IOException { if(pool != null) { pool.close(); } } public synchronized void redirect(long objectId, String host, int port) { RemoteObjectAddress redirectedAddress = new RemoteObjectAddress("rmi://" + host + ":" + port, objectId); logger.info("redirecting client from {} to {}", remoteObjectAddress, redirectedAddress); this.objectid = objectId; remoteObjectAddress = redirectedAddress; Pool<Connection<Message>> oldPool = pool; pool = createPool(); if(oldPool != null) { try { oldPool.close(); }catch(Exception e){ logger.error(e.toString(), e); } } } private SendResult send(InvokeRequest invokeRequest, OneWay oneWay, boolean isResultSet, CompletableFuture<Object> result) { if(!requestQueue.add(invokeRequest, oneWay, result)){ return null; } final CompletableFuture<Response> responseFuture = new CompletableFuture<>(); if(!isResultSet) { Modules.getInstance().getTransport().addResponseFuture(invokeRequest, responseFuture, traceMap.get(invokeRequest.getMethodId())); } CompletableFuture<Connection<Message>> connectionFuture = pool.get(); connectionFuture.whenComplete((connection, throwable) -> { requestQueue.processRequest(connection, throwable, responseFuture); if(!isResultSet) { pool.free(connection); } }); return new SendResult(responseFuture, connectionFuture); } class SendResult{ public SendResult(CompletableFuture<Response> responseFuture, CompletableFuture<Connection<Message>> connectionFuture) { this.responseFuture = responseFuture; this.connectionFuture = connectionFuture; } CompletableFuture<Response> responseFuture; CompletableFuture<Connection<Message>> connectionFuture; } void trace(InvokeRequest invokeRequest, Connection<Message> connection) { Trace trace = traceMap.get(invokeRequest.getMethodId()); if(trace != null && trace.value() != TraceType.OFF) { if(trace.value() == TraceType.DETAILED){ logger.debug("{} --> {} : {}", connection.getLocalAddress(), connection.getRemoteAddress(), invokeRequest.toDetailedString()); }else{ logger.debug("{} --> {} : {}", connection.getLocalAddress(), connection.getRemoteAddress(), invokeRequest); } } } public <T> T translateClientError(Future<T> future) throws Throwable { try { return future.get(); } catch (Exception e) { if (e instanceof InterruptedException) { throw e; } else if (e instanceof ExecutionException) { throw e.getCause(); } else { throw e; } } } private Object getResponseResult(Response response) throws Throwable { if (response == null) { return null; } if (response.isError()) { //noinspection ThrowableResultOfMethodCallIgnored Throwable t = response.getError(); return translateServerError(t); } else { return response.getResult(); } } private Object translateServerError(Throwable t) throws Throwable { if (t instanceof RemoteException) { throw t; } else { throw new RemoteException(t.toString() + " from server", t); } } @Override public void writeExternal(ObjectOutput out) throws IOException { out.writeObject(remoteObjectAddress); out.writeObject(remoteInterfaces); out.writeObject(traceMap); out.writeLong(objectid); out.writeUTF(callDescription); } @Override public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { remoteObjectAddress = (RemoteObjectAddress) in.readObject(); remoteInterfaces = (Class[]) in.readObject(); //noinspection unchecked traceMap = (Map<Long, Trace>) in.readObject(); objectid = in.readLong(); callDescription = in.readUTF(); pool = createPool(); } @Override protected void finalize() throws Throwable { super.finalize(); close(); } @Override public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; UnicastRef that = (UnicastRef) o; return remoteObjectAddress.equals(that.remoteObjectAddress); } @Override public int hashCode() { return remoteObjectAddress.hashCode(); } private Pool<Connection<Message>> createPool() { pool = new ShrinkableConnectionPool(2); NettyClientConnectionFactory factory = new NettyClientConnectionFactory(Modules.getInstance().getTransport().getClientEventLoopGroup(), remoteObjectAddress, clientId); factory.setPool(pool); pool.setFactory(factory); return pool; } @Override public String toString() { return "UnicastRef{" + "remoteObjectAddress=" + remoteObjectAddress + ", remoteInterfaces=" + Arrays.toString(remoteInterfaces) + '}'; } }