/* * Copyright 2016 LINE Corporation * * LINE Corporation licenses this file to you under the Apache License, * version 2.0 (the "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at: * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. */ package com.linecorp.armeria.client.thrift; import static com.linecorp.armeria.common.util.Functions.voidFunction; import java.lang.reflect.InvocationHandler; import java.lang.reflect.Method; import java.lang.reflect.UndeclaredThrowableException; import java.net.URI; import java.util.Arrays; import java.util.concurrent.ExecutionException; import org.apache.thrift.async.AsyncMethodCallback; import com.linecorp.armeria.client.ClientBuilderParams; import com.linecorp.armeria.client.ClientFactory; import com.linecorp.armeria.client.ClientOptions; import com.linecorp.armeria.common.RpcResponse; import com.linecorp.armeria.common.util.CompletionActions; final class THttpClientInvocationHandler implements InvocationHandler, ClientBuilderParams { private static final Object[] NO_ARGS = new Object[0]; private final ClientBuilderParams params; private final THttpClient thriftClient; private final String path; private final String fragment; THttpClientInvocationHandler(ClientBuilderParams params, THttpClient thriftClient, String path, String fragment) { this.params = params; this.thriftClient = thriftClient; this.path = path; this.fragment = fragment; } @Override public ClientFactory factory() { return params.factory(); } @Override public URI uri() { return params.uri(); } @Override public Class<?> clientType() { return params.clientType(); } @Override public ClientOptions options() { return params.options(); } @Override public Object invoke(Object proxy, Method method, Object[] args) throws Throwable { final Class<?> declaringClass = method.getDeclaringClass(); if (declaringClass == Object.class) { // Handle the methods in Object return invokeObjectMethod(proxy, method, args); } assert declaringClass == params.clientType(); // Handle the methods in the interface. return invokeClientMethod(method, args); } private Object invokeObjectMethod(Object proxy, Method method, Object[] args) { final String methodName = method.getName(); switch (methodName) { case "toString": return params.clientType().getSimpleName() + '(' + path + ')'; case "hashCode": return System.identityHashCode(proxy); case "equals": return proxy == args[0]; default: throw new Error("unknown method: " + methodName); } } private Object invokeClientMethod(Method method, Object[] args) throws Throwable { final AsyncMethodCallback<Object> callback; if (args == null) { args = NO_ARGS; callback = null; } else { final int lastIdx = args.length - 1; if (args.length > 0 && args[lastIdx] instanceof AsyncMethodCallback) { @SuppressWarnings("unchecked") final AsyncMethodCallback<Object> lastArg = (AsyncMethodCallback<Object>) args[lastIdx]; callback = lastArg; args = Arrays.copyOfRange(args, 0, lastIdx); } else { callback = null; } } try { final RpcResponse reply = thriftClient.executeMultiplexed( path, params.clientType(), fragment, method.getName(), args); if (callback != null) { reply.handle(voidFunction((result, cause) -> { if (cause == null) { callback.onComplete(result); } else { invokeOnError(callback, cause); } })).exceptionally(CompletionActions::log); return null; } else { try { return reply.get(); } catch (ExecutionException e) { throw e.getCause(); } } } catch (Throwable cause) { if (callback != null) { invokeOnError(callback, cause); return null; } else { throw cause; } } } private static void invokeOnError(AsyncMethodCallback<Object> callback, Throwable cause) { callback.onError(cause instanceof Exception ? (Exception) cause : new UndeclaredThrowableException(cause)); } }